Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
5.0 KiB

2 years ago
import asyncio
import requests
import json
import os, tempfile
import logging
logger = logging.getLogger(__name__)
class StableHordeWrapper(object):
"""Base Class for stablehorde"""
def __init__(self, api_key: str, endpoint_name: str, model_name: str):
self.api_key = api_key
self.endpoint_name = endpoint_name
self.model_name = model_name
async def generate(self, input_data: str, typing_fn, timeout=180):
# Set the API endpoint URL
endpoint = "https://stablehorde.net/api/v2/generate/async"
#endpoint = "https://koboldai.net/api/v2/generate/async"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"accept": "application/json",
"apikey": f"{self.api_key}"
}
logger.info(f"sending request to stablehorde.net. endpoint=\"{self.endpoint_name}\"")
# Make the request
try:
r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout)
except requests.exceptions.RequestException as e:
raise ValueError(f"<HTTP ERROR>")
r_json = r.json()
logger.debug(r_json)
if r.status_code == 202:
#status = r_json["message"]
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 11
output = None
for i in range(TIMEOUT//DELAY):
endpoint = f"https://stablehorde.net/api/v2/generate/status/{job_id}"
#endpoint = f"https://koboldai.net/api/v2/generate/text/status/{job_id}"
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
#status = r_json["message"]
if "done" not in r_json:
raise ValueError("<ERROR>")
if "faulted" in r_json and r_json["faulted"] == True:
raise ValueError("<ERROR> Faulted")
if r_json["done"] == True:
output = r_json["generations"]
break
else:
if "processing" in r_json and r_json["processing"] == 1:
await typing_fn()
elif "wait_time" in r_json and r_json["wait_time"] < 20 and r_json["wait_time"] != 0 and r_json["queue_position"] < 100:
await typing_fn()
await asyncio.sleep(DELAY)
else:
raise ValueError(f"<ERROR> HTTP code {r.status_code}")
if not output:
raise ValueError(f"<ERROR> TIMEOUT / NO OUTOUT")
return output
class StableHordeTextWrapper(StableHordeWrapper):
async def generate(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
# Define your inputs
input_data = {
"prompt": prompt,
"params": {
"n": 1,
# "frmtadsnsp": False,
# "frmtrmblln": False,
# "frmtrmspch": False,
# "frmttriminc": False,
"max_context_length": 1024,
"max_length": 512,
"rep_pen": 1.1,
"rep_pen_range": 1024,
"rep_pen_slope": 0.7,
# "singleline": False,
# "soft_prompt": "",
"temperature": 0.75,
"tfs": 1.0,
"top_a": 0.0,
"top_k": 0,
"top_p": 0.9,
"typical": 1.0,
# "sampler_order": [0],
},
"softprompts": [],
"trusted_workers": False,
"nsfw": True,
# "workers": [],
"models": [f"{self.endpoint_name}"]
}
output = await super().generate(input_data, api_key, typing_fn, timeout)
output = output[0]["text"].removeprefix(prompt)
return(output)
class StableHordeImageWrapper(StableHordeWrapper):
async def download_image(self, url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)
async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):
# Define your inputs
input_data = {
"prompt": input_prompt,
"params": {
# "negative_prompt": negative_prompt,
"width": 512,
"height": 512,
},
"nsfw": True,
"trusted_workers": False,
# "workers": [],
"models": [f"{self.endpoint_name}"]
}
output = await super().generate(input_data, typing_fn, timeout)
os.makedirs("./.data/images", exist_ok=True)
files = []
for image in output:
temp_name = next(tempfile._get_candidate_names())
filename = "./.data/images/" + temp_name + ".jpg"
await self.download_image(image["img"], filename)
files.append(filename)
return files