Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
8.0 KiB

2 years ago
import asyncio
import os, tempfile
import logging
import requests
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
async def generate_sync(
prompt: str,
api_key: str,
bot_name: str,
):
# Set the API endpoint URL
endpoint = "https://koboldai.net/api/v2/generate/async"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"accept": "application/json",
"apikey": f"{api_key}"
}
max_new_tokens = 200
prompt_num_tokens = await num_tokens(prompt)
# Define your inputs
input_data = {
"prompt": prompt,
"params": {
"n": 1,
# "frmtadsnsp": False,
# "frmtrmblln": False,
# "frmtrmspch": False,
# "frmttriminc": False,
"max_context_length": 1024,
"max_length": 512,
"rep_pen": 1.1,
"rep_pen_range": 1024,
"rep_pen_slope": 0.7,
# "singleline": False,
# "soft_prompt": "",
"temperature": 0.75,
"tfs": 1.0,
"top_a": 0.0,
"top_k": 0,
"top_p": 0.9,
"typical": 1.0,
# "sampler_order": [0],
},
"softprompts": [],
"trusted_workers": False,
"nsfw": True,
# "workers": [],
"models": ["PygmalionAI/pygmalion-6b"]
}
logger.info(f"sending request to koboldai.net")
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180)
r_json = r.json()
logger.info(r_json)
status = r_json["message"]
if "id" in r_json:
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 11
for i in range(TIMEOUT//DELAY):
endpoint = "https://koboldai.net/api/v2/generate/text/status/" + job_id
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
if "done" not in r_json:
return "<ERROR>"
if r_json["done"] == True:
text = r_json["generations"][0]["text"]
answer = text.removeprefix(prompt)
idx = answer.find(f"\nYou:")
if idx != -1:
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
reply = reply.replace("\n{bot_name}: ", " ")
reply.replace("\n<BOT>: ", " ")
return reply
else:
await asyncio.sleep(DELAY)
else:
return "<ERROR> {status}"
async def get_full_prompt(simple_prompt: str, bot, chat_history):
# Prompt without history
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
prompt += "Scenario: " + bot.scenario + "\n"
prompt += "<START>" + "\n"
#prompt += bot.name + ": " + bot.greeting + "\n"
prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":"
MAX_TOKENS = 2048
max_new_tokens = 200
total_num_tokens = await num_tokens(prompt)
visible_history = []
current_message = True
for key, chat_item in reversed(chat_history.chat_history.items()):
if current_message:
current_message = False
continue
if chat_item.message["en"].startswith('!begin'):
break
if chat_item.message["en"].startswith('!'):
continue
if chat_item.message["en"].startswith('<ERROR>'):
continue
#if chat_item.message["en"] == bot.greeting:
# continue
if chat_item.num_tokens == None:
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"]))
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens:
visible_history.append(chat_item)
total_num_tokens += chat_item.num_tokens
else:
break
visible_history = reversed(visible_history)
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
prompt += "Scenario: " + bot.scenario + "\n"
prompt += "<START>" + "\n"
#prompt += bot.name + ": " + bot.greeting + "\n"
for chat_item in visible_history:
if chat_item.is_own_message:
prompt += bot.name + ": " + chat_item.message["en"] + "\n"
else:
prompt += "You" + ": " + chat_item.message["en"] + "\n"
prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":"
return prompt
async def num_tokens(input_text: str):
# os.makedirs("./models/pygmalion-6b", exist_ok=True)
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
encoding = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes
return len(encoding)
async def estimate_num_tokens(input_text: str):
return len(input_text)//4+1
async def generate_image(input_prompt: str, negative_prompt: str, model: str, api_key: str):
# Set the API endpoint URL
endpoint = "https://stablehorde.net/api/v2/generate/async"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"accept": "application/json",
"apikey": f"{api_key}"
}
# Define your inputs
input_data = {
"prompt": input_prompt,
"params": {
# "negative_prompt": negative_prompt,
"width": 512,
"height": 512,
},
"nsfw": True,
"trusted_workers": False,
# "workers": [],
"models": ["{model}"]
}
logger.info(f"sending request to stablehorde.net")
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers)
r_json = r.json()
logger.info(r_json)
if r.status_code == 202:
#status = r_json["message"]
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 11
output = None
for i in range(TIMEOUT//DELAY):
endpoint = "https://stablehorde.net/api/v2/generate/status/" + job_id
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
#status = r_json["message"]
if "done" not in r_json:
return "<ERROR>"
if "faulted" in r_json and r_json["faulted"] == True:
return "<ERROR>"
if r_json["done"] == True:
output = r_json["generations"]
break
else:
await asyncio.sleep(DELAY)
if not output:
raise ValueError(f"<ERROR>")
os.makedirs("./images", exist_ok=True)
files = []
for image in output:
temp_name = next(tempfile._get_candidate_names())
filename = "./images/" + temp_name + ".jpg"
await download_image(image["img"], filename)
files.append(filename)
return files
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str):
return await generate_image(input_prompt, negative_prompt, "Deliberate", api_key)
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str):
return await generate_image(input_prompt, negative_prompt, "PFG", api_key)
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str):
return await generate_image(input_prompt, negative_prompt, "Hassanblend", api_key)
async def download_image(url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)