import asyncio import requests import json import os, tempfile import logging logger = logging.getLogger(__name__) class StableHordeWrapper(object): """Base Class for stablehorde""" def __init__(self, api_key: str, endpoint_name: str, model_name: str): self.api_key = api_key self.endpoint_name = endpoint_name self.model_name = model_name async def generate(self, input_data: str, typing_fn, timeout=180): # Set the API endpoint URL endpoint = "https://stablehorde.net/api/v2/generate/async" #endpoint = "https://koboldai.net/api/v2/generate/async" # Set the headers for the request headers = { "Content-Type": "application/json", "accept": "application/json", "apikey": f"{self.api_key}" } logger.info(f"sending request to stablehorde.net. endpoint=\"{self.endpoint_name}\"") # Make the request try: r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout) except requests.exceptions.RequestException as e: raise ValueError(f"") r_json = r.json() logger.debug(r_json) if r.status_code == 202: #status = r_json["message"] job_id = r_json["id"] TIMEOUT = 360 DELAY = 11 output = None for i in range(TIMEOUT//DELAY): endpoint = f"https://stablehorde.net/api/v2/generate/status/{job_id}" #endpoint = f"https://koboldai.net/api/v2/generate/text/status/{job_id}" r = requests.get(endpoint, headers=headers) r_json = r.json() logger.info(r_json) #status = r_json["message"] if "done" not in r_json: raise ValueError("") if "faulted" in r_json and r_json["faulted"] == True: raise ValueError(" Faulted") if r_json["done"] == True: output = r_json["generations"] break else: if "processing" in r_json and r_json["processing"] == 1: await typing_fn() elif "wait_time" in r_json and r_json["wait_time"] < 20 and r_json["wait_time"] != 0 and r_json["queue_position"] < 100: await typing_fn() await asyncio.sleep(DELAY) else: raise ValueError(f" HTTP code {r.status_code}") if not output: raise ValueError(f" TIMEOUT / NO OUTOUT") return output class StableHordeTextWrapper(StableHordeWrapper): async def generate(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180): # Define your inputs input_data = { "prompt": prompt, "params": { "n": 1, # "frmtadsnsp": False, # "frmtrmblln": False, # "frmtrmspch": False, # "frmttriminc": False, "max_context_length": 1024, "max_length": 512, "rep_pen": 1.1, "rep_pen_range": 1024, "rep_pen_slope": 0.7, # "singleline": False, # "soft_prompt": "", "temperature": 0.75, "tfs": 1.0, "top_a": 0.0, "top_k": 0, "top_p": 0.9, "typical": 1.0, # "sampler_order": [0], }, "softprompts": [], "trusted_workers": False, "nsfw": True, # "workers": [], "models": [f"{self.endpoint_name}"] } output = await super().generate(input_data, api_key, typing_fn, timeout) output = output[0]["text"].removeprefix(prompt) return(output) class StableHordeImageWrapper(StableHordeWrapper): async def download_image(self, url, path): r = requests.get(url, stream=True) if r.status_code == 200: with open(path, 'wb') as f: for chunk in r: f.write(chunk) async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180): # Define your inputs input_data = { "prompt": input_prompt, "params": { # "negative_prompt": negative_prompt, "width": 512, "height": 512, }, "nsfw": True, "trusted_workers": False, # "workers": [], "models": [f"{self.endpoint_name}"] } output = await super().generate(input_data, typing_fn, timeout) os.makedirs("./.data/images", exist_ok=True) files = [] for image in output: temp_name = next(tempfile._get_candidate_names()) filename = "./.data/images/" + temp_name + ".jpg" await self.download_image(image["img"], filename) files.append(filename) return files