Hendrik Langer
2 years ago
3 changed files with 302 additions and 4 deletions
@ -0,0 +1,256 @@ |
|||||
|
import asyncio |
||||
|
import os, tempfile |
||||
|
import logging |
||||
|
|
||||
|
import requests |
||||
|
|
||||
|
from transformers import AutoTokenizer, AutoConfig |
||||
|
from huggingface_hub import hf_hub_download |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
async def generate_sync( |
||||
|
prompt: str, |
||||
|
api_key: str, |
||||
|
bot_name: str, |
||||
|
): |
||||
|
# Set the API endpoint URL |
||||
|
endpoint = "https://koboldai.net/api/v2/generate/async" |
||||
|
|
||||
|
# Set the headers for the request |
||||
|
headers = { |
||||
|
"Content-Type": "application/json", |
||||
|
"accept": "application/json", |
||||
|
"apikey": f"{api_key}" |
||||
|
} |
||||
|
|
||||
|
max_new_tokens = 200 |
||||
|
prompt_num_tokens = await num_tokens(prompt) |
||||
|
|
||||
|
# Define your inputs |
||||
|
input_data = { |
||||
|
"prompt": prompt, |
||||
|
"params": { |
||||
|
"n": 1, |
||||
|
# "frmtadsnsp": False, |
||||
|
# "frmtrmblln": False, |
||||
|
# "frmtrmspch": False, |
||||
|
# "frmttriminc": False, |
||||
|
"max_context_length": 1024, |
||||
|
"max_length": 512, |
||||
|
"rep_pen": 1.1, |
||||
|
"rep_pen_range": 1024, |
||||
|
"rep_pen_slope": 0.7, |
||||
|
# "singleline": False, |
||||
|
# "soft_prompt": "", |
||||
|
"temperature": 0.75, |
||||
|
"tfs": 1.0, |
||||
|
"top_a": 0.0, |
||||
|
"top_k": 0, |
||||
|
"top_p": 0.9, |
||||
|
"typical": 1.0, |
||||
|
# "sampler_order": [0], |
||||
|
}, |
||||
|
"softprompts": [], |
||||
|
"trusted_workers": False, |
||||
|
"nsfw": True, |
||||
|
# "workers": [], |
||||
|
"models": ["PygmalionAI/pygmalion-6b"] |
||||
|
} |
||||
|
|
||||
|
|
||||
|
logger.info(f"sending request to koboldai.net") |
||||
|
|
||||
|
# Make the request |
||||
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180) |
||||
|
|
||||
|
r_json = r.json() |
||||
|
logger.info(r_json) |
||||
|
status = r_json["message"] |
||||
|
|
||||
|
if "id" in r_json: |
||||
|
job_id = r_json["id"] |
||||
|
TIMEOUT = 360 |
||||
|
DELAY = 11 |
||||
|
for i in range(TIMEOUT//DELAY): |
||||
|
endpoint = "https://koboldai.net/api/v2/generate/text/status/" + job_id |
||||
|
r = requests.get(endpoint, headers=headers) |
||||
|
r_json = r.json() |
||||
|
logger.info(r_json) |
||||
|
if "done" not in r_json: |
||||
|
return "<ERROR>" |
||||
|
if r_json["done"] == True: |
||||
|
text = r_json["generations"][0]["text"] |
||||
|
answer = text.removeprefix(prompt) |
||||
|
idx = answer.find(f"\nYou:") |
||||
|
if idx != -1: |
||||
|
reply = answer[:idx].strip() |
||||
|
else: |
||||
|
reply = answer.removesuffix('<|endoftext|>').strip() |
||||
|
reply = reply.replace("\n{bot_name}: ", " ") |
||||
|
reply.replace("\n<BOT>: ", " ") |
||||
|
return reply |
||||
|
else: |
||||
|
await asyncio.sleep(DELAY) |
||||
|
else: |
||||
|
return "<ERROR> {status}" |
||||
|
|
||||
|
async def get_full_prompt(simple_prompt: str, bot, chat_history): |
||||
|
|
||||
|
# Prompt without history |
||||
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n" |
||||
|
prompt += "Scenario: " + bot.scenario + "\n" |
||||
|
prompt += "<START>" + "\n" |
||||
|
#prompt += bot.name + ": " + bot.greeting + "\n" |
||||
|
prompt += "You: " + simple_prompt + "\n" |
||||
|
prompt += bot.name + ":" |
||||
|
|
||||
|
MAX_TOKENS = 2048 |
||||
|
max_new_tokens = 200 |
||||
|
total_num_tokens = await num_tokens(prompt) |
||||
|
visible_history = [] |
||||
|
current_message = True |
||||
|
for key, chat_item in reversed(chat_history.chat_history.items()): |
||||
|
if current_message: |
||||
|
current_message = False |
||||
|
continue |
||||
|
if chat_item.message["en"].startswith('!begin'): |
||||
|
break |
||||
|
if chat_item.message["en"].startswith('!'): |
||||
|
continue |
||||
|
if chat_item.message["en"].startswith('<ERROR>'): |
||||
|
continue |
||||
|
#if chat_item.message["en"] == bot.greeting: |
||||
|
# continue |
||||
|
if chat_item.num_tokens == None: |
||||
|
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) |
||||
|
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? |
||||
|
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") |
||||
|
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: |
||||
|
visible_history.append(chat_item) |
||||
|
total_num_tokens += chat_item.num_tokens |
||||
|
else: |
||||
|
break |
||||
|
visible_history = reversed(visible_history) |
||||
|
|
||||
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n" |
||||
|
prompt += "Scenario: " + bot.scenario + "\n" |
||||
|
prompt += "<START>" + "\n" |
||||
|
#prompt += bot.name + ": " + bot.greeting + "\n" |
||||
|
for chat_item in visible_history: |
||||
|
if chat_item.is_own_message: |
||||
|
prompt += bot.name + ": " + chat_item.message["en"] + "\n" |
||||
|
else: |
||||
|
prompt += "You" + ": " + chat_item.message["en"] + "\n" |
||||
|
prompt += "You: " + simple_prompt + "\n" |
||||
|
prompt += bot.name + ":" |
||||
|
|
||||
|
return prompt |
||||
|
|
||||
|
|
||||
|
async def num_tokens(input_text: str): |
||||
|
# os.makedirs("./models/pygmalion-6b", exist_ok=True) |
||||
|
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") |
||||
|
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") |
||||
|
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
||||
|
encoding = tokenizer.encode(input_text, add_special_tokens=False) |
||||
|
max_input_size = tokenizer.max_model_input_sizes |
||||
|
return len(encoding) |
||||
|
|
||||
|
async def estimate_num_tokens(input_text: str): |
||||
|
return len(input_text)//4+1 |
||||
|
|
||||
|
|
||||
|
async def generate_image(input_prompt: str, negative_prompt: str, model: str, api_key: str): |
||||
|
|
||||
|
# Set the API endpoint URL |
||||
|
endpoint = "https://stablehorde.net/api/v2/generate/async" |
||||
|
|
||||
|
# Set the headers for the request |
||||
|
headers = { |
||||
|
"Content-Type": "application/json", |
||||
|
"accept": "application/json", |
||||
|
"apikey": f"{api_key}" |
||||
|
} |
||||
|
|
||||
|
# Define your inputs |
||||
|
input_data = { |
||||
|
"prompt": input_prompt, |
||||
|
"params": { |
||||
|
# "negative_prompt": negative_prompt, |
||||
|
"width": 512, |
||||
|
"height": 512, |
||||
|
}, |
||||
|
"nsfw": True, |
||||
|
"trusted_workers": False, |
||||
|
# "workers": [], |
||||
|
"models": ["{model}"] |
||||
|
} |
||||
|
|
||||
|
logger.info(f"sending request to stablehorde.net") |
||||
|
|
||||
|
# Make the request |
||||
|
r = requests.post(endpoint, json=input_data, headers=headers) |
||||
|
r_json = r.json() |
||||
|
logger.info(r_json) |
||||
|
|
||||
|
if r.status_code == 202: |
||||
|
#status = r_json["message"] |
||||
|
job_id = r_json["id"] |
||||
|
TIMEOUT = 360 |
||||
|
DELAY = 11 |
||||
|
output = None |
||||
|
for i in range(TIMEOUT//DELAY): |
||||
|
endpoint = "https://stablehorde.net/api/v2/generate/status/" + job_id |
||||
|
r = requests.get(endpoint, headers=headers) |
||||
|
r_json = r.json() |
||||
|
logger.info(r_json) |
||||
|
#status = r_json["message"] |
||||
|
if "done" not in r_json: |
||||
|
return "<ERROR>" |
||||
|
if "faulted" in r_json and r_json["faulted"] == True: |
||||
|
return "<ERROR>" |
||||
|
if r_json["done"] == True: |
||||
|
output = r_json["generations"] |
||||
|
break |
||||
|
else: |
||||
|
await asyncio.sleep(DELAY) |
||||
|
|
||||
|
if not output: |
||||
|
raise ValueError(f"<ERROR>") |
||||
|
|
||||
|
os.makedirs("./images", exist_ok=True) |
||||
|
files = [] |
||||
|
for image in output: |
||||
|
temp_name = next(tempfile._get_candidate_names()) |
||||
|
filename = "./images/" + temp_name + ".jpg" |
||||
|
await download_image(image["img"], filename) |
||||
|
files.append(filename) |
||||
|
|
||||
|
return files |
||||
|
|
||||
|
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str): |
||||
|
return await generate_image(input_prompt, negative_prompt, "Deliberate", api_key) |
||||
|
|
||||
|
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str): |
||||
|
return await generate_image(input_prompt, negative_prompt, "PFG", api_key) |
||||
|
|
||||
|
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str): |
||||
|
return await generate_image(input_prompt, negative_prompt, "Hassanblend", api_key) |
||||
|
|
||||
|
async def download_image(url, path): |
||||
|
r = requests.get(url, stream=True) |
||||
|
if r.status_code == 200: |
||||
|
with open(path, 'wb') as f: |
||||
|
for chunk in r: |
||||
|
f.write(chunk) |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
Loading…
Reference in new issue