Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

389 lines
14 KiB

import asyncio
import os, tempfile
import logging
2 years ago
import json
import requests
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import io
import base64
from PIL import Image, PngImagePlugin
logger = logging.getLogger(__name__)
async def generate_sync(
prompt: str,
api_key: str,
bot,
2 years ago
typing_fn,
api_endpoint = "pygmalion-6b"
):
# Set the API endpoint URL
2 years ago
endpoint = f"https://api.runpod.ai/v2/{api_endpoint}/run"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
max_new_tokens = 200
prompt_num_tokens = await num_tokens(prompt)
# Define your inputs
input_data = {
"input": {
"prompt": prompt,
2 years ago
"max_length": min(prompt_num_tokens+max_new_tokens, 2048),
"temperature": bot.temperature,
2 years ago
"do_sample": True,
}
}
logger.info(f"sending request to runpod.io")
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180)
r_json = r.json()
logger.info(r_json)
2 years ago
if r.status_code == 200:
status = r_json["status"]
job_id = r_json["id"]
2 years ago
TIMEOUT = 360
DELAY = 5
for i in range(TIMEOUT//DELAY):
2 years ago
endpoint = f"https://api.runpod.ai/v2/{api_endpoint}/status/{job_id}"
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
output = r_json["output"]
if isinstance(output, list):
output.sort(key=len, reverse=True)
text = output[0]
else:
text = r_json["output"]
answer = text.removeprefix(prompt)
# lines = reply.split('\n')
# reply = lines[0].strip()
idx = answer.find(f"\nYou:")
if idx != -1:
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
2 years ago
reply = reply.replace(f"\n{bot.name}: ", " ")
reply = reply.replace(f"\n<BOT>: ", " ")
reply = reply.replace(f"<BOT>", "{bot.name}")
reply = reply.replace(f"<USER>", "You")
return reply
else:
2 years ago
err_msg = r_json["error"] if "error" in r_json else ""
err_msg = err_msg.replace("\\n", "\n")
2 years ago
raise ValueError(f"RETURN CODE {status}: {err_msg}")
raise ValueError(f"<TIMEOUT>")
else:
2 years ago
raise ValueError(f"<ERROR>")
2 years ago
async def get_full_prompt(simple_prompt: str, bot, chat_history):
# Prompt without history
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
prompt += "Scenario: " + bot.scenario + "\n"
prompt += "<START>" + "\n"
#prompt += bot.name + ": " + bot.greeting + "\n"
prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":"
MAX_TOKENS = 2048
max_new_tokens = 200
total_num_tokens = await num_tokens(prompt)
visible_history = []
current_message = True
2 years ago
for key, chat_item in reversed(chat_history.chat_history.items()):
if current_message:
current_message = False
continue
2 years ago
if chat_item.message["en"].startswith('!begin'):
break
2 years ago
if chat_item.message["en"].startswith('!'):
continue
if chat_item.message["en"].startswith('<ERROR>'):
continue
2 years ago
#if chat_item.message["en"] == bot.greeting:
# continue
if chat_item.num_tokens == None:
2 years ago
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"]))
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens:
visible_history.append(chat_item)
total_num_tokens += chat_item.num_tokens
else:
break
visible_history = reversed(visible_history)
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
prompt += "Scenario: " + bot.scenario + "\n"
prompt += "<START>" + "\n"
#prompt += bot.name + ": " + bot.greeting + "\n"
for chat_item in visible_history:
if chat_item.is_own_message:
2 years ago
prompt += bot.name + ": " + chat_item.message["en"] + "\n"
else:
2 years ago
prompt += "You" + ": " + chat_item.message["en"] + "\n"
prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":"
return prompt
async def num_tokens(input_text: str):
# os.makedirs("./models/pygmalion-6b", exist_ok=True)
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
encoding = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes
return len(encoding)
async def estimate_num_tokens(input_text: str):
return len(input_text)//4+1
async def download_image(url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
# Set the API endpoint URL
2 years ago
endpoint = api_url + "run"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"negative_prompt": negative_prompt,
"width": 512,
2 years ago
"height": 768,
"num_outputs": 3,
# "nsfw": True
},
}
logger.info(f"sending request to runpod.io")
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers)
r_json = r.json()
logger.debug(r_json)
if r.status_code == 200:
status = r_json["status"]
if status != 'IN_QUEUE':
2 years ago
err_msg = r_json["error"] if "error" in r_json else ""
raise ValueError(f"RETURN CODE {status}: {err_msg}")
job_id = r_json["id"]
2 years ago
TIMEOUT = 360
DELAY = 5
2 years ago
output = None
for i in range(TIMEOUT//DELAY):
2 years ago
endpoint = api_url + "status/" + job_id
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.debug(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
output = r_json["output"]
break
else:
2 years ago
err_msg = r_json["error"] if "error" in r_json else ""
err_msg = err_msg.replace("\\n", "\n")
2 years ago
raise ValueError(f"RETURN CODE {status}: {err_msg}")
2 years ago
if not output:
raise ValueError(f"<ERROR>")
os.makedirs("./images", exist_ok=True)
files = []
for image in output:
temp_name = next(tempfile._get_candidate_names())
filename = "./images/" + temp_name + ".jpg"
await download_image(image["image"], filename)
files.append(filename)
return files
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
2 years ago
# AnythingV4
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key, typing_fn)
2 years ago
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
2 years ago
# OpenJourney
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key, typing_fn)
2 years ago
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
2 years ago
# Hassanblend
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key, typing_fn)
2 years ago
async def generate_image4(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
2 years ago
# DeliberateV2
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/lxdhmiccp3vdsf/", api_key, typing_fn)
2 years ago
async def generate_image5(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
2 years ago
# Hassanblend
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/13rrs00l7yxikf/", api_key, typing_fn)
2 years ago
async def generate_image6(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
2 years ago
# PFG
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/5j1xzlsyw84vk5/", api_key, typing_fn)
2 years ago
async def generate_image7(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
# ChilloutMix
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v2/rrjxafqx66osr4/", api_key, typing_fn)
async def generate_image8(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
2 years ago
# AbyssOrangeMixV3
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v2/vuyifmsasm3ix7/", api_key, typing_fn)
async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str, typing_fn):
# Set the API endpoint URL
endpoint = api_url + "run"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
# Define your inputs
2 years ago
payload.update({"api_endpoint": str(cmd)})
input_data = {
2 years ago
"input": payload,
"cmd": cmd,
}
logger.info(f"sending request to runpod.io")
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers)
r_json = r.json()
logger.debug(r_json)
if r.status_code == 200:
status = r_json["status"]
if status != 'IN_QUEUE':
2 years ago
err_msg = r_json["error"] if "error" in r_json else ""
raise ValueError(f"RETURN CODE {status}: {err_msg}")
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 5
output = None
for i in range(TIMEOUT//DELAY):
endpoint = api_url + "status/" + job_id
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.debug(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
output = r_json["output"]
break
else:
2 years ago
err_msg = r_json["error"] if "error" in r_json else ""
raise ValueError(f"RETURN CODE {status}: {err_msg}")
if not output:
2 years ago
raise ValueError(f"<ERROR> {status}")
2 years ago
return output
async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
2 years ago
payload = {
"prompt": input_prompt,
"nagative_prompt": negative_prompt,
"steps": 25,
2 years ago
"seed": -1,
"width": 512,
"height": 768,
"batch_size": 3,
# "enable_hr": True,
# "hr_scale": 2,
# "hr_upscaler": "ESRGAN_4x",
2 years ago
"restore_faces": True,
# "gfpgan_visibility": 0.5,
# "codeformer_visibility": 0.5,
# "codeformer_weight": 0.5,
## "override_settings": {
## "filter_nsfw": False,
## },
2 years ago
}
output = await serverless_automatic_request(payload, "txt2img", api_url, api_key, typing_fn)
2 years ago
upscale = False
if upscale:
count = 0
for i in output['images']:
payload = {
"init_images": [i],
2 years ago
"prompt": input_prompt,
"nagative_prompt": negative_prompt,
"steps": 20,
"seed": -1,
# tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale
# "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875],
# "script_name": "Ultimate SD upscale",
}
upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn)
2 years ago
output['images'][count] = upscaled_output['images'][count]
os.makedirs("./images", exist_ok=True)
files = []
for i in output['images']:
temp_name = next(tempfile._get_candidate_names())
filename = "./images/" + temp_name + ".png"
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
info = output['info']
parameters = output['parameters']
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", info)
image.save(filename, pnginfo=pnginfo)
files.append(filename)
return files