Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

226 lines
7.9 KiB

import asyncio
import requests
import json
import os, tempfile
import io
import base64
from PIL import Image, PngImagePlugin
import logging
logger = logging.getLogger(__name__)
class RunpodWrapper(object):
"""Base Class for runpod"""
def __init__(self, api_key: str, endpoint_name: str, model_name: str):
self.api_key = api_key
self.endpoint_name = endpoint_name
self.model_name = model_name
async def generate(self, input_data: str, typing_fn, timeout=180):
# Set the API endpoint URL
endpoint = f"https://api.runpod.ai/v2/{self.endpoint_name}/run"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
logger.info(f"sending request to runpod.io. endpoint=\"{self.endpoint_name}\"")
# Make the request
try:
r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout)
except requests.exceptions.RequestException as e:
raise ValueError(f"<HTTP ERROR>")
r_json = r.json()
logger.debug(r_json)
if r.status_code == 200:
status = r_json["status"]
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 5
for i in range(TIMEOUT//DELAY):
endpoint = f"https://api.runpod.ai/v2/{self.endpoint_name}/status/{job_id}"
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
output = r_json["output"]
return output
else:
err_msg = r_json["error"] if "error" in r_json else ""
err_msg = err_msg.replace("\\n", "\n")
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
raise ValueError(f"<ERROR> TIMEOUT")
else:
raise ValueError(f"<ERROR>")
class RunpodTextWrapper(RunpodWrapper):
async def generate(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": prompt,
"max_length": min(max_new_tokens, 2048),
"temperature": bot.temperature,
"do_sample": True,
}
}
output = await super().generate(input_data, api_key, typing_fn, timeout)
output = output.removeprefix(prompt)
return(output)
async def generate2(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
generate(prompt, typing_fn, temperature, nax_new_tokens, timeout)
class RunpodImageWrapper(RunpodWrapper):
async def download_image(self, url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)
async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"negative_prompt": negative_prompt,
"width": 512,
"height": 768,
"num_outputs": 3,
# "nsfw": True
},
}
output = await super().generate(input_data, typing_fn, timeout)
os.makedirs("./.data/images", exist_ok=True)
files = []
for image in output:
temp_name = next(tempfile._get_candidate_names())
filename = "./.data/images/" + temp_name + ".jpg"
await self.download_image(image["image"], filename)
files.append(filename)
return files
class RunpodImageWrapper2(RunpodWrapper):
async def download_image(self, url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)
async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"negative_prompt": negative_prompt,
"h": 768,
"w": 768,
"num_images": 3,
"seed": -1
},
}
output = await super().generate(input_data, typing_fn, timeout)
os.makedirs("./.data/images", exist_ok=True)
files = []
for image in output:
temp_name = next(tempfile._get_candidate_names())
filename = "./.data/images/" + temp_name + ".jpg"
await self.download_image(image["image"], filename)
files.append(filename)
return files
class RunpodImageAutomaticWrapper(RunpodWrapper):
async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"nagative_prompt": negative_prompt,
"steps": 25,
"cfg_scale": 7,
"seed": -1,
"width": 512,
"height": 768,
"batch_size": 3,
# "sampler_index": "DPM++ 2M Karras",
# "enable_hr": True,
# "hr_scale": 2,
# "hr_upscaler": "ESRGAN_4x", # "Latent"
# "denoising_strength": 0.5,
# "hr_second_pass_steps": 15,
"restore_faces": True,
# "gfpgan_visibility": 0.5,
# "codeformer_visibility": 0.5,
# "codeformer_weight": 0.5,
## "override_settings": {
## "filter_nsfw": False,
## },
"api_endpoint": "txt2img",
},
"cmd": "txt2img"
}
output = await super().generate(input_data, typing_fn, timeout)
upscale = False
if upscale:
count = 0
for i in output['images']:
payload = {
"init_images": [i],
"prompt": input_prompt,
"nagative_prompt": negative_prompt,
"steps": 20,
"seed": -1,
#"sampler_index": "Euler",
# tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale
# "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875],
# "script_name": "Ultimate SD upscale",
}
upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn)
output['images'][count] = upscaled_output['images'][count]
os.makedirs("./.data/images", exist_ok=True)
files = []
for i in output['images']:
temp_name = next(tempfile._get_candidate_names())
filename = "./.data/images/" + temp_name + ".png"
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
info = output['info']
parameters = output['parameters']
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", info)
image.save(filename, pnginfo=pnginfo)
files.append(filename)
return files