Browse Source

add test command for custom remote endpoint

master
Hendrik Langer 2 years ago
parent
commit
fb826a4ae5
  1. 2
      matrix_pygmalion_bot/ai/runpod_pygmalion.py
  2. 6
      matrix_pygmalion_bot/chatlog.py
  3. 6
      matrix_pygmalion_bot/core.py

2
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -84,6 +84,7 @@ async def generate_sync(
return reply return reply
else: else:
err_msg = r_json["error"] if "error" in r_json else "" err_msg = r_json["error"] if "error" in r_json else ""
err_msg = err_msg.replace("\\n", "\n")
raise ValueError(f"RETURN CODE {status}: {err_msg}") raise ValueError(f"RETURN CODE {status}: {err_msg}")
raise ValueError(f"<TIMEOUT>") raise ValueError(f"<TIMEOUT>")
else: else:
@ -217,6 +218,7 @@ async def generate_image(input_prompt: str, negative_prompt: str, api_url: str,
break break
else: else:
err_msg = r_json["error"] if "error" in r_json else "" err_msg = r_json["error"] if "error" in r_json else ""
err_msg = err_msg.replace("\\n", "\n")
raise ValueError(f"RETURN CODE {status}: {err_msg}") raise ValueError(f"RETURN CODE {status}: {err_msg}")
if not output: if not output:

6
matrix_pygmalion_bot/chatlog.py

@ -30,7 +30,11 @@ class ChatMessage:
if not (to_lang in self.message): if not (to_lang in self.message):
self.message[to_lang] = translate.translate(self.message["en"], "en", to_lang) self.message[to_lang] = translate.translate(self.message["en"], "en", to_lang)
return self.message[to_lang] return self.message[to_lang]
def updateText(self, new_text, language="en"):
self.message[self.language] = new_text
self.num_tokens = None
if not (language == "en"):
self.message["en"] = translate.translate(message, language, "en")
class ChatHistory: class ChatHistory:
def __init__(self, bot_name, room_name): def __init__(self, bot_name, room_name):

6
matrix_pygmalion_bot/core.py

@ -67,6 +67,7 @@ class Callbacks(object):
) )
) )
api_endpoint = "pygmalion-6b"
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id) await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
# Ignore messages when disabled # Ignore messages when disabled
if "disabled" in self.bot.room_config[room.room_id] and self.bot.room_config[room.room_id]["disabled"] == True and not event.body.startswith('!start'): if "disabled" in self.bot.room_config[room.room_id] and self.bot.room_config[room.room_id]["disabled"] == True and not event.body.startswith('!start'):
@ -202,6 +203,9 @@ class Callbacks(object):
new_answer = event.body.removeprefix('!replace').strip() new_answer = event.body.removeprefix('!replace').strip()
await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event)
return return
elif event.body.startswith('!2'):
chat_message.updateText( event.body.removeprefix('!2').strip() )
api_endpoint = "ynznznpn6qz6yh"
elif event.body.startswith('!'): elif event.body.startswith('!'):
await self.bot.send_message(self.client, room.room_id, "UNKNOWN COMMAND") await self.bot.send_message(self.client, room.room_id, "UNKNOWN COMMAND")
return return
@ -232,7 +236,7 @@ class Callbacks(object):
# print("") # print("")
try: try:
typing = lambda : self.client.room_typing(room.room_id, True, 15000) typing = lambda : self.client.room_typing(room.room_id, True, 15000)
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing) answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing, api_endpoint)
answer = answer.strip() answer = answer.strip()
await self.client.room_typing(room.room_id, False) await self.client.room_typing(room.room_id, False)
if not (self.bot.translate is None): if not (self.bot.translate is None):

Loading…
Cancel
Save