From 8213d50f1590cdd30ba95f89d49c5809c1751b64 Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Sat, 15 Apr 2023 00:39:07 +0200 Subject: [PATCH] don't recalculate every time. use a window and remove a chunk of chat history when we get near the token limit --- matrix_pygmalion_bot/ai/koboldcpp.py | 4 +- matrix_pygmalion_bot/ai/llama_helpers.py | 47 +++++++++++++---- matrix_pygmalion_bot/ai/pygmalion_helpers.py | 53 +++++++++++++++----- matrix_pygmalion_bot/chatlog.py | 14 ++++++ matrix_pygmalion_bot/core.py | 39 +++++++------- 5 files changed, 117 insertions(+), 40 deletions(-) diff --git a/matrix_pygmalion_bot/ai/koboldcpp.py b/matrix_pygmalion_bot/ai/koboldcpp.py index d4b70d5..696ffe0 100644 --- a/matrix_pygmalion_bot/ai/koboldcpp.py +++ b/matrix_pygmalion_bot/ai/koboldcpp.py @@ -20,8 +20,10 @@ logger = logging.getLogger(__name__) def setup(): os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)") - os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)") + os.system("apt update && apt-get install libopenblas-dev libclblast-dev libmkl-dev") + os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 LLAMA_CLBLAST=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)") #python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin + #python3 koboldcpp.py --smartcontext models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin async def generate_sync( prompt: str, diff --git a/matrix_pygmalion_bot/ai/llama_helpers.py b/matrix_pygmalion_bot/ai/llama_helpers.py index 72afc1f..4709ccf 100644 --- a/matrix_pygmalion_bot/ai/llama_helpers.py +++ b/matrix_pygmalion_bot/ai/llama_helpers.py @@ -51,15 +51,20 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): #prompt += f"{ai_name}:" MAX_TOKENS = 2048 + WINDOW = 600 max_new_tokens = 200 total_num_tokens = await num_tokens(prompt) - total_num_tokens += await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:") + input_num_tokens = await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:") + total_num_tokens += input_num_tokens visible_history = [] - current_message = True + num_message = 0 for key, chat_item in reversed(chat_history.chat_history.items()): - if current_message: - current_message = False + num_message += 1 + if num_message == 1: + # skip current_message continue + if chat_item.stop_here: + break if chat_item.message["en"].startswith('!begin'): break if chat_item.message["en"].startswith('!'): @@ -69,10 +74,11 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): #if chat_item.message["en"] == bot.greeting: # continue if chat_item.num_tokens == None: - chat_item.num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"])) + chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"])) + chat_item = chat_history.chat_history[key] # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") - if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: + if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: visible_history.append(chat_item) total_num_tokens += chat_item.num_tokens else: @@ -81,14 +87,37 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): if not hasattr(bot, "greeting_num_tokens"): bot.greeting_num_tokens = await num_tokens(bot.greeting) - if total_num_tokens + bot.greeting_num_tokens < MAX_TOKENS - max_new_tokens: + if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: prompt += f"{ai_name}: {bot.greeting}\n" + total_num_tokens += bot.greeting_num_tokens for chat_item in visible_history: if chat_item.is_own_message: - prompt += f"{ai_name}: {chat_item.message['en']}\n" + line = f"{ai_name}: {chat_item.message['en']}\n" else: - prompt += f"{user_name}: {chat_item.message['en']}\n" + line = f"{user_name}: {chat_item.message['en']}\n" + prompt += line + if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt: + logger.info(f"adding to saved prompt: \"line\"") + chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens ) + chat_item.is_in_saved_prompt = True + + if chat_history.saved_context_num_tokens: + logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}") + if chat_history.getSavedPrompt(): + if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens: + chat_history.setFastForward(False) + if chat_history.getFastForward(): + logger.info("using saved prompt") + prompt = chat_history.getSavedPrompt() + if not chat_history.getSavedPrompt() or not chat_history.getFastForward(): + logger.info("regenerating prompt") + chat_history.setSavedPrompt(prompt, total_num_tokens) + for key, chat_item in chat_history.chat_history.items(): + if key != list(chat_history.chat_history)[-1]: # exclude current item + chat_history.chat_history[key].is_in_saved_prompt = True + chat_history.setFastForward(True) + prompt += f"{user_name}: {simple_prompt}\n" prompt += f"{ai_name}:" diff --git a/matrix_pygmalion_bot/ai/pygmalion_helpers.py b/matrix_pygmalion_bot/ai/pygmalion_helpers.py index 78656e2..b0bf12e 100644 --- a/matrix_pygmalion_bot/ai/pygmalion_helpers.py +++ b/matrix_pygmalion_bot/ai/pygmalion_helpers.py @@ -15,6 +15,9 @@ from PIL import Image, PngImagePlugin logger = logging.getLogger(__name__) +tokenizer = None + + async def get_full_prompt(simple_prompt: str, bot, chat_history): # Prompt without history @@ -32,15 +35,20 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): #prompt += bot.name + ":" MAX_TOKENS = 2048 + WINDOW = 800 max_new_tokens = 200 total_num_tokens = await num_tokens(prompt) - total_num_tokens += await num_tokens(f"You: " + simple_prompt + "\n{bot.name}:") + input_num_tokens = await num_tokens(f"You: " + simple_prompt + "\n{bot.name}:") + total_num_tokens += input_num_tokens visible_history = [] - current_message = True + num_message = 0 for key, chat_item in reversed(chat_history.chat_history.items()): - if current_message: - current_message = False + num_message += 1 + if num_message == 1: + # skip current_message continue + if chat_item.stop_here: + break if chat_item.message["en"].startswith('!begin'): break if chat_item.message["en"].startswith('!'): @@ -50,10 +58,11 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): #if chat_item.message["en"] == bot.greeting: # continue if chat_item.num_tokens == None: - chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) + chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) + chat_item = chat_history.chat_history[key] # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") - if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: + if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: visible_history.append(chat_item) total_num_tokens += chat_item.num_tokens else: @@ -62,23 +71,43 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): if not hasattr(bot, "greeting_num_tokens"): bot.greeting_num_tokens = await num_tokens(bot.greeting) - if total_num_tokens + bot.greeting_num_tokens < MAX_TOKENS - max_new_tokens: + if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: prompt += bot.name + ": " + bot.greeting + "\n" + total_num_tokens += bot.greeting_num_tokens for chat_item in visible_history: if chat_item.is_own_message: - prompt += bot.name + ": " + chat_item.message["en"] + "\n" + line = bot.name + ": " + chat_item.message["en"] + "\n" else: - prompt += "You" + ": " + chat_item.message["en"] + "\n" + line = "You" + ": " + chat_item.message["en"] + "\n" + prompt += line + if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt: + logger.info(f"adding to saved prompt: \"{line}\"") + chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens ) + chat_item.is_in_saved_prompt = True + + if chat_history.saved_context_num_tokens: + logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}") + if chat_history.getSavedPrompt(): + if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens: + chat_history.setFastForward(False) + if chat_history.getFastForward(): + logger.info("using saved prompt") + prompt = chat_history.getSavedPrompt() + if not chat_history.getSavedPrompt() or not chat_history.getFastForward(): + logger.info("regenerating prompt") + chat_history.setSavedPrompt(prompt, total_num_tokens) + for key, chat_item in chat_history.chat_history.items(): + if key != list(chat_history.chat_history)[-1]: # exclude current item + chat_history.chat_history[key].is_in_saved_prompt = True + chat_history.setFastForward(True) + prompt += "You: " + simple_prompt + "\n" prompt += bot.name + ":" return prompt -tokenizer = None - - async def num_tokens(input_text: str): # os.makedirs("./models/pygmalion-6b", exist_ok=True) # hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") diff --git a/matrix_pygmalion_bot/chatlog.py b/matrix_pygmalion_bot/chatlog.py index ef44e9a..a135211 100644 --- a/matrix_pygmalion_bot/chatlog.py +++ b/matrix_pygmalion_bot/chatlog.py @@ -12,8 +12,10 @@ class ChatMessage: self.user_name = user_name self.is_own_message = is_own_message self.is_command = is_command + self.stop_here = False self.relates_to_event = relates_to_event self.num_tokens = None + self.is_in_saved_prompt = False self.message = {} self.message[language] = message if not (language == "en"): @@ -40,6 +42,9 @@ class ChatHistory: def __init__(self, bot_name, room_name): self.bot_name = bot_name self.room_name = room_name + self.context_fast_forward = False + self.saved_context = None + self.saved_context_num_tokens = None self.chat_history = {} def __str__(self): return str("Chat History for {} in room {}".format(self.bot_name, self.room_name)) @@ -70,6 +75,15 @@ class ChatHistory: def getLastItem(self): key, chat_item = list(self.chat_history.items())[-1] return chat_item + def setFastForward(self, value): + self.context_fast_forward = value + def getFastForward(self): + return self.context_fast_forward + def getSavedPrompt(self): + return self.saved_context + def setSavedPrompt(self, context, num_tokens): + self.saved_context = context + self.saved_context_num_tokens = num_tokens class BotChatHistory: def __init__(self, bot_name): diff --git a/matrix_pygmalion_bot/core.py b/matrix_pygmalion_bot/core.py index 67678ee..403db62 100644 --- a/matrix_pygmalion_bot/core.py +++ b/matrix_pygmalion_bot/core.py @@ -59,12 +59,14 @@ class Callbacks(object): else: english_original_message = None - chat_message = self.bot.chat_history.room(room.display_name).add(event.event_id, event.server_timestamp, room.user_name(event.sender), event.sender == self.client.user, is_command, relates_to, event.body, language, english_original_message) + chat_message = self.bot.chat_history.room(room.room_id).add(event.event_id, event.server_timestamp, room.user_name(event.sender), event.sender == self.client.user, is_command, relates_to, event.body, language, english_original_message) # parse keywords self.bot.extra_info = {"persona": [], "scenario": [], "example_dialogue": []} for i, keyword in enumerate(self.bot.keywords): if re.search(keyword["regex"], event.body, flags=re.IGNORECASE): + if not 'active' in self.bot.keywords[i] or self.bot.keywords[i]['active'] < 1: + self.bot.chat_history.room(room.room_id).setFastForward(False) self.bot.keywords[i]['active'] = int(keyword["duration"]) logger.info(f"keyword \"{keyword['regex']}\" detected") if 'active' in self.bot.keywords[i]: @@ -185,7 +187,7 @@ class Callbacks(object): elif event.body.startswith('!temperature'): self.bot.temperature = float( event.body.removeprefix('!temperature').strip() ) elif event.body.startswith('!begin'): - self.bot.chat_history.room(room.display_name).clear() + self.bot.chat_history.room(room.room_id).clear() self.bot.room_config[room.room_id]["tick"] = 0 await self.bot.write_conf2(self.bot.name) await self.bot.send_message(self.client, room.room_id, self.bot.greeting) @@ -197,30 +199,30 @@ class Callbacks(object): self.bot.room_config[room.room_id]["disabled"] = True return elif event.body.startswith('!!!'): - if self.bot.chat_history.room(room.display_name).getLen() < 3: + if self.bot.chat_history.room(room.room_id).getLen() < 3: return - chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current + chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) # current await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) + chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) + chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") return elif event.body.startswith('!!'): - if self.bot.chat_history.room(room.display_name).getLen() < 3: + if self.bot.chat_history.room(room.room_id).getLen() < 3: return - chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)# current + chat_history_item = self.bot.chat_history.room(room.room_id).remove(1)# current await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) + chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_message = self.bot.chat_history.room(room.display_name).getLastItem() # new current + chat_message = self.bot.chat_history.room(room.room_id).getLastItem() # new current # don't return, we generate a new answer elif event.body.startswith('!replace'): - if self.bot.chat_history.room(room.display_name).getLen() < 3: + if self.bot.chat_history.room(room.room_id).getLen() < 3: return - chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current + chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) # current await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") - chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) + chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") new_answer = event.body.removeprefix('!replace').strip() await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) @@ -237,7 +239,7 @@ class Callbacks(object): # send, mail, drop, snap picture, photo, image, portrait pass - full_prompt = await text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.display_name)) + full_prompt = await text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.room_id)) num_tokens = await text_ai.num_tokens(full_prompt) logger.debug(full_prompt) logger.info(f"Prompt has " + str(num_tokens) + " tokens") @@ -288,11 +290,12 @@ class Callbacks(object): logger.info(f"event redacted in room {room.room_id}. event_id: {event.redacts}") for bot in bots: # for room in bot.chat_history.chat_rooms.keys(): - if room.display_name in bot.chat_history.chat_rooms: + if room.room_id in bot.chat_history.chat_rooms: logger.info("room found") - if bot.chat_history.chat_rooms[room.display_name].exists_id(event.redacts): + if bot.chat_history.chat_rooms[room.room_id].exists_id(event.redacts): logger.info("found it") - bot.chat_history.chat_rooms[room.display_name].remove_id(event.redacts) + bot.chat_history.chat_rooms[room.room_id].remove_id(event.redacts) + self.bot.chat_history.room(room.room_id).setFastForward(False) class ChatBot(object): """Main chatbot""" @@ -338,6 +341,7 @@ class ChatBot(object): self.scenario = scenario self.greeting = greeting self.example_dialogue = example_dialogue + self.chat_history = BotChatHistory(self.name) def get_persona(self): return ' '.join([self.persona, ' '.join(self.extra_info['persona'])]) @@ -374,7 +378,6 @@ class ChatBot(object): async def login(self): self.config = AsyncClientConfig(store_sync_tokens=True) self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) - self.chat_history = BotChatHistory(self.name) self.callbacks = Callbacks(self.client, self) self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText) self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent)