From 0b52bb31b74b6de4a96fa2376a8ac4136c01de1e Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Fri, 5 May 2023 20:09:06 +0200 Subject: [PATCH] context size --- matrix_pygmalion_bot/bot/ai/langchain.py | 7 +++++++ .../bot/ai/langchain_memory.py | 18 ++++++++++-------- matrix_pygmalion_bot/bot/utilities/messages.py | 7 ++++++- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/matrix_pygmalion_bot/bot/ai/langchain.py b/matrix_pygmalion_bot/bot/ai/langchain.py index c596e77..a91afb7 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain.py +++ b/matrix_pygmalion_bot/bot/ai/langchain.py @@ -119,6 +119,7 @@ class AI(object): room_id = message.additional_kwargs['room_id'] conversation_memory = self.get_memory(room_id) conversation_memory.chat_memory.messages.append(message) + conversation_memory.chat_memory_day.messages.append(message) async def clear(self, room_id): conversation_memory = self.get_memory(room_id) @@ -219,6 +220,8 @@ class AI(object): if prompt_len+256 > 2000: logger.warning(f"Prompt too large. Estimated {prompt_len} tokens") + await reply_fn(f" Prompt too large. Estimated {prompt_len} tokens") + await conversation_memory.prune_memory(conversation_memory.min_len) #roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=chat_ai_name, human_name_chat=chat_human_name) @@ -370,6 +373,8 @@ class AI(object): "room_id": room_id, } ) + if conversation_memory.chat_memory.messages[-1].content.startswith('~~~~ '): + conversation_memory.chat_memory.messages.pop() conversation_memory.chat_memory.messages.append(message) #conversation_memory.chat_memory.add_system_message(message) @@ -377,6 +382,8 @@ class AI(object): yesterday = ( datetime.now() - timedelta(days=1) ).strftime('%Y-%m-%d') for room_id in self.rooms.keys(): if len(conversation_memory.chat_memory_day.messages) > 0: + if not "diary" in self.bot.rooms[room_id]: + self.bot.rooms[room_id]['diary'] = {} self.bot.rooms[room_id]["diary"][yesterday] = await self.diary(room_id) # Calculate new goals for the character # Update stats diff --git a/matrix_pygmalion_bot/bot/ai/langchain_memory.py b/matrix_pygmalion_bot/bot/ai/langchain_memory.py index 5161fab..ed3abce 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain_memory.py +++ b/matrix_pygmalion_bot/bot/ai/langchain_memory.py @@ -91,22 +91,24 @@ class CustomMemory(BaseMemory): self.predict_new_summary(pruned_memory, self.moving_summary_buffer) ) - async def asave_context(self, input_msg: BaseMessage, output_msg: BaseMessage) -> None: - """Save context from this conversation to buffer.""" - self.chat_memory.messages.append(input_msg) - self.chat_memory.messages.append(output_msg) - self.chat_memory_day.messages.append(input_msg) - self.chat_memory_day.messages.append(output_msg) + async def prune_memory(self, max_len): # Prune buffer if it exceeds max token limit buffer = self.chat_memory.messages curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) - if curr_buffer_length > self.max_len: + if curr_buffer_length > max_len: pruned_memory = [] - while curr_buffer_length > self.min_len: + while curr_buffer_length > self.min_len and len(buffer) > 0: pruned_memory.append(buffer.pop(0)) curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) self.moving_summary_buffer = await self.apredict_new_summary(pruned_memory, self.moving_summary_buffer) + async def asave_context(self, input_msg: BaseMessage, output_msg: BaseMessage) -> None: + """Save context from this conversation to buffer.""" + self.chat_memory.messages.append(input_msg) + self.chat_memory.messages.append(output_msg) + self.chat_memory_day.messages.append(input_msg) + self.chat_memory_day.messages.append(output_msg) + await self.prune_memory(self.max_len) def clear(self) -> None: """Clear memory contents.""" diff --git a/matrix_pygmalion_bot/bot/utilities/messages.py b/matrix_pygmalion_bot/bot/utilities/messages.py index 6ef3f28..6166b1b 100644 --- a/matrix_pygmalion_bot/bot/utilities/messages.py +++ b/matrix_pygmalion_bot/bot/utilities/messages.py @@ -77,7 +77,12 @@ class Message(object): return self.message.startswith('!') def is_error(self): - return self.message.startswith('') + if self.message.startswith(''): + return True + elif self.message.startswith(''): + return True + else: + return False def __str__(self): return str("{}: {}".format(self.user_name, self.message))