From 64bdbc9c0f94ef4c43faa3653f4c5e3274c03871 Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Fri, 5 May 2023 19:21:09 +0200 Subject: [PATCH] memory --- matrix_pygmalion_bot/bot/ai/langchain.py | 92 +++++++++---------- .../bot/ai/langchain_memory.py | 23 ++++- matrix_pygmalion_bot/bot/core.py | 12 ++- matrix_pygmalion_bot/main.py | 12 +++ 4 files changed, 81 insertions(+), 58 deletions(-) diff --git a/matrix_pygmalion_bot/bot/ai/langchain.py b/matrix_pygmalion_bot/bot/ai/langchain.py index 55166a9..c596e77 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain.py +++ b/matrix_pygmalion_bot/bot/ai/langchain.py @@ -102,24 +102,23 @@ class AI(object): def get_memory(self, room_id, human_prefix="Human"): if not room_id in self.rooms: self.rooms[room_id] = {} - memory = CustomMemory(memory_key="chat_history", input_key="input", human_prefix=human_prefix, ai_prefix=self.bot.name, llm=self.llm_summary, summary_prompt=prompt_progressive_summary, max_len=1200, min_len=200) + if "moving_summary" in self.bot.rooms[room_id]: + moving_summary = self.bot.rooms[room_id]['moving_summary'] + else: + moving_summary = "No previous events." + memory = CustomMemory(memory_key="chat_history", input_key="input", human_prefix=human_prefix, ai_prefix=self.bot.name, llm=self.llm_summary, summary_prompt=prompt_progressive_summary, moving_summary_buffer=moving_summary, max_len=1200, min_len=200) self.rooms[room_id]["memory"] = memory - self.rooms[room_id]["summary"] = "No previous events." - memory.chat_memory.add_ai_message(self.bot.greeting) - #memory.save_context({"input": None, "output": self.bot.greeting}) - memory.load_memory_variables({}) + #memory.chat_memory.add_ai_message(self.bot.greeting) else: memory = self.rooms[room_id]["memory"] - #print(f"memory: {memory.load_memory_variables({})}") - #print(f"memory has an estimated {self.llm_chat.get_num_tokens(memory.buffer)} number of tokens") + if human_prefix != memory.human_prefix: + memory.human_prefix = human_prefix return memory async def add_chat_message(self, message): - conversation_memory = self.get_memory(message.room_id) - langchain_message = message.to_langchain() - if message.user_id == self.bot.connection.user_id: - langchain_message.role = self.bot.name - conversation_memory.chat_memory.messages.append(langchain_message) + room_id = message.additional_kwargs['room_id'] + conversation_memory = self.get_memory(room_id) + conversation_memory.chat_memory.messages.append(message) async def clear(self, room_id): conversation_memory = self.get_memory(room_id) @@ -176,7 +175,7 @@ class AI(object): llm=self.llm_chat, prompt=PromptTemplate.from_template(prompt_template), ) - output = await chain.arun(message.message) + output = await chain.arun(message.content) return output.strip() @@ -190,33 +189,11 @@ class AI(object): chat_human_name = "### Human" conversation_memory = self.get_memory(room_id, chat_human_name) - conversation_memory.human_prefix = chat_human_name readonlymemory = ReadOnlySharedMemory(memory=conversation_memory) - summary_memory = ConversationSummaryMemory(llm=self.llm_summary, memory_key="summary", input_key="input") + #summary_memory = ConversationSummaryMemory(llm=self.llm_summary, memory_key="summary", input_key="input") #combined_memory = CombinedMemory(memories=[conversation_memory, summary_memory]) - k = 1 # 5 - max_k = 3 # 12 - if len(conversation_memory.chat_memory.messages) > max_k*2: - - async def make_progressive_summary(previous_summary, chat_history_text_string): - await asyncio.sleep(0) # yield for matrix-nio - #self.rooms[room_id]["summary"] = summary_memory.predict_new_summary(conversation_memory.chat_memory.messages, previous_summary).strip() - summary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_progressive_summary, verbose=True) - self.rooms[room_id]["summary"] = await summary_chain.apredict(summary=previous_summary, chat_history=chat_history_text_string) - # ToDo: maybe add an add_task_done callback and don't access the variable directly from here? - logger.info(f"New summary is: \"{self.rooms[room_id]['summary']}\"") - conversation_memory.chat_memory.messages = conversation_memory.chat_memory.messages[-k * 2 :] - conversation_memory.load_memory_variables({}) - #summary = summarize(conversation_memory.buffer) - #print(summary) - #return summary - - - logger.info("memory progressive summary scheduled...") - await self.bot.schedule(self.bot.queue, make_progressive_summary, self.rooms[room_id]["summary"], conversation_memory.buffer) #.add_done_callback( - - + #await self.bot.schedule(self.bot.queue, make_progressive_summary, self.rooms[room_id]["summary"], conversation_memory.buffer) #.add_done_callback( #t = datetime.fromtimestamp(message.additional_kwargs['timestamp']) #when = humanize.naturaltime(t) @@ -231,11 +208,20 @@ class AI(object): ai_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, - summary=self.rooms[room_id]["summary"], + summary=conversation_memory.moving_summary_buffer, human_name=chat_human_name, #example_dialogue=replace_all(self.bot.example_dialogue, {"{{user}}": chat_human_name, "{{char}}": chat_ai_name}) ai_name_chat=chat_ai_name, ) + + tmp_prompt_text = prompt.format(chat_history=conversation_memory.buffer, input=message.content) + prompt_len = self.llm_chat.get_num_tokens(tmp_prompt_text) + + if prompt_len+256 > 2000: + logger.warning(f"Prompt too large. Estimated {prompt_len} tokens") + + + #roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=chat_ai_name, human_name_chat=chat_human_name) chain = ConversationChain( llm=self.llm_chat, @@ -247,8 +233,6 @@ class AI(object): # output = llm_chain(inputs={"ai_name": self.bot.name, "persona": self.bot.persona, "scenario": self.bot.scenario, "human_name": chat_human_name, "ai_name_chat": self.bot.name, "chat_history": "", "input": message.content})['results'][0]['text'] - #roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=chat_ai_name, human_name_chat=chat_human_name) - stop = ['<|endoftext|>', f"\n{chat_human_name}"] #print(f"Message is: \"{message.content}\"") await asyncio.sleep(0) @@ -264,20 +248,26 @@ class AI(object): own_message_resp = await reply_fn(output) - langchain_ai_message = AIMessage( + output_message = AIMessage( content=output, additional_kwargs={ "timestamp": datetime.now().timestamp(), "user_name": self.bot.name, "event_id": own_message_resp.event_id, - "user_id": None, + "user_id": self.bot.connection.user_id, "room_name": message.additional_kwargs['room_name'], "room_id": own_message_resp.room_id, } ) - conversation_memory.save_context({"input": message.content}, {"ouput": output}) - conversation_memory.load_memory_variables({}) + await conversation_memory.asave_context(message, output_message) + summary_len = self.llm_chat.get_num_tokens(conversation_memory.moving_summary_buffer) + if summary_len > 400: + logger.warning("Summary is getting too long. Refining...") + conversation_memory.moving_summary_buffer = await self.summarize(conversation_memory.moving_summary_buffer) + new_summary_len = self.llm_chat.get_num_tokens(conversation_memory.moving_summary_buffer) + logger.info(f"Refined summary from {summary_len} tokens to {new_summary_len} tokens ({new_summary_len-summary_len} tokens)") + self.bot.rooms[room_id]['moving_summary'] = conversation_memory.moving_summary_buffer return output @@ -293,11 +283,13 @@ class AI(object): await asyncio.sleep(0) # yield for matrix-nio diary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_outline, verbose=True) conversation_memory = self.get_memory(room_id) - #self.rooms[message.room_id]["summary"] - string_messages = [] - for m in conversation_memory.chat_memory_day.messages: - string_messages.append(f"{message.role}: {message.content}") - return await diary_chain.apredict(text="\n".join(string_messages)) + + if self.llm_summary.get_num_tokens(conversation_memory.buffer_day) < 1600: + input_text = conversation_memory.buffer_day + else: + input_text = conversation_memory.moving_summary_buffer + + return await diary_chain.apredict(text=input_text) async def agent(self): @@ -371,7 +363,7 @@ class AI(object): content=f"~~~~ {datetime.now().strftime('%A, %B %d, %Y')} ~~~~", additional_kwargs={ "timestamp": datetime.now().timestamp(), - "user_name": self.bot.name, + "user_name": None, "event_id": None, "user_id": None, "room_name": None, diff --git a/matrix_pygmalion_bot/bot/ai/langchain_memory.py b/matrix_pygmalion_bot/bot/ai/langchain_memory.py index be5a4e8..5161fab 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain_memory.py +++ b/matrix_pygmalion_bot/bot/ai/langchain_memory.py @@ -88,9 +88,25 @@ class CustomMemory(BaseMemory): curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) loop = asyncio.get_event_loop() self.moving_summary_buffer = loop.run_until_complete( - self.apredict_new_summary(pruned_memory, self.moving_summary_buffer) + self.predict_new_summary(pruned_memory, self.moving_summary_buffer) ) + async def asave_context(self, input_msg: BaseMessage, output_msg: BaseMessage) -> None: + """Save context from this conversation to buffer.""" + self.chat_memory.messages.append(input_msg) + self.chat_memory.messages.append(output_msg) + self.chat_memory_day.messages.append(input_msg) + self.chat_memory_day.messages.append(output_msg) + # Prune buffer if it exceeds max token limit + buffer = self.chat_memory.messages + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + if curr_buffer_length > self.max_len: + pruned_memory = [] + while curr_buffer_length > self.min_len: + pruned_memory.append(buffer.pop(0)) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + self.moving_summary_buffer = await self.apredict_new_summary(pruned_memory, self.moving_summary_buffer) + def clear(self) -> None: """Clear memory contents.""" @@ -157,9 +173,10 @@ class CustomMemory(BaseMemory): ai_prefix=self.ai_prefix, ) - chain = LLMChain(llm=self.llm, prompt=self.summary_prompt) + chain = LLMChain(llm=self.llm, prompt=self.summary_prompt, verbose = True) await asyncio.sleep(0) - return await chain.apredict(summary=existing_summary, new_lines=new_lines) + output = await chain.apredict(summary=existing_summary, chat_history=new_lines) + return output.strip() class ChatMessageHistoryMessage(BaseModel): diff --git a/matrix_pygmalion_bot/bot/core.py b/matrix_pygmalion_bot/bot/core.py index b084682..54dbae2 100644 --- a/matrix_pygmalion_bot/bot/core.py +++ b/matrix_pygmalion_bot/bot/core.py @@ -81,6 +81,7 @@ class ChatBot(object): task.cancel() # Wait until all worker tasks are cancelled. await asyncio.gather(*self.background_tasks, return_exceptions=True) + await self.write_conf2(self.rooms) await self.connection.logout() async def load_ai(self, available_text_endpoints, available_image_endpoints): @@ -137,6 +138,10 @@ class ChatBot(object): if self.name.casefold() == message.user_name.casefold(): """Bot and user have the same name""" message.user_name += " 2" # or simply "You" + if message.is_from(self.connection.user_id): + message.role = "ai" + else: + message.role = "human" if not room.room_id in self.rooms: self.rooms[room.room_id] = {} @@ -148,18 +153,15 @@ class ChatBot(object): if not self.connection.synced: if not message.is_command() and not message.is_error(): - await self.ai.add_chat_message(message) + await self.ai.add_chat_message(message.to_langchain()) self.chatlog.save(message, False) return if message.is_from(self.connection.user_id): """Skip messages from ouselves""" - message.role = "ai" self.chatlog.save(message) await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id) return - else: - message.role = "human" # if event.decrypted: # encrypted_symbol = "🛡 " @@ -201,7 +203,7 @@ class ChatBot(object): self.rooms[room.room_id]['num_messages'] += 1 self.last_conversation = datetime.now() self.chatlog.save(message) - print("done") + async def redaction_cb(self, room, event) -> None: self.chatlog.remove_message_by_id(event.event_id) diff --git a/matrix_pygmalion_bot/main.py b/matrix_pygmalion_bot/main.py index 7e809e0..08cec80 100644 --- a/matrix_pygmalion_bot/main.py +++ b/matrix_pygmalion_bot/main.py @@ -61,6 +61,13 @@ async def main() -> None: try: +# loop = asyncio.get_running_loop() +# +# for signame in {'SIGINT', 'SIGTERM'}: +# loop.add_signal_handler( +# getattr(signal, signame), +# functools.partial(ask_exit, signame, loop)) + if sys.version_info[0] == 3 and sys.version_info[1] < 11: tasks = [] for bot in bots: @@ -81,6 +88,11 @@ async def main() -> None: await bot.disconnect() sys.exit(0) +#def ask_exit(signame, loop): +# print("got signal %s: exit" % signame) +# loop.stop() + + if __name__ == "__main__": asyncio.run(main())