Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
7.9 KiB

2 years ago
import asyncio
import requests
import json
2 years ago
import os, tempfile
import io
import base64
from PIL import Image, PngImagePlugin
2 years ago
import logging
logger = logging.getLogger(__name__)
class RunpodWrapper(object):
"""Base Class for runpod"""
2 years ago
def __init__(self, api_key: str, endpoint_name: str, model_name: str):
2 years ago
self.api_key = api_key
2 years ago
self.endpoint_name = endpoint_name
self.model_name = model_name
2 years ago
2 years ago
async def generate(self, input_data: str, typing_fn, timeout=180):
2 years ago
# Set the API endpoint URL
2 years ago
endpoint = f"https://api.runpod.ai/v2/{self.endpoint_name}/run"
2 years ago
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
2 years ago
logger.info(f"sending request to runpod.io. endpoint=\"{self.endpoint_name}\"")
2 years ago
# Make the request
try:
r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout)
except requests.exceptions.RequestException as e:
raise ValueError(f"<HTTP ERROR>")
r_json = r.json()
logger.debug(r_json)
if r.status_code == 200:
status = r_json["status"]
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 5
for i in range(TIMEOUT//DELAY):
2 years ago
endpoint = f"https://api.runpod.ai/v2/{self.endpoint_name}/status/{job_id}"
2 years ago
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
output = r_json["output"]
return output
else:
err_msg = r_json["error"] if "error" in r_json else ""
err_msg = err_msg.replace("\\n", "\n")
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
raise ValueError(f"<ERROR> TIMEOUT")
else:
raise ValueError(f"<ERROR>")
2 years ago
class RunpodTextWrapper(RunpodWrapper):
async def generate(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": prompt,
"max_length": min(max_new_tokens, 2048),
"temperature": bot.temperature,
"do_sample": True,
}
}
output = await super().generate(input_data, api_key, typing_fn, timeout)
output = output.removeprefix(prompt)
return(output)
async def generate2(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
generate(prompt, typing_fn, temperature, nax_new_tokens, timeout)
class RunpodImageWrapper(RunpodWrapper):
async def download_image(self, url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)
async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"negative_prompt": negative_prompt,
"width": 512,
"height": 768,
"num_outputs": 3,
# "nsfw": True
},
}
output = await super().generate(input_data, typing_fn, timeout)
os.makedirs("./.data/images", exist_ok=True)
files = []
for image in output:
temp_name = next(tempfile._get_candidate_names())
filename = "./.data/images/" + temp_name + ".jpg"
await self.download_image(image["image"], filename)
files.append(filename)
return files
class RunpodImageWrapper2(RunpodWrapper):
async def download_image(self, url, path):
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(path, 'wb') as f:
for chunk in r:
f.write(chunk)
async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"negative_prompt": negative_prompt,
"h": 768,
"w": 768,
"num_images": 3,
"seed": -1
},
}
output = await super().generate(input_data, typing_fn, timeout)
os.makedirs("./.data/images", exist_ok=True)
files = []
for image in output['images']:
temp_name = next(tempfile._get_candidate_names())
2 years ago
filename = "./.data/images/" + temp_name + ".jpg"
await self.download_image(image, filename)
2 years ago
files.append(filename)
return files
class RunpodImageAutomaticWrapper(RunpodWrapper):
async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": input_prompt,
"nagative_prompt": negative_prompt,
"steps": 25,
"cfg_scale": 7,
"seed": -1,
"width": 512,
"height": 768,
"batch_size": 3,
# "sampler_index": "DPM++ 2M Karras",
# "enable_hr": True,
# "hr_scale": 2,
# "hr_upscaler": "ESRGAN_4x", # "Latent"
# "denoising_strength": 0.5,
# "hr_second_pass_steps": 15,
"restore_faces": True,
# "gfpgan_visibility": 0.5,
# "codeformer_visibility": 0.5,
# "codeformer_weight": 0.5,
## "override_settings": {
## "filter_nsfw": False,
## },
"api_endpoint": "txt2img",
},
"cmd": "txt2img"
}
output = await super().generate(input_data, typing_fn, timeout)
upscale = False
if upscale:
count = 0
for i in output['images']:
payload = {
"init_images": [i],
"prompt": input_prompt,
"nagative_prompt": negative_prompt,
"steps": 20,
"seed": -1,
#"sampler_index": "Euler",
# tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale
# "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875],
# "script_name": "Ultimate SD upscale",
}
upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn)
output['images'][count] = upscaled_output['images'][count]
os.makedirs("./.data/images", exist_ok=True)
files = []
for i in output['images']:
temp_name = next(tempfile._get_candidate_names())
filename = "./.data/images/" + temp_name + ".png"
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
info = output['info']
parameters = output['parameters']
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", info)
image.save(filename, pnginfo=pnginfo)
files.append(filename)
return files