Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

233 lines
7.2 KiB

import asyncio
import os, tempfile
import logging
import requests
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
async def generate_sync(
prompt: str,
api_key: str,
2 years ago
bot_name: str,
):
# Set the API endpoint URL
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
max_new_tokens = 200
prompt_num_tokens = await num_tokens(prompt)
# Define your inputs
input_data = {
"input": {
"prompt": prompt,
"max_length": max(prompt_num_tokens+max_new_tokens, 2048),
2 years ago
"temperature": 0.75,
"do_sample": True,
}
}
logger.info(f"sending request to runpod.io")
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'COMPLETED':
text = r_json["output"]
answer = text.removeprefix(prompt)
# lines = reply.split('\n')
# reply = lines[0].strip()
idx = answer.find(f"\nYou:")
if idx != -1:
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
2 years ago
reply.replace("\n{bot_name}: ", " ")
return reply
elif status == 'IN_PROGRESS' or status == 'IN_QUEUE':
job_id = r_json["id"]
TIMEOUT = 180
DELAY = 5
for i in range(TIMEOUT//DELAY):
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/status/" + job_id
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
text = r_json["output"]
answer = text.removeprefix(prompt)
# lines = reply.split('\n')
# reply = lines[0].strip()
idx = answer.find(f"\nYou:")
if idx != -1:
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
2 years ago
reply = reply.replace("\n{bot_name}: ", " ")
return reply
else:
return "<ERROR>"
else:
return "<ERROR>"
2 years ago
async def get_full_prompt(simple_prompt: str, bot, chat_history):
# Prompt without history
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
prompt += "Scenario: " + bot.scenario + "\n"
prompt += "<START>" + "\n"
#prompt += bot.name + ": " + bot.greeting + "\n"
prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":"
MAX_TOKENS = 2048
max_new_tokens = 200
total_num_tokens = await num_tokens(prompt)
visible_history = []
current_message = True
2 years ago
for key, chat_item in reversed(chat_history.chat_history.items()):
if current_message:
current_message = False
continue
2 years ago
if chat_item.message["en"].startswith('!begin'):
break
2 years ago
if chat_item.message["en"].startswith('!'):
continue
2 years ago
#if chat_item.message["en"] == bot.greeting:
# continue
print("History: " + str(chat_item))
if chat_item.num_tokens == None:
2 years ago
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"]))
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
if total_num_tokens < (MAX_TOKENS - max_new_tokens):
visible_history.append(chat_item)
total_num_tokens += chat_item.num_tokens
print(total_num_tokens)
print("Finally: "+ str(total_num_tokens))
visible_history = reversed(visible_history)
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
prompt += "Scenario: " + bot.scenario + "\n"
prompt += "<START>" + "\n"
#prompt += bot.name + ": " + bot.greeting + "\n"
for chat_item in visible_history:
if chat_item.is_own_message:
2 years ago
prompt += bot.name + ": " + chat_item.message["en"] + "\n"
else:
2 years ago
prompt += "You" + ": " + chat_item.message["en"] + "\n"
prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":"
return prompt
async def num_tokens(input_text: str):
# os.makedirs("./models/pygmalion-6b", exist_ok=True)
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
encoding = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes
return len(encoding)
async def estimate_num_tokens(input_text: str):
return len(input_text)//4+1
async def generate_image(input_prompt: str, negative_prompt: str, api_key: str):
# Set the API endpoint URL
endpoint = "https://api.runpod.ai/v1/sd-anything-v4/run"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"negative_prompt": negative_prompt,
"width": 512,
"height": 768,
"nsfw": True
}
}
logger.info(f"sending request to runpod.io")
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers)
r_json = r.json()
logger.info(r_json)
if r.status_code == 200:
status = r_json["status"]
if status != 'IN_QUEUE':
raise ValueError(f"RETURN CODE {status}")
job_id = r_json["id"]
TIMEOUT = 180
DELAY = 5
for i in range(TIMEOUT//DELAY):
endpoint = "https://api.runpod.ai/v1/sd-anything-v4/status/" + job_id
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
output = r_json["output"]
break
else:
raise ValueError(f"RETURN CODE {status}")
os.makedirs("./images", exist_ok=True)
files = []
for image in output:
temp_name = next(tempfile._get_candidate_names())
filename = "./images/" + temp_name + ".jpg"
await download_image(image["image"], filename)
files.append(filename)
return files
async def download_image(url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)