Browse Source

async and typing notiications

master
Hendrik Langer 2 years ago
parent
commit
5f4f54903a
  1. 38
      matrix_pygmalion_bot/ai/runpod_pygmalion.py
  2. 35
      matrix_pygmalion_bot/core.py

38
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -19,6 +19,7 @@ async def generate_sync(
prompt: str, prompt: str,
api_key: str, api_key: str,
bot, bot,
typing_fn
): ):
# Set the API endpoint URL # Set the API endpoint URL
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/run" endpoint = "https://api.runpod.ai/v2/pygmalion-6b/run"
@ -61,6 +62,7 @@ async def generate_sync(
logger.info(r_json) logger.info(r_json)
status = r_json["status"] status = r_json["status"]
if status == 'IN_PROGRESS': if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE': elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
@ -157,7 +159,7 @@ async def download_image(url, path):
for chunk in r: for chunk in r:
f.write(chunk) f.write(chunk)
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str): async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
# Set the API endpoint URL # Set the API endpoint URL
endpoint = api_url + "run" endpoint = api_url + "run"
@ -203,6 +205,7 @@ async def generate_image(input_prompt: str, negative_prompt: str, api_url: str,
logger.debug(r_json) logger.debug(r_json)
status = r_json["status"] status = r_json["status"]
if status == 'IN_PROGRESS': if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE': elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
@ -226,25 +229,25 @@ async def generate_image(input_prompt: str, negative_prompt: str, api_url: str,
return files return files
async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str): async def generate_image1(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key) return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-anything-v4/", api_key, typing_fn)
async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str): async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key) return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key, typing_fn)
async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str): async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key) return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key, typing_fn)
async def generate_image4(input_prompt: str, negative_prompt: str, api_key: str): async def generate_image4(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/lxdhmiccp3vdsf/", api_key) return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/lxdhmiccp3vdsf/", api_key, typing_fn)
async def generate_image5(input_prompt: str, negative_prompt: str, api_key: str): async def generate_image5(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/13rrs00l7yxikf/", api_key) return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/13rrs00l7yxikf/", api_key, typing_fn)
async def generate_image6(input_prompt: str, negative_prompt: str, api_key: str): async def generate_image6(input_prompt: str, negative_prompt: str, api_key: str, typing_fn):
return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/5j1xzlsyw84vk5/", api_key) return await generate_image_automatic(input_prompt, negative_prompt, "https://api.runpod.ai/v1/5j1xzlsyw84vk5/", api_key, typing_fn)
async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str): async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str, typing_fn):
# Set the API endpoint URL # Set the API endpoint URL
endpoint = api_url + "run" endpoint = api_url + "run"
@ -284,6 +287,7 @@ async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str)
logger.debug(r_json) logger.debug(r_json)
status = r_json["status"] status = r_json["status"]
if status == 'IN_PROGRESS': if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE': elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
@ -300,7 +304,7 @@ async def serverless_automatic_request(payload, cmd, api_url: str, api_key: str)
return output return output
async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_url: str, api_key: str): async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
payload = { payload = {
"prompt": input_prompt, "prompt": input_prompt,
"nagative_prompt": negative_prompt, "nagative_prompt": negative_prompt,
@ -321,7 +325,7 @@ async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_
## }, ## },
} }
output = await serverless_automatic_request(payload, "txt2img", api_url, api_key) output = await serverless_automatic_request(payload, "txt2img", api_url, api_key, typing_fn)
upscale = False upscale = False
if upscale: if upscale:
@ -337,7 +341,7 @@ async def generate_image_automatic(input_prompt: str, negative_prompt: str, api_
# "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875], # "script_args": ["",512,0,8,32,64,0.275,32,3,False,0,True,8,3,2,1080,1440,1.875],
# "script_name": "Ultimate SD upscale", # "script_name": "Ultimate SD upscale",
} }
upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key) upscaled_output = await serverless_automatic_request(payload, "img2img", api_url, api_key, typing_fn)
output['images'][count] = upscaled_output['images'][count] output['images'][count] = upscaled_output['images'][count]
os.makedirs("./images", exist_ok=True) os.makedirs("./images", exist_ok=True)

35
matrix_pygmalion_bot/core.py

@ -83,7 +83,7 @@ class Callbacks(object):
print(event) print(event)
await self.bot.send_message(self.client, room.room_id, "Hello World!") await self.bot.send_message(self.client, room.room_id, "Hello World!")
return return
elif re.search("^!image(?P<num>[0-9])(\s(?P<cmd>.*))?$", event.body): elif re.search("^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", event.body):
m = re.search("^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", event.body) m = re.search("^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", event.body)
if m['num']: if m['num']:
num = int(m['num']) num = int(m['num'])
@ -109,38 +109,41 @@ class Callbacks(object):
# else: # else:
# negative_prompt = "ugly, deformed, out of frame" # negative_prompt = "ugly, deformed, out of frame"
try: try:
typing = lambda : self.client.room_typing(room.room_id, True, 15000)
if self.bot.service == "runpod": if self.bot.service == "runpod":
if num == 1: if num == 1:
output = await ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key) output = await ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing)
elif num == 2: elif num == 2:
output = await ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key) output = await ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing)
elif num == 3: elif num == 3:
output = await ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key) output = await ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing)
elif num == 4: elif num == 4:
output = await ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key) output = await ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing)
elif num == 5: elif num == 5:
output = await ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key) output = await ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing)
elif num == 6: elif num == 6:
output = await ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key) output = await ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing)
else: else:
raise ValueError('no image generator with that number') raise ValueError('no image generator with that number')
elif self.bot.service == "stablehorde": elif self.bot.service == "stablehorde":
if num == 1: if num == 1:
output = await ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key) output = await ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing)
elif num == 2: elif num == 2:
output = await ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key) output = await ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing)
elif num == 3: elif num == 3:
output = await ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key) output = await ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing)
else: else:
raise ValueError('no image generator with that number') raise ValueError('no image generator with that number')
else: else:
raise ValueError('remote image generation not configured properly') raise ValueError('remote image generation not configured properly')
except ValueError as err: except ValueError as err:
await self.client.room_typing(room.room_id, False)
errormessage = f"<ERROR> {err=}, {type(err)=}" errormessage = f"<ERROR> {err=}, {type(err)=}"
logger.error(errormessage) logger.error(errormessage)
await self.bot.send_message(self.client, room.room_id, errormessage) await self.bot.send_message(self.client, room.room_id, errormessage)
return return
await self.client.room_typing(room.room_id, False)
for imagefile in output: for imagefile in output:
await self.bot.send_image(self.client, room.room_id, imagefile) await self.bot.send_image(self.client, room.room_id, imagefile)
return return
@ -221,8 +224,9 @@ class Callbacks(object):
# error = e.__str__() # error = e.__str__()
# answer = answer.strip() # answer = answer.strip()
# print("") # print("")
await self.client.room_typing(room.room_id, True, 15000) try:
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot) typing = lambda : self.client.room_typing(room.room_id, True, 15000)
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing)
answer = answer.strip() answer = answer.strip()
await self.client.room_typing(room.room_id, False) await self.client.room_typing(room.room_id, False)
if not (self.bot.translate is None): if not (self.bot.translate is None):
@ -233,7 +237,12 @@ class Callbacks(object):
if not "message_count" in self.bot.room_config[room.room_id]: if not "message_count" in self.bot.room_config[room.room_id]:
self.bot.room_config[room.room_id]["message_count"] = 0 self.bot.room_config[room.room_id]["message_count"] = 0
self.bot.room_config[room.room_id]["message_count"] += 1 self.bot.room_config[room.room_id]["message_count"] += 1
except ValueError as err:
await self.client.room_typing(room.room_id, False)
errormessage = f"<ERROR> {err=}, {type(err)=}"
logger.error(errormessage)
await self.bot.send_message(self.client, room.room_id, errormessage)
return
async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None: async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None:
"""Automatically join all rooms we get invited to""" """Automatically join all rooms we get invited to"""

Loading…
Cancel
Save