Hendrik Langer
2 years ago
3 changed files with 302 additions and 4 deletions
@ -0,0 +1,256 @@ |
|||
import asyncio |
|||
import os, tempfile |
|||
import logging |
|||
|
|||
import requests |
|||
|
|||
from transformers import AutoTokenizer, AutoConfig |
|||
from huggingface_hub import hf_hub_download |
|||
|
|||
logger = logging.getLogger(__name__) |
|||
|
|||
|
|||
async def generate_sync( |
|||
prompt: str, |
|||
api_key: str, |
|||
bot_name: str, |
|||
): |
|||
# Set the API endpoint URL |
|||
endpoint = "https://koboldai.net/api/v2/generate/async" |
|||
|
|||
# Set the headers for the request |
|||
headers = { |
|||
"Content-Type": "application/json", |
|||
"accept": "application/json", |
|||
"apikey": f"{api_key}" |
|||
} |
|||
|
|||
max_new_tokens = 200 |
|||
prompt_num_tokens = await num_tokens(prompt) |
|||
|
|||
# Define your inputs |
|||
input_data = { |
|||
"prompt": prompt, |
|||
"params": { |
|||
"n": 1, |
|||
# "frmtadsnsp": False, |
|||
# "frmtrmblln": False, |
|||
# "frmtrmspch": False, |
|||
# "frmttriminc": False, |
|||
"max_context_length": 1024, |
|||
"max_length": 512, |
|||
"rep_pen": 1.1, |
|||
"rep_pen_range": 1024, |
|||
"rep_pen_slope": 0.7, |
|||
# "singleline": False, |
|||
# "soft_prompt": "", |
|||
"temperature": 0.75, |
|||
"tfs": 1.0, |
|||
"top_a": 0.0, |
|||
"top_k": 0, |
|||
"top_p": 0.9, |
|||
"typical": 1.0, |
|||
# "sampler_order": [0], |
|||
}, |
|||
"softprompts": [], |
|||
"trusted_workers": False, |
|||
"nsfw": True, |
|||
# "workers": [], |
|||
"models": ["PygmalionAI/pygmalion-6b"] |
|||
} |
|||
|
|||
|
|||
logger.info(f"sending request to koboldai.net") |
|||
|
|||
# Make the request |
|||
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180) |
|||
|
|||
r_json = r.json() |
|||
logger.info(r_json) |
|||
status = r_json["message"] |
|||
|
|||
if "id" in r_json: |
|||
job_id = r_json["id"] |
|||
TIMEOUT = 360 |
|||
DELAY = 11 |
|||
for i in range(TIMEOUT//DELAY): |
|||
endpoint = "https://koboldai.net/api/v2/generate/text/status/" + job_id |
|||
r = requests.get(endpoint, headers=headers) |
|||
r_json = r.json() |
|||
logger.info(r_json) |
|||
if "done" not in r_json: |
|||
return "<ERROR>" |
|||
if r_json["done"] == True: |
|||
text = r_json["generations"][0]["text"] |
|||
answer = text.removeprefix(prompt) |
|||
idx = answer.find(f"\nYou:") |
|||
if idx != -1: |
|||
reply = answer[:idx].strip() |
|||
else: |
|||
reply = answer.removesuffix('<|endoftext|>').strip() |
|||
reply = reply.replace("\n{bot_name}: ", " ") |
|||
reply.replace("\n<BOT>: ", " ") |
|||
return reply |
|||
else: |
|||
await asyncio.sleep(DELAY) |
|||
else: |
|||
return "<ERROR> {status}" |
|||
|
|||
async def get_full_prompt(simple_prompt: str, bot, chat_history): |
|||
|
|||
# Prompt without history |
|||
prompt = bot.name + "'s Persona: " + bot.persona + "\n" |
|||
prompt += "Scenario: " + bot.scenario + "\n" |
|||
prompt += "<START>" + "\n" |
|||
#prompt += bot.name + ": " + bot.greeting + "\n" |
|||
prompt += "You: " + simple_prompt + "\n" |
|||
prompt += bot.name + ":" |
|||
|
|||
MAX_TOKENS = 2048 |
|||
max_new_tokens = 200 |
|||
total_num_tokens = await num_tokens(prompt) |
|||
visible_history = [] |
|||
current_message = True |
|||
for key, chat_item in reversed(chat_history.chat_history.items()): |
|||
if current_message: |
|||
current_message = False |
|||
continue |
|||
if chat_item.message["en"].startswith('!begin'): |
|||
break |
|||
if chat_item.message["en"].startswith('!'): |
|||
continue |
|||
if chat_item.message["en"].startswith('<ERROR>'): |
|||
continue |
|||
#if chat_item.message["en"] == bot.greeting: |
|||
# continue |
|||
if chat_item.num_tokens == None: |
|||
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) |
|||
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? |
|||
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") |
|||
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: |
|||
visible_history.append(chat_item) |
|||
total_num_tokens += chat_item.num_tokens |
|||
else: |
|||
break |
|||
visible_history = reversed(visible_history) |
|||
|
|||
prompt = bot.name + "'s Persona: " + bot.persona + "\n" |
|||
prompt += "Scenario: " + bot.scenario + "\n" |
|||
prompt += "<START>" + "\n" |
|||
#prompt += bot.name + ": " + bot.greeting + "\n" |
|||
for chat_item in visible_history: |
|||
if chat_item.is_own_message: |
|||
prompt += bot.name + ": " + chat_item.message["en"] + "\n" |
|||
else: |
|||
prompt += "You" + ": " + chat_item.message["en"] + "\n" |
|||
prompt += "You: " + simple_prompt + "\n" |
|||
prompt += bot.name + ":" |
|||
|
|||
return prompt |
|||
|
|||
|
|||
async def num_tokens(input_text: str): |
|||
# os.makedirs("./models/pygmalion-6b", exist_ok=True) |
|||
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") |
|||
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") |
|||
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
|||
encoding = tokenizer.encode(input_text, add_special_tokens=False) |
|||
max_input_size = tokenizer.max_model_input_sizes |
|||
return len(encoding) |
|||
|
|||
async def estimate_num_tokens(input_text: str): |
|||
return len(input_text)//4+1 |
|||
|
|||
|
|||
async def generate_image(input_prompt: str, negative_prompt: str, model: str, api_key: str): |
|||
|
|||
# Set the API endpoint URL |
|||
endpoint = "https://stablehorde.net/api/v2/generate/async" |
|||
|
|||
# Set the headers for the request |
|||
headers = { |
|||
"Content-Type": "application/json", |
|||
"accept": "application/json", |
|||
"apikey": f"{api_key}" |
|||
} |
|||
|
|||
# Define your inputs |
|||
input_data = { |
|||
"prompt": input_prompt, |
|||
"params": { |
|||
# "negative_prompt": negative_prompt, |
|||
"width": 512, |
|||
"height": 512, |
|||
}, |
|||
"nsfw": True, |
|||
"trusted_workers": False, |
|||
# "workers": [], |
|||
"models": ["{model}"] |
|||
} |
|||
|
|||
logger.info(f"sending request to stablehorde.net") |
|||
|
|||
# Make the request |
|||
r = requests.post(endpoint, json=input_data, headers=headers) |
|||
r_json = r.json() |
|||
logger.info(r_json) |
|||
|
|||
if r.status_code == 202: |
|||
#status = r_json["message"] |
|||
job_id = r_json["id"] |
|||
TIMEOUT = 360 |
|||
DELAY = 11 |
|||
output = None |
|||
for i in range(TIMEOUT//DELAY): |
|||
endpoint = "https://stablehorde.net/api/v2/generate/status/" + job_id |
|||
r = requests.get(endpoint, headers=headers) |
|||
r_json = r.json() |
|||
logger.info(r_json) |
|||
#status = r_json["message"] |
|||
if "done" not in r_json: |
|||
return "<ERROR>" |
|||
if "faulted" in r_json and r_json["faulted"] == True: |
|||
return "<ERROR>" |
|||
if r_json["done"] == True: |
|||
output = r_json["generations"] |
|||
break |
|||
else: |
|||
await asyncio.sleep(DELAY) |
|||
|
|||
if not output: |
|||
raise ValueError(f"<ERROR>") |
|||
|
|||
os.makedirs("./images", exist_ok=True) |
|||
files = [] |
|||
for image in output: |
|||
temp_name = next(tempfile._get_candidate_names()) |
|||
filename = "./images/" + temp_name + ".jpg" |
|||
await download_image(image["img"], filename) |
|||
files.append(filename) |
|||
|
|||
return files |
|||
|
|||
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str): |
|||
return await generate_image(input_prompt, negative_prompt, "Deliberate", api_key) |
|||
|
|||
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str): |
|||
return await generate_image(input_prompt, negative_prompt, "PFG", api_key) |
|||
|
|||
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str): |
|||
return await generate_image(input_prompt, negative_prompt, "Hassanblend", api_key) |
|||
|
|||
async def download_image(url, path): |
|||
r = requests.get(url, stream=True) |
|||
if r.status_code == 200: |
|||
with open(path, 'wb') as f: |
|||
for chunk in r: |
|||
f.write(chunk) |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
Loading…
Reference in new issue