You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
8.0 KiB
249 lines
8.0 KiB
import asyncio
|
|
import os, tempfile
|
|
import logging
|
|
|
|
import requests
|
|
|
|
from transformers import AutoTokenizer, AutoConfig
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def generate_sync(
|
|
prompt: str,
|
|
api_key: str,
|
|
bot_name: str,
|
|
):
|
|
# Set the API endpoint URL
|
|
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync"
|
|
|
|
# Set the headers for the request
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
max_new_tokens = 200
|
|
prompt_num_tokens = await num_tokens(prompt)
|
|
|
|
# Define your inputs
|
|
input_data = {
|
|
"input": {
|
|
"prompt": prompt,
|
|
"max_length": min(prompt_num_tokens+max_new_tokens, 2048),
|
|
"temperature": 0.80,
|
|
"do_sample": True,
|
|
}
|
|
}
|
|
|
|
logger.info(f"sending request to runpod.io")
|
|
|
|
# Make the request
|
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180)
|
|
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
status = r_json["status"]
|
|
|
|
if status == 'COMPLETED':
|
|
text = r_json["output"]
|
|
answer = text.removeprefix(prompt)
|
|
# lines = reply.split('\n')
|
|
# reply = lines[0].strip()
|
|
idx = answer.find(f"\nYou:")
|
|
if idx != -1:
|
|
reply = answer[:idx].strip()
|
|
else:
|
|
reply = answer.removesuffix('<|endoftext|>').strip()
|
|
reply = reply.replace("\n{bot_name}: ", " ")
|
|
reply = reply.replace("\n<BOT>: ", " ")
|
|
return reply
|
|
elif status == 'IN_PROGRESS' or status == 'IN_QUEUE':
|
|
job_id = r_json["id"]
|
|
TIMEOUT = 360
|
|
DELAY = 5
|
|
for i in range(TIMEOUT//DELAY):
|
|
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/status/" + job_id
|
|
r = requests.get(endpoint, headers=headers)
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
status = r_json["status"]
|
|
if status == 'IN_PROGRESS':
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'IN_QUEUE':
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'COMPLETED':
|
|
text = r_json["output"]
|
|
answer = text.removeprefix(prompt)
|
|
# lines = reply.split('\n')
|
|
# reply = lines[0].strip()
|
|
idx = answer.find(f"\nYou:")
|
|
if idx != -1:
|
|
reply = answer[:idx].strip()
|
|
else:
|
|
reply = answer.removesuffix('<|endoftext|>').strip()
|
|
reply = reply.replace("\n{bot_name}: ", " ")
|
|
reply = reply.replace("\n<BOT>: ", " ")
|
|
return reply
|
|
else:
|
|
return "<ERROR>"
|
|
else:
|
|
return "<ERROR>"
|
|
|
|
async def get_full_prompt(simple_prompt: str, bot, chat_history):
|
|
|
|
# Prompt without history
|
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
|
|
prompt += "Scenario: " + bot.scenario + "\n"
|
|
prompt += "<START>" + "\n"
|
|
#prompt += bot.name + ": " + bot.greeting + "\n"
|
|
prompt += "You: " + simple_prompt + "\n"
|
|
prompt += bot.name + ":"
|
|
|
|
MAX_TOKENS = 2048
|
|
max_new_tokens = 200
|
|
total_num_tokens = await num_tokens(prompt)
|
|
visible_history = []
|
|
current_message = True
|
|
for key, chat_item in reversed(chat_history.chat_history.items()):
|
|
if current_message:
|
|
current_message = False
|
|
continue
|
|
if chat_item.message["en"].startswith('!begin'):
|
|
break
|
|
if chat_item.message["en"].startswith('!'):
|
|
continue
|
|
if chat_item.message["en"].startswith('<ERROR>'):
|
|
continue
|
|
#if chat_item.message["en"] == bot.greeting:
|
|
# continue
|
|
if chat_item.num_tokens == None:
|
|
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"]))
|
|
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
|
|
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
|
|
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens:
|
|
visible_history.append(chat_item)
|
|
total_num_tokens += chat_item.num_tokens
|
|
else:
|
|
break
|
|
visible_history = reversed(visible_history)
|
|
|
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
|
|
prompt += "Scenario: " + bot.scenario + "\n"
|
|
prompt += "<START>" + "\n"
|
|
#prompt += bot.name + ": " + bot.greeting + "\n"
|
|
for chat_item in visible_history:
|
|
if chat_item.is_own_message:
|
|
prompt += bot.name + ": " + chat_item.message["en"] + "\n"
|
|
else:
|
|
prompt += "You" + ": " + chat_item.message["en"] + "\n"
|
|
prompt += "You: " + simple_prompt + "\n"
|
|
prompt += bot.name + ":"
|
|
|
|
return prompt
|
|
|
|
|
|
async def num_tokens(input_text: str):
|
|
# os.makedirs("./models/pygmalion-6b", exist_ok=True)
|
|
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
|
|
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
|
|
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
|
|
encoding = tokenizer.encode(input_text, add_special_tokens=False)
|
|
max_input_size = tokenizer.max_model_input_sizes
|
|
return len(encoding)
|
|
|
|
async def estimate_num_tokens(input_text: str):
|
|
return len(input_text)//4+1
|
|
|
|
|
|
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str):
|
|
|
|
# Set the API endpoint URL
|
|
endpoint = api_url + "run"
|
|
|
|
# Set the headers for the request
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
# Define your inputs
|
|
input_data = {
|
|
"input": {
|
|
"prompt": input_prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"width": 512,
|
|
"height": 512,
|
|
# "nsfw": True
|
|
},
|
|
}
|
|
|
|
logger.info(f"sending request to runpod.io")
|
|
|
|
# Make the request
|
|
r = requests.post(endpoint, json=input_data, headers=headers)
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
|
|
if r.status_code == 200:
|
|
status = r_json["status"]
|
|
if status != 'IN_QUEUE':
|
|
raise ValueError(f"RETURN CODE {status}")
|
|
job_id = r_json["id"]
|
|
TIMEOUT = 360
|
|
DELAY = 5
|
|
output = None
|
|
for i in range(TIMEOUT//DELAY):
|
|
endpoint = api_url + "status/" + job_id
|
|
r = requests.get(endpoint, headers=headers)
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
status = r_json["status"]
|
|
if status == 'IN_PROGRESS':
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'IN_QUEUE':
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'COMPLETED':
|
|
output = r_json["output"]
|
|
break
|
|
else:
|
|
raise ValueError(f"RETURN CODE {status}")
|
|
|
|
if not output:
|
|
raise ValueError(f"<ERROR>")
|
|
|
|
os.makedirs("./images", exist_ok=True)
|
|
files = []
|
|
for image in output:
|
|
temp_name = next(tempfile._get_candidate_names())
|
|
filename = "./images/" + temp_name + ".jpg"
|
|
await download_image(image["image"], filename)
|
|
files.append(filename)
|
|
|
|
return files
|
|
|
|
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str):
|
|
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key)
|
|
|
|
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str):
|
|
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key)
|
|
|
|
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str):
|
|
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key)
|
|
|
|
async def download_image(url, path):
|
|
r = requests.get(url, stream=True)
|
|
if r.status_code == 200:
|
|
with open(path, 'wb') as f:
|
|
for chunk in r:
|
|
f.write(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|