You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
338 lines
12 KiB
338 lines
12 KiB
import asyncio
|
|
import os, tempfile
|
|
import logging
|
|
|
|
import json
|
|
import requests
|
|
|
|
from transformers import AutoTokenizer, AutoConfig
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
import io
|
|
import base64
|
|
from PIL import Image, PngImagePlugin
|
|
|
|
from .model_helpers import get_full_prompt, num_tokens
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def generate_sync(
|
|
prompt: str,
|
|
api_key: str,
|
|
bot,
|
|
typing_fn,
|
|
api_endpoint = "pygmalion-6b"
|
|
):
|
|
# Set the API endpoint URL
|
|
endpoint = f"https://api.runpod.ai/v2/{api_endpoint}/run"
|
|
|
|
# Set the headers for the request
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
max_new_tokens = 200
|
|
prompt_num_tokens = await num_tokens(prompt, bot.model)
|
|
|
|
# Define your inputs
|
|
input_data = {
|
|
"input": {
|
|
"prompt": prompt,
|
|
"max_length": min(prompt_num_tokens+max_new_tokens, 2048),
|
|
"temperature": bot.temperature,
|
|
"do_sample": True,
|
|
}
|
|
}
|
|
|
|
logger.info(f"sending request to runpod.io")
|
|
|
|
# Make the request
|
|
try:
|
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180)
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"<HTTP ERROR>")
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
|
|
if r.status_code == 200:
|
|
status = r_json["status"]
|
|
job_id = r_json["id"]
|
|
TIMEOUT = 360
|
|
DELAY = 5
|
|
for i in range(TIMEOUT//DELAY):
|
|
endpoint = f"https://api.runpod.ai/v2/{api_endpoint}/status/{job_id}"
|
|
r = requests.get(endpoint, headers=headers)
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
status = r_json["status"]
|
|
if status == 'IN_PROGRESS':
|
|
await typing_fn()
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'IN_QUEUE':
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'COMPLETED':
|
|
output = r_json["output"]
|
|
if isinstance(output, list):
|
|
output.sort(key=len, reverse=True)
|
|
text = output[0]
|
|
else:
|
|
text = r_json["output"]
|
|
answer = text.removeprefix(prompt)
|
|
# lines = reply.split('\n')
|
|
# reply = lines[0].strip()
|
|
idx = answer.find(f"\nYou:")
|
|
if idx != -1:
|
|
reply = answer[:idx].strip()
|
|
else:
|
|
reply = answer.removesuffix('<|endoftext|>').strip()
|
|
reply = reply.replace(f"<BOT>", f"{bot.name}")
|
|
reply = reply.replace(f"<USER>", f"You")
|
|
reply = reply.replace(f"\n{bot.name}: ", " ")
|
|
return reply
|
|
else:
|
|
err_msg = r_json["error"] if "error" in r_json else ""
|
|
err_msg = err_msg.replace("\\n", "\n")
|
|
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
|
|
raise ValueError(f"<ERROR> TIMEOUT")
|
|
else:
|
|
raise ValueError(f"<ERROR>")
|
|
|
|
|
|
async def download_image(url, path):
|
|
r = requests.get(url, stream=True)
|
|
if r.status_code == 200:
|
|
with open(path, 'wb') as f:
|
|
for chunk in r:
|
|
f.write(chunk)
|
|
|
|
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
|
|
|
|
# Set the API endpoint URL
|
|
endpoint = api_url + "run"
|
|
|
|
# Set the headers for the request
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
# Define your inputs
|
|
input_data = {
|
|
"input": {
|
|
"prompt": input_prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"width": 512,
|
|
"height": 768,
|
|
"num_outputs": 3,
|
|
# "nsfw": True
|
|
},
|
|
}
|
|
|
|
logger.info(f"sending request to runpod.io")
|
|
|
|
# Make the request
|
|
try:
|
|
r = requests.post(endpoint, json=input_data, headers=headers)
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"<ERROR> HTTP ERROR")
|
|
r_json = r.json()
|
|
logger.debug(r_json)
|
|
|
|
if r.status_code == 200:
|
|
status = r_json["status"]
|
|
if status != 'IN_QUEUE':
|
|
err_msg = r_json["error"] if "error" in r_json else ""
|
|
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
|
|
job_id = r_json["id"]
|
|
TIMEOUT = 360
|
|
DELAY = 5
|
|
output = None
|
|
for i in range(TIMEOUT//DELAY):
|
|
endpoint = api_url + "status/" + job_id
|
|
r = requests.get(endpoint, headers=headers)
|
|
r_json = r.json()
|
|
logger.debug(r_json)
|
|
status = r_json["status"]
|
|
if status == 'IN_PROGRESS':
|
|
await typing_fn()
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'IN_QUEUE':
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'COMPLETED':
|
|
output = r_json["output"]
|
|
break
|
|
else:
|
|
err_msg = r_json["error"] if "error" in r_json else ""
|
|
err_msg = err_msg.replace("\\n", "\n")
|
|
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
|
|
|
|
if not output:
|
|
raise ValueError(f"<ERROR>")
|
|
|
|
os.makedirs("./images", exist_ok=True)
|
|
files = []
|
|
for image in output:
|
|
temp_name = next(tempfile._get_candidate_names())
|
|
filename = "./images/" + temp_name + ".jpg"
|
|
await download_image(image["image"], filename)
|
|
files.append(filename)
|
|
|
|
return files
|
|
|
|
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# AnythingV4
|
|
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key, typing_fn)
|
|
|
|
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# OpenJourney
|
|
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key, typing_fn)
|
|
|
|
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# Hassanblend
|
|
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key, typing_fn)
|
|
|
|
async def generate_image4(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# DeliberateV2
|
|
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/lxdhmiccp3vdsf/", api_key, typing_fn)
|
|
|
|
async def generate_image5(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# Hassanblend
|
|
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/13rrs00l7yxikf/", api_key, typing_fn)
|
|
|
|
async def generate_image6(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# PFG
|
|
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/5j1xzlsyw84vk5/", api_key, typing_fn)
|
|
|
|
async def generate_image7(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# ChilloutMix
|
|
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v2/rrjxafqx66osr4/", api_key, typing_fn)
|
|
|
|
async def generate_image8(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
|
|
# AbyssOrangeMixV3
|
|
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v2/vuyifmsasm3ix7/", api_key, typing_fn)
|
|
|
|
|
|
async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str, typing_fn):
|
|
# Set the API endpoint URL
|
|
endpoint = api_url + "run"
|
|
|
|
# Set the headers for the request
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
# Define your inputs
|
|
payload.update({"api_endpoint": str(cmd)})
|
|
input_data = {
|
|
"input": payload,
|
|
"cmd": cmd,
|
|
}
|
|
|
|
logger.info(f"sending request to runpod.io")
|
|
|
|
# Make the request
|
|
try:
|
|
r = requests.post(endpoint, json=input_data, headers=headers)
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"<ERROR> HTTP ERROR")
|
|
r_json = r.json()
|
|
logger.debug(r_json)
|
|
|
|
if r.status_code == 200:
|
|
status = r_json["status"]
|
|
if status != 'IN_QUEUE':
|
|
err_msg = r_json["error"] if "error" in r_json else ""
|
|
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
|
|
job_id = r_json["id"]
|
|
TIMEOUT = 360
|
|
DELAY = 5
|
|
output = None
|
|
for i in range(TIMEOUT//DELAY):
|
|
endpoint = api_url + "status/" + job_id
|
|
r = requests.get(endpoint, headers=headers)
|
|
r_json = r.json()
|
|
logger.debug(r_json)
|
|
status = r_json["status"]
|
|
if status == 'IN_PROGRESS':
|
|
await typing_fn()
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'IN_QUEUE':
|
|
await asyncio.sleep(DELAY)
|
|
elif status == 'COMPLETED':
|
|
output = r_json["output"]
|
|
break
|
|
else:
|
|
err_msg = r_json["error"] if "error" in r_json else ""
|
|
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
|
|
|
|
if not output:
|
|
raise ValueError(f"<ERROR> {status}")
|
|
|
|
return output
|
|
|
|
|
|
async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
|
|
payload = {
|
|
"prompt": input_prompt,
|
|
"nagative_prompt": negative_prompt,
|
|
"steps": 25,
|
|
"cfg_scale": 7,
|
|
"seed": -1,
|
|
"width": 512,
|
|
"height": 768,
|
|
"batch_size": 3,
|
|
# "sampler_index": "DPM++ 2M Karras",
|
|
# "enable_hr": True,
|
|
# "hr_scale": 2,
|
|
# "hr_upscaler": "ESRGAN_4x", # "Latent"
|
|
# "denoising_strength": 0.5,
|
|
# "hr_second_pass_steps": 15,
|
|
"restore_faces": True,
|
|
# "gfpgan_visibility": 0.5,
|
|
# "codeformer_visibility": 0.5,
|
|
# "codeformer_weight": 0.5,
|
|
## "override_settings": {
|
|
## "filter_nsfw": False,
|
|
## },
|
|
}
|
|
|
|
output = await serverless_automatic_request(payload, "txt2img", api_url, api_key, typing_fn)
|
|
|
|
upscale = False
|
|
if upscale:
|
|
count = 0
|
|
for i in output['images']:
|
|
payload = {
|
|
"init_images": [i],
|
|
"prompt": input_prompt,
|
|
"nagative_prompt": negative_prompt,
|
|
"steps": 20,
|
|
"seed": -1,
|
|
#"sampler_index": "Euler",
|
|
# tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale
|
|
# "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875],
|
|
# "script_name": "Ultimate SD upscale",
|
|
}
|
|
upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn)
|
|
output['images'][count] = upscaled_output['images'][count]
|
|
|
|
os.makedirs("./images", exist_ok=True)
|
|
files = []
|
|
for i in output['images']:
|
|
temp_name = next(tempfile._get_candidate_names())
|
|
filename = "./images/" + temp_name + ".png"
|
|
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
|
|
info = output['info']
|
|
parameters = output['parameters']
|
|
pnginfo = PngImagePlugin.PngInfo()
|
|
pnginfo.add_text("parameters", info)
|
|
image.save(filename, pnginfo=pnginfo)
|
|
|
|
files.append(filename)
|
|
|
|
return files
|
|
|
|
|
|
|