From 0fabd77408dde8201956766c338268b4bfd66d8e Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Mon, 8 May 2023 16:37:33 +0200 Subject: [PATCH] map-reduce summarization --- matrix_pygmalion_bot/bot/ai/langchain.py | 46 +++++++++++++++---- .../bot/ai/langchain_memory.py | 2 + matrix_pygmalion_bot/bot/ai/prompts.py | 14 +++++- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/matrix_pygmalion_bot/bot/ai/langchain.py b/matrix_pygmalion_bot/bot/ai/langchain.py index c4305ae..36cda61 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain.py +++ b/matrix_pygmalion_bot/bot/ai/langchain.py @@ -9,9 +9,11 @@ from langchain import LLMChain, ConversationChain from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory, CombinedMemory, ConversationSummaryMemory from langchain.chains.base import Chain -from typing import Dict, List, Union +from langchain.chains.summarize import load_summarize_chain +from typing import Any, Dict, List, Optional, Union from langchain.document_loaders import TextLoader +from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import SentenceTransformerEmbeddings from langchain.vectorstores import Chroma @@ -25,6 +27,7 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper import humanize from datetime import datetime, timedelta +from termcolor import colored import logging logger = logging.getLogger(__name__) @@ -262,12 +265,12 @@ class AI(object): memory=custom_memory, #stop=['<|endoftext|>', '\nYou:', f"\n{chat_human_name}:"], ) - + # output = llm_chain(inputs={"ai_name": self.bot.name, "persona": self.bot.persona, "scenario": self.bot.scenario, "human_name": chat_human_name, "ai_name_chat": self.bot.name, "chat_history": "", "input": message.content})['results'][0]['text'] - stop = ['<|endoftext|>', f"\n{chat_human_name}"] + stop = ['<|endoftext|>', f"\n{chat_human_name}:"] if chat_human_name != message.additional_kwargs['user_name']: - stop.append(f"\n{message.additional_kwargs['user_name']}") + stop.append(f"\n{message.additional_kwargs['user_name']}:") #print(f"Message is: \"{message.content}\"") await asyncio.sleep(0) output = await chain.arun({"input":message.content, "stop": stop}) @@ -307,11 +310,35 @@ class AI(object): return output - async def summarize(self, text): + async def summarize(self, text, map_prompt=prompt_summary2, combine_prompt=prompt_summary2): + #metadata = {"source": "internet", "date": "Friday"} + #doc = Document(page_content=text, metadata=metadata) + docs = [ Document(page_content=text) ] + + map_chain = LLMChain(llm=self.llm_summary, prompt=map_prompt, verbose=True) + reduce_chain = LLMChain(llm=self.llm_summary, prompt=combine_prompt, verbose=True) + text_splitter = RecursiveCharacterTextSplitter( + #separators = ["\n\n", "\n", " ", ""], + chunk_size = 1600, + chunk_overlap = 80, + length_function = self.llm_chat.get_num_tokens, + ) + + for i in range(2): + docs = text_splitter.split_documents(docs) + if len(docs) > 1: + #results = await map_chain.aapply([{"text": d.page_content} for d in docs]) + #docs = [Document(page_content=r['output'], metadata=docs[i].metadata) for i, r in enumerate(results) + for i, d in enumerate(docs): + await asyncio.sleep(0) # yield for matrix-nio + docs[i].page_content = await map_chain.arun(docs[i].page_content) + combined = "\n".join([d.page_content for d in docs]) + docs = [ Document(page_content=combined) ] + else: + break + await asyncio.sleep(0) # yield for matrix-nio - summary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_summary, verbose=True) - return await summary_chain.arun(text=text) - #ToDo: We can summarize the whole dialogue here, let half of it in the buffer but skip doing a summary until this is flushed, too? + return await reduce_chain.arun(text=docs[0].page_content) #ToDo: max_tokens and stop async def diary(self, room_id): @@ -326,7 +353,7 @@ class AI(object): ) docs = text_splitter.create_documents([conversation_memory.buffer_day]) string_diary = [] - for i in range(len(docs)): + for i, doc in enumerate(docs): logger.info("Writing diary... page {i} of {len(docs)}.") await asyncio.sleep(0) # yield for matrix-nio diary_chunk = await diary_chain.apredict(text=docs[i].page_content, ai_name=self.bot.name) @@ -453,3 +480,4 @@ def replace_all(text, dic): for i, j in dic.items(): text = text.replace(i, j) return text + diff --git a/matrix_pygmalion_bot/bot/ai/langchain_memory.py b/matrix_pygmalion_bot/bot/ai/langchain_memory.py index c016784..f90301a 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain_memory.py +++ b/matrix_pygmalion_bot/bot/ai/langchain_memory.py @@ -92,9 +92,11 @@ class CustomMemory(BaseMemory): self.moving_summary_buffer = loop.run_until_complete( self.predict_new_summary(pruned_memory, self.moving_summary_buffer) ) + # loop.run_in_executor(None, self.predict_new_summary, pruned_memory, self.moving_summery_buffer) async def prune_memory(self, max_len): # Prune buffer if it exceeds max token limit + #ToDo: We can summarize the whole dialogue here, let half of it in the buffer but skip doing a summary until this is flushed, too? buffer = self.chat_memory.messages curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) if curr_buffer_length > max_len: diff --git a/matrix_pygmalion_bot/bot/ai/prompts.py b/matrix_pygmalion_bot/bot/ai/prompts.py index 99c2c43..1a7a8fd 100644 --- a/matrix_pygmalion_bot/bot/ai/prompts.py +++ b/matrix_pygmalion_bot/bot/ai/prompts.py @@ -97,7 +97,6 @@ template_question_simple = """Question: {question} Answer: Let's think step by step.""" - prompt_summary = PromptTemplate.from_template( """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -111,6 +110,19 @@ Summarize the following text in one paragraph. """ ) +prompt_summary2 = PromptTemplate.from_template( +"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +### Instruction: +Write a concise summary of the following: + +### Input: +{text} + +### Response: +""" +) + prompt_progressive_summary = PromptTemplate.from_template( """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.