diff --git a/matrix_pygmalion_bot/bot/ai/langchain.py b/matrix_pygmalion_bot/bot/ai/langchain.py index 3d4359c..c4305ae 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain.py +++ b/matrix_pygmalion_bot/bot/ai/langchain.py @@ -315,16 +315,33 @@ class AI(object): #ToDo: max_tokens and stop async def diary(self, room_id): - await asyncio.sleep(0) # yield for matrix-nio diary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_outline, verbose=True) conversation_memory = self.get_memory(room_id) - if self.llm_summary.get_num_tokens(conversation_memory.buffer_day) < 1600: - input_text = conversation_memory.buffer_day - else: - input_text = conversation_memory.moving_summary_buffer - - return await diary_chain.apredict(text=input_text, ai_name=self.bot.name) + text_splitter = RecursiveCharacterTextSplitter( + separators = ["\n", " ", ""], + chunk_size = 1600, + chunk_overlap = 40, + length_function = self.llm_summary.get_num_tokens, + ) + docs = text_splitter.create_documents([conversation_memory.buffer_day]) + string_diary = [] + for i in range(len(docs)): + logger.info("Writing diary... page {i} of {len(docs)}.") + await asyncio.sleep(0) # yield for matrix-nio + diary_chunk = await diary_chain.apredict(text=docs[i].page_content, ai_name=self.bot.name) + string_diary.append(diary_chunk) + diary_entry = "\n".join(string_diary) + + if self.llm_summary.get_num_tokens(diary_entry) > 1400: + logger.info("Summarizing diary entry.") + await asyncio.sleep(0) + diary_entry = await self.summarize(diary_entry) + if self.llm_summary.get_num_tokens(diary_entry) > 1600: + logger.warning("Diary entry too long. Discarding.") + diary_entry = conversation_memory.moving_summary_buffer + + return diary_entry async def agent(self):