From 57925310dea6002f64c4b8cb67046dea21b716b4 Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Mon, 3 Jul 2023 22:44:47 +0200 Subject: [PATCH] update runpod image wrapper --- matrix_pygmalion_bot/bot/core.py | 28 +++++++++++------ matrix_pygmalion_bot/bot/wrappers/runpod.py | 35 +++++++++++++++++++++ 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/matrix_pygmalion_bot/bot/core.py b/matrix_pygmalion_bot/bot/core.py index 1c16232..0eb9641 100644 --- a/matrix_pygmalion_bot/bot/core.py +++ b/matrix_pygmalion_bot/bot/core.py @@ -104,15 +104,17 @@ class ChatBot(object): i = text_endpoint['id'] self.text_generators[i] = text_generator - from .wrappers.runpod import RunpodImageWrapper + from .wrappers.runpod import RunpodImageWrapper, RunpodImageWrapper2 from .wrappers.runpod import RunpodImageAutomaticWrapper from .wrappers.stablehorde import StableHordeImageWrapper self.image_generators = {} for image_endpoint in sorted(available_image_endpoints, key=lambda d: d['id']): - if image_endpoint['service'] == "runpod": - image_generator = RunpodImageWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model']) - elif image_endpoint['service'] == "runpod-automatic1111": + if image_endpoint['service'] == "runpod-automatic1111": image_generator = RunpodImageAutomaticWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model']) + elif image_endpoint['service'] == "runpod" and image_endpoint['model'].startswith('kandinsky'): + image_generator = RunpodImageWrapper2(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model']) + elif image_endpoint['service'] == "runpod": + image_generator = RunpodImageWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model']) elif image_endpoint['service'] == "stablehorde": image_generator = StableHordeImageWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model']) else: @@ -211,20 +213,26 @@ class ChatBot(object): async def process_command(self, message, reply_fn, typing_fn): if message.message.startswith("!replybot"): await reply_fn("Hello World") - elif re.search("(?s)^!image(?P[0-9])?(\s(?P.*))?$", message.message, flags=re.DOTALL): - m = re.search("(?s)^!image(?P[0-9])?(\s(?P.*))?$", message.message, flags=re.DOTALL) + elif re.search("(?s)^!image(?P[0-9]+)?(\s(?P.*))?$", message.message, flags=re.DOTALL): + m = re.search("(?s)^!image(?P[0-9]+)?(\s(?P.*))?$", message.message, flags=re.DOTALL) if m['num']: num = int(m['num']) else: num = 1 - if m['cmd']: - prompt = m['cmd'].strip() - else: - prompt = "a beautiful woman" + + prompt = "a beautiful woman" negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head" #"anime, cartoon, penis, fake, drawing, illustration, boring, 3d render, long neck, out of frame, extra fingers, mutated hands, monochrome, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, glitchy, bokeh, (((long neck))), 3D, 3DCG, cgstation, red eyes, multiple subjects, extra heads, close up, watermarks, logo" #"ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" #"ugly, deformed, out of frame" + + if m['cmd']: + prompt = m['cmd'].strip() + prompt_split = prompt.rsplit('|', 1) + if len(prompt_split) == 2: + prompt = prompt_split[0].strip() + negative_prompt = prompt_split[1].strip() + try: output = await self.image_generators[num].generate(prompt, negative_prompt, typing_fn) await self.connection.room_typing(message.room_id, False) diff --git a/matrix_pygmalion_bot/bot/wrappers/runpod.py b/matrix_pygmalion_bot/bot/wrappers/runpod.py index 7d0760c..65305b3 100644 --- a/matrix_pygmalion_bot/bot/wrappers/runpod.py +++ b/matrix_pygmalion_bot/bot/wrappers/runpod.py @@ -121,6 +121,41 @@ class RunpodImageWrapper(RunpodWrapper): return files +class RunpodImageWrapper2(RunpodWrapper): + async def download_image(self, url, path): + r = requests.get(url, stream=True) + if r.status_code == 200: + with open(path, 'wb') as f: + for chunk in r: + f.write(chunk) + + async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180): + + # Define your inputs + input_data = { + "input": { + "prompt": input_prompt, + "negative_prompt": negative_prompt, + "h": 768, + "w": 768, + "num_images": 3, + "seed": -1 + }, + } + + output = await super().generate(input_data, typing_fn, timeout) + + os.makedirs("./.data/images", exist_ok=True) + files = [] + for image in output: + temp_name = next(tempfile._get_candidate_names()) + filename = "./.data/images/" + temp_name + ".jpg" + await self.download_image(image["image"], filename) + files.append(filename) + + return files + + class RunpodImageAutomaticWrapper(RunpodWrapper): async def generate(self, input_prompt: str, negative_prompt: str, typing_fn, timeout=180):