From a5339f2486c0c48e76f2b035a98fa91f56a33d9c Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Thu, 30 Mar 2023 00:11:29 +0200 Subject: [PATCH] parallel tasks --- matrix_pygmalion_bot/ai/runpod_pygmalion.py | 3 + matrix_pygmalion_bot/core.py | 71 ++++++++++++++------- matrix_pygmalion_bot/helpers.py | 14 +++- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod_pygmalion.py index 59a2cf8..355762a 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod_pygmalion.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) async def generate_sync( prompt: str, api_key: str, + bot_name: str, ): # Set the API endpoint URL endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync" @@ -54,6 +55,7 @@ async def generate_sync( reply = answer[:idx].strip() else: reply = answer.removesuffix('<|endoftext|>').strip() + reply.replace("\n{bot_name}: ", " ") return reply elif status == 'IN_PROGRESS' or status == 'IN_QUEUE': job_id = r_json["id"] @@ -79,6 +81,7 @@ async def generate_sync( reply = answer[:idx].strip() else: reply = answer.removesuffix('<|endoftext|>').strip() + reply.replace("\n{bot_name}: ", " ") return reply else: return "" diff --git a/matrix_pygmalion_bot/core.py b/matrix_pygmalion_bot/core.py index 290b5c5..ff420bb 100644 --- a/matrix_pygmalion_bot/core.py +++ b/matrix_pygmalion_bot/core.py @@ -3,6 +3,7 @@ import nio from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse) import os, sys +import time import importlib import configparser import logging @@ -10,6 +11,7 @@ import logging import aiofiles.os import magic from PIL import Image +import re from .helpers import ChatItem ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") @@ -23,6 +25,7 @@ STORE_PATH = "./.store/" logger = logging.getLogger(__name__) config = configparser.ConfigParser() bots = [] +background_tasks = set() class Callbacks(object): """Class to pass client to callback methods.""" @@ -32,6 +35,7 @@ class Callbacks(object): self.bot = bot async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: + event_id = event.event_id message = event.body is_own_message = False if event.sender == self.client.user: @@ -39,6 +43,9 @@ class Callbacks(object): is_command = False if event.body.startswith('!'): is_command = True + if not (self.bot.owner is None): + if not (event.sender == self.bot.owner or is_own_message): + return relates_to = None if 'm.relates_to' in event.source["content"]: relates_to = event.source["content"]['m.relates_to']["event_id"] @@ -49,7 +56,7 @@ class Callbacks(object): else: translated_message = translate.translate(message, self.bot.translate, "en") if hasattr(event, 'body'): - self.bot.chat_history[event.event_id] = ChatItem(event.event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message) + self.bot.chat_history[event_id] = ChatItem(event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message) if self.bot.not_synced: return print( @@ -57,7 +64,7 @@ class Callbacks(object): room.display_name, room.user_name(event.sender), event.body ) ) - await self.client.room_read_markers(room.room_id, event.event_id, event.event_id) + await self.client.room_read_markers(room.room_id, event_id, event_id) # Ignore messages from ourselves if is_own_message: return @@ -81,23 +88,26 @@ class Callbacks(object): return elif event.body.startswith('!!!'): return -# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current -# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") -# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() -# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") -# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() -# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") -# return - elif event.body.startswith('!!'): + chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current + await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") + chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() + await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") + chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() + await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") return -# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current -# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") -# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() -# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") -# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current -# self.bot.chat_history[chat_history_event_id] = chat_history_item -# message = chat_history_item.message -# # don't return, we generate a new answer + elif event.body.startswith('!!'): + if len(self.bot.chat_history) < 3: + return + chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current + await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") + chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() + await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") + chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current + self.bot.chat_history[chat_history_event_id] = chat_history_item + event_id = chat_history_item.event_id + message = chat_history_item.message + translated_message = message + # don't return, we generate a new answer full_prompt = await ai.get_full_prompt(translated_message, self.bot) num_tokens = await ai.num_tokens(full_prompt) @@ -119,15 +129,15 @@ class Callbacks(object): # answer = answer.strip() # print("") await self.client.room_typing(room.room_id, True, 15000) - answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key) + answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name) answer = answer.strip() await self.client.room_typing(room.room_id, False) translated_answer = answer if not (self.bot.translate is None): translated_answer = translate.translate(answer, "en", self.bot.translate) - await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event.event_id, original_message=answer) + await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event_id, original_message=answer) else: - await self.bot.send_message(self.client, room.room_id, answer, reply_to=event.event_id) + await self.bot.send_message(self.client, room.room_id, answer, reply_to=event_id) @@ -162,6 +172,7 @@ class ChatBot(object): self.persona = None self.scenario = None self.greeting = None + self.events = [] self.chat_history = {} if STORE_PATH and not os.path.isdir(STORE_PATH): @@ -173,6 +184,13 @@ class ChatBot(object): self.scenario = scenario self.greeting = greeting + async def event_loop(self): + while True: + await asyncio.sleep(60) + print(time.time()) + for event in self.events: + event.loop() + async def login(self): self.config = AsyncClientConfig(store_sync_tokens=True) self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) @@ -181,14 +199,18 @@ class ChatBot(object): self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced)) + event_loop = asyncio.create_task(self.event_loop()) + background_tasks.add(event_loop) + event_loop.add_done_callback(background_tasks.discard) try: response = await self.client.login(self.password) print(response) - await self.client.sync_forever(timeout=30000, full_state=True) + #sync_forever_task = asyncio.create_task(self.client.sync_forever(timeout=30000, full_state=True)) except (asyncio.CancelledError, KeyboardInterrupt): print("Received interrupt.") await self.client.close() + #return sync_forever_task async def watch_for_sync(self, sync_event): print("Awaiting sync") @@ -279,6 +301,9 @@ async def main() -> None: bot.runpod_api_key = config['DEFAULT']['runpod_api_key'] bots.append(bot) await bot.login() - print("logged in") + print("gather") + async with asyncio.TaskGroup() as tg: + for bot in bots: + task = tg.create_task(bot.client.sync_forever(timeout=30000, full_state=True)) asyncio.get_event_loop().run_until_complete(main()) diff --git a/matrix_pygmalion_bot/helpers.py b/matrix_pygmalion_bot/helpers.py index e57f9ce..d84bfa7 100644 --- a/matrix_pygmalion_bot/helpers.py +++ b/matrix_pygmalion_bot/helpers.py @@ -1,4 +1,4 @@ - +import time class ChatItem: def __init__(self, event_id, timestamp, user_name, is_own_message, relates_to_event, message): @@ -13,3 +13,15 @@ class ChatItem: return str("{}: {}".format(self.user_name, self.message)) def getLine(self): return str("{}: {}".format(self.user_name, self.message)) + +class Event: + def __init__(self, timestamp_start, timestamp_stop=None, chance=1): + self.timestamp_start = timestamp_start + self.timestamp_stop = timestamp_stop + self.chance = chance + self.executed = 0 + def __str__(self): + return str("Event starting at timestamp {}".format(self.timestamp_start)) + def loop(self): + if timestamp_start > time.time(): + pass