diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod_pygmalion.py index 0d63f02..0fd74cb 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod_pygmalion.py @@ -19,6 +19,7 @@ async def generate_sync( prompt: str, api_key: str, bot, + typing_fn ): # Set the API endpoint URL endpoint = "https://api.runpod.ai/v2/pygmalion-6b/run" @@ -61,6 +62,7 @@ async def generate_sync( logger.info(r_json) status = r_json["status"] if status == 'IN_PROGRESS': + await typing_fn() await asyncio.sleep(DELAY) elif status == 'IN_QUEUE': await asyncio.sleep(DELAY) @@ -157,7 +159,7 @@ async def download_image(url, path): for chunk in r: f.write(chunk) -async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str): +async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn): # Set the API endpoint URL endpoint = api_url + "run" @@ -203,6 +205,7 @@ async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, logger.debug(r_json) status = r_json["status"] if status == 'IN_PROGRESS': + await typing_fn() await asyncio.sleep(DELAY) elif status == 'IN_QUEUE': await asyncio.sleep(DELAY) @@ -226,25 +229,25 @@ async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, return files -async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str): - return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key) +async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): + return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key, typing_fn) -async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str): - return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key) +async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): + return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key, typing_fn) -async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str): - return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key) +async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): + return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key, typing_fn) -async def generate_image4(input_prompt: str, negative_prompt: str, api_key: str): - return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/lxdhmiccp3vdsf/", api_key) +async def generate_image4(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): + return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/lxdhmiccp3vdsf/", api_key, typing_fn) -async def generate_image5(input_prompt: str, negative_prompt: str, api_key: str): - return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/13rrs00l7yxikf/", api_key) +async def generate_image5(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): + return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/13rrs00l7yxikf/", api_key, typing_fn) -async def generate_image6(input_prompt: str, negative_prompt: str, api_key: str): - return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/5j1xzlsyw84vk5/", api_key) +async def generate_image6(input_prompt: str, negative_prompt: str, api_key: str, typing_fn): + return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/5j1xzlsyw84vk5/", api_key, typing_fn) -async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str): +async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str, typing_fn): # Set the API endpoint URL endpoint = api_url + "run" @@ -284,6 +287,7 @@ async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str) logger.debug(r_json) status = r_json["status"] if status == 'IN_PROGRESS': + await typing_fn() await asyncio.sleep(DELAY) elif status == 'IN_QUEUE': await asyncio.sleep(DELAY) @@ -300,7 +304,7 @@ async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str) return output -async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_url: str, api_key: str): +async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn): payload = { "prompt": input_prompt, "nagative_prompt": negative_prompt, @@ -321,7 +325,7 @@ async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_ ## }, } - output = await serverless_automatic_request(payload, "txt2img", api_url, api_key) + output = await serverless_automatic_request(payload, "txt2img", api_url, api_key, typing_fn) upscale = False if upscale: @@ -337,7 +341,7 @@ async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_ # "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875], # "script_name": "Ultimate SD upscale", } - upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key) + upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn) output['images'][count] = upscaled_output['images'][count] os.makedirs("./images", exist_ok=True) diff --git a/matrix_pygmalion_bot/core.py b/matrix_pygmalion_bot/core.py index 17cbd9d..3d1007d 100644 --- a/matrix_pygmalion_bot/core.py +++ b/matrix_pygmalion_bot/core.py @@ -83,7 +83,7 @@ class Callbacks(object): print(event) await self.bot.send_message(self.client, room.room_id, "Hello World!") return - elif re.search("^!image(?P[0-9])(\s(?P.*))?$", event.body): + elif re.search("^!image(?P[0-9])?(\s(?P.*))?$", event.body): m = re.search("^!image(?P[0-9])?(\s(?P.*))?$", event.body) if m['num']: num = int(m['num']) @@ -109,38 +109,41 @@ class Callbacks(object): # else: # negative_prompt = "ugly, deformed, out of frame" try: + typing = lambda : self.client.room_typing(room.room_id, True, 15000) if self.bot.service == "runpod": if num == 1: - output = await ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key) + output = await ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 2: - output = await ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key) + output = await ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 3: - output = await ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key) + output = await ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 4: - output = await ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key) + output = await ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 5: - output = await ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key) + output = await ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 6: - output = await ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key) + output = await ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing) else: raise ValueError('no image generator with that number') elif self.bot.service == "stablehorde": if num == 1: - output = await ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key) + output = await ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) elif num == 2: - output = await ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key) + output = await ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) elif num == 3: - output = await ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key) + output = await ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) else: raise ValueError('no image generator with that number') else: raise ValueError('remote image generation not configured properly') except ValueError as err: + await self.client.room_typing(room.room_id, False) errormessage = f" {err=}, {type(err)=}" logger.error(errormessage) await self.bot.send_message(self.client, room.room_id, errormessage) return + await self.client.room_typing(room.room_id, False) for imagefile in output: await self.bot.send_image(self.client, room.room_id, imagefile) return @@ -221,19 +224,25 @@ class Callbacks(object): # error = e.__str__() # answer = answer.strip() # print("") - await self.client.room_typing(room.room_id, True, 15000) - answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot) - answer = answer.strip() - await self.client.room_typing(room.room_id, False) - if not (self.bot.translate is None): - translated_answer = translate.translate(answer, "en", self.bot.translate) - await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=chat_message.event_id, original_message=answer) - else: - await self.bot.send_message(self.client, room.room_id, answer, reply_to=chat_message.event_id) - if not "message_count" in self.bot.room_config[room.room_id]: - self.bot.room_config[room.room_id]["message_count"] = 0 - self.bot.room_config[room.room_id]["message_count"] += 1 - + try: + typing = lambda : self.client.room_typing(room.room_id, True, 15000) + answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing) + answer = answer.strip() + await self.client.room_typing(room.room_id, False) + if not (self.bot.translate is None): + translated_answer = translate.translate(answer, "en", self.bot.translate) + await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=chat_message.event_id, original_message=answer) + else: + await self.bot.send_message(self.client, room.room_id, answer, reply_to=chat_message.event_id) + if not "message_count" in self.bot.room_config[room.room_id]: + self.bot.room_config[room.room_id]["message_count"] = 0 + self.bot.room_config[room.room_id]["message_count"] += 1 + except ValueError as err: + await self.client.room_typing(room.room_id, False) + errormessage = f" {err=}, {type(err)=}" + logger.error(errormessage) + await self.bot.send_message(self.client, room.room_id, errormessage) + return async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None: """Automatically join all rooms we get invited to"""