From 50ef01f1eb2a5208cc88b1ddb6a8959b6ee30055 Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Thu, 30 Mar 2023 16:15:22 +0200 Subject: [PATCH] chat history --- .gitignore | 1 + bot.conf2 | 2 +- matrix_pygmalion_bot/ai/runpod_pygmalion.py | 19 +++--- matrix_pygmalion_bot/chatlog.py | 66 ++++++++++++++++++ matrix_pygmalion_bot/core.py | 76 +++++++++------------ matrix_pygmalion_bot/helpers.py | 14 ---- 6 files changed, 112 insertions(+), 66 deletions(-) create mode 100644 matrix_pygmalion_bot/chatlog.py diff --git a/.gitignore b/.gitignore index 94e6881..b1445f2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Custom bot.conf +bot.conf2 .store images/ chatlogs/ diff --git a/bot.conf2 b/bot.conf2 index c2e3359..7c3a493 100644 --- a/bot.conf2 +++ b/bot.conf2 @@ -2,5 +2,5 @@ timestamp = 1680134005.6927342 [Hendrik] -timestamp = 1680131376.732477 +timestamp = 1680149214.3503995 diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod_pygmalion.py index 355762a..550bb4b 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod_pygmalion.py @@ -32,7 +32,8 @@ async def generate_sync( "input": { "prompt": prompt, "max_length": max(prompt_num_tokens+max_new_tokens, 2048), - "temperature": 0.75 + "temperature": 0.75, + "do_sample": True, } } @@ -88,7 +89,7 @@ async def generate_sync( else: return "" -async def get_full_prompt(simple_prompt: str, bot): +async def get_full_prompt(simple_prompt: str, bot, chat_history): # Prompt without history prompt = bot.name + "'s Persona: " + bot.persona + "\n" @@ -103,19 +104,19 @@ async def get_full_prompt(simple_prompt: str, bot): total_num_tokens = await num_tokens(prompt) visible_history = [] current_message = True - for key, chat_item in reversed(bot.chat_history.items()): + for key, chat_item in reversed(chat_history.chat_history.items()): if current_message: current_message = False continue - if chat_item.message.startswith('!begin'): + if chat_item.message["en"].startswith('!begin'): break - if chat_item.message.startswith('!'): + if chat_item.message["en"].startswith('!'): continue - #if chat_item.message == bot.greeting: + #if chat_item.message["en"] == bot.greeting: # continue print("History: " + str(chat_item)) if chat_item.num_tokens == None: - chat_item.num_tokens = await num_tokens(chat_item.getLine()) + chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? if total_num_tokens < (MAX_TOKENS - max_new_tokens): visible_history.append(chat_item) @@ -130,9 +131,9 @@ async def get_full_prompt(simple_prompt: str, bot): #prompt += bot.name + ": " + bot.greeting + "\n" for chat_item in visible_history: if chat_item.is_own_message: - prompt += bot.name + ": " + chat_item.message + "\n" + prompt += bot.name + ": " + chat_item.message["en"] + "\n" else: - prompt += "You" + ": " + chat_item.message + "\n" + prompt += "You" + ": " + chat_item.message["en"] + "\n" prompt += "You: " + simple_prompt + "\n" prompt += bot.name + ":" diff --git a/matrix_pygmalion_bot/chatlog.py b/matrix_pygmalion_bot/chatlog.py new file mode 100644 index 0000000..a9fc59a --- /dev/null +++ b/matrix_pygmalion_bot/chatlog.py @@ -0,0 +1,66 @@ +import os +import matrix_pygmalion_bot.translate as translate + +class ChatMessage: + def __init__(self, event_id, timestamp, user_name, is_own_message, relates_to_event, message, language="en", english_original_message=None): + self.event_id = event_id + self.timestamp = timestamp + self.user_name = user_name + self.is_own_message = is_own_message + self.relates_to_event = relates_to_event + self.num_tokens = None + self.message = {} + self.message[language] = message + if not (language == "en"): + if not (english_original_message is None): + self.message["en"] = english_original_message + else: + self.message["en"] = translate.translate(message, language, "en") + self.language = language + self.num_tokens = None + def __str__(self): + return str("{}: {}".format(self.user_name, self.message[self.language])) + def getTranslation(self, to_lang): + if not (to_lang in self.message): + self.message[to_lang] = translate.translate(self.message["en"], "en", to_lang) + return self.message[to_lang] + + +class ChatHistory: + def __init__(self, bot_name, room_name): + self.bot_name = bot_name + self.room_name = room_name + self.chat_history = {} + def __str__(self): + return str("Chat History for {} in room {}".format(self.bot_name, self.room_name)) + def getLen(self): + return len(self.chat_history) + def load_from_file(self): + pass + def clear(self): + self.chat_history = {} + def remove(self, num_items=1): + for i in range(num_items): + event_id, item = self.chat_history.popitem() + return item + def add(self, event_id, timestamp, user_name, is_own_message, relates_to_event, message, language="en", english_original_message=None): + chat_message = ChatMessage(event_id, timestamp, user_name, is_own_message, relates_to_event, message, language, english_original_message) + self.chat_history[chat_message.event_id] = chat_message + os.makedirs("./chatlogs", exist_ok=True) + with open("chatlogs/" + self.bot_name + "_" + self.room_name + ".txt", "a") as f: + f.write("{}: {}\n".format(user_name, message)) + return chat_message + def getLastItem(self): + key, chat_item = reversed(self.chat_history.items())[0] + return chat_item + +class BotChatHistory: + def __init__(self, bot_name): + self.bot_name = bot_name + self.chat_rooms = {} + def __str__(self): + return str("Chat History for {}".format(self.bot_name)) + def room (self, room): + if not room in self.chat_rooms: + self.chat_rooms[room] = ChatHistory(self.bot_name, room) + return self.chat_rooms[room] diff --git a/matrix_pygmalion_bot/core.py b/matrix_pygmalion_bot/core.py index e976c84..cc46482 100644 --- a/matrix_pygmalion_bot/core.py +++ b/matrix_pygmalion_bot/core.py @@ -13,7 +13,8 @@ import magic from PIL import Image import re -from .helpers import ChatItem, Event +from .helpers import Event +from .chatlog import BotChatHistory ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") #from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style #from .runpod_pygmalion import generate_sync, get_full_prompt @@ -35,8 +36,9 @@ class Callbacks(object): self.bot = bot async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: - event_id = event.event_id - message = event.body + if not hasattr(event, 'body'): + return + #message = event.body is_own_message = False if event.sender == self.client.user: is_own_message = True @@ -49,14 +51,15 @@ class Callbacks(object): relates_to = None if 'm.relates_to' in event.source["content"]: relates_to = event.source["content"]['m.relates_to']["event_id"] - translated_message = message + language = "en" if not (self.bot.translate is None) and not is_command: - if 'original_message' in event.source["content"]: - translated_message = event.source["content"]['original_message'] - else: - translated_message = translate.translate(message, self.bot.translate, "en") - if hasattr(event, 'body'): - self.bot.chat_history[event_id] = ChatItem(event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message) + language = self.bot.translate + if 'original_message' in event.source["content"]: + english_original_message = event.source["content"]['original_message'] + else: + english_original_message = None + + chat_message = self.bot.chat_history.room(room.display_name).add(event.event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, event.body, language, english_original_message) if self.bot.not_synced: return print( @@ -64,10 +67,8 @@ class Callbacks(object): room.display_name, room.user_name(event.sender), event.body ) ) - os.makedirs("./chatlogs", exist_ok=True) - with open("chatlogs/" + self.bot.name + "_" + room.display_name + ".txt", "a") as f: - f.write("{}: {}\n".format(room.user_name(event.sender), event.body)) - await self.client.room_read_markers(room.room_id, event_id, event_id) + + await self.client.room_read_markers(room.room_id, event.event_id, event.event_id) # Ignore messages from ourselves if is_own_message: return @@ -89,56 +90,47 @@ class Callbacks(object): await self.bot.send_image(self.client, room.room_id, imagefile) return elif event.body.startswith('!begin'): - self.bot.chat_history = {} + self.bot.chat_history.room(room.display_name).clear() self.bot.timestamp = time.time() await self.bot.write_conf2(self.bot.name) await self.bot.send_message(self.client, room.room_id, self.bot.greeting) return elif event.body.startswith('!!!'): - if len(self.bot.chat_history) < 3: + if self.bot.chat_history.room(room.display_name).getLen() < 3: return - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current + chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() + chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() + chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") return elif event.body.startswith('!!'): - if len(self.bot.chat_history) < 3: + if self.bot.chat_history.room(room.display_name).getLen() < 3: return - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current + chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)# current await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() + chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current - self.bot.chat_history[chat_history_event_id] = chat_history_item - event_id = chat_history_item.event_id - message = chat_history_item.message - translated_message = message + chat_message = self.bot.chat_history.room(room.display_name).getLastItem() # new current # don't return, we generate a new answer elif event.body.startswith('!replace'): - if len(self.bot.chat_history) < 3: + if self.bot.chat_history.room(room.display_name).getLen() < 3: return - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current + chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() + chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") new_answer = event.body.removeprefix('!replace').strip() - new_translated_answer = new_answer - if not (self.bot.translate is None): - new_translated_answer = translate.translate(new_answer, "en", self.bot.translate) - await self.bot.send_message(self.client, room.room_id, new_translated_answer, reply_to=chat_history_item.relates_to_event, original_message=new_answer) - else: - await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) + await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) return # Other commands - if re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", message): + if re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body): # send, mail, drop, snap picture, photo, image, portrait pass - full_prompt = await ai.get_full_prompt(translated_message, self.bot) + full_prompt = await ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.display_name)) num_tokens = await ai.num_tokens(full_prompt) logger.info(full_prompt) logger.info(f"num tokens:" + str(num_tokens)) @@ -161,12 +153,11 @@ class Callbacks(object): answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name) answer = answer.strip() await self.client.room_typing(room.room_id, False) - translated_answer = answer if not (self.bot.translate is None): translated_answer = translate.translate(answer, "en", self.bot.translate) - await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event_id, original_message=answer) + await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=chat_message.event_id, original_message=answer) else: - await self.bot.send_message(self.client, room.room_id, answer, reply_to=event_id) + await self.bot.send_message(self.client, room.room_id, answer, reply_to=chat_message.event_id) @@ -203,7 +194,7 @@ class ChatBot(object): self.scenario = None self.greeting = None self.events = [] - self.chat_history = {} + self.chat_history = None if STORE_PATH and not os.path.isdir(STORE_PATH): os.mkdir(STORE_PATH) @@ -232,6 +223,7 @@ class ChatBot(object): async def login(self): self.config = AsyncClientConfig(store_sync_tokens=True) self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) + self.chat_history = BotChatHistory(self.name) self.callbacks = Callbacks(self.client, self) self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText) self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) diff --git a/matrix_pygmalion_bot/helpers.py b/matrix_pygmalion_bot/helpers.py index bf27704..fae01b0 100644 --- a/matrix_pygmalion_bot/helpers.py +++ b/matrix_pygmalion_bot/helpers.py @@ -1,19 +1,5 @@ import time -class ChatItem: - def __init__(self, event_id, timestamp, user_name, is_own_message, relates_to_event, message): - self.event_id = event_id - self.timestamp = timestamp - self.user_name = user_name - self.is_own_message = is_own_message - self.relates_to_event = relates_to_event - self.message = message - self.num_tokens = None - def __str__(self): - return str("{}: {}".format(self.user_name, self.message)) - def getLine(self): - return str("{}: {}".format(self.user_name, self.message)) - class Event: def __init__(self, time_start, time_stop, chance, command): self.time_start = time_start