From 5703dcd1753e698f39c7fe3c1258484a03115fe3 Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Sun, 16 Apr 2023 13:27:20 +0200 Subject: [PATCH] rewrite model helpers --- matrix_pygmalion_bot/ai/koboldcpp.py | 5 +- .../ai/{llama_helpers.py => model_helpers.py} | 111 +++++++++++++--- matrix_pygmalion_bot/ai/pygmalion_helpers.py | 124 ------------------ .../ai/{runpod_pygmalion.py => runpod.py} | 4 +- matrix_pygmalion_bot/ai/stablehorde.py | 5 +- matrix_pygmalion_bot/core.py | 78 +++++++---- 6 files changed, 155 insertions(+), 172 deletions(-) rename matrix_pygmalion_bot/ai/{llama_helpers.py => model_helpers.py} (56%) delete mode 100644 matrix_pygmalion_bot/ai/pygmalion_helpers.py rename matrix_pygmalion_bot/ai/{runpod_pygmalion.py => runpod.py} (99%) diff --git a/matrix_pygmalion_bot/ai/koboldcpp.py b/matrix_pygmalion_bot/ai/koboldcpp.py index 4739f9d..986667e 100644 --- a/matrix_pygmalion_bot/ai/koboldcpp.py +++ b/matrix_pygmalion_bot/ai/koboldcpp.py @@ -12,8 +12,7 @@ import io import base64 from PIL import Image, PngImagePlugin -from .pygmalion_helpers import get_full_prompt, num_tokens -#from .llama_helpers import get_full_prompt, num_tokens +from .model_helpers import get_full_prompt, num_tokens logger = logging.getLogger(__name__) @@ -41,7 +40,7 @@ async def generate_sync( } max_new_tokens = 200 - prompt_num_tokens = await num_tokens(prompt) + prompt_num_tokens = await num_tokens(prompt, bot.model) # Define your inputs input_data = { diff --git a/matrix_pygmalion_bot/ai/llama_helpers.py b/matrix_pygmalion_bot/ai/model_helpers.py similarity index 56% rename from matrix_pygmalion_bot/ai/llama_helpers.py rename to matrix_pygmalion_bot/ai/model_helpers.py index 4709ccf..d6aee2f 100644 --- a/matrix_pygmalion_bot/ai/llama_helpers.py +++ b/matrix_pygmalion_bot/ai/model_helpers.py @@ -15,9 +15,10 @@ from PIL import Image, PngImagePlugin logger = logging.getLogger(__name__) -async def get_full_prompt(simple_prompt: str, bot, chat_history): - ai_name = "### Assistant" # bot.name - user_name = "### Human" # bot.user_name +gptj_tokenizer = None + + +async def get_full_prompt(simple_prompt: str, bot, chat_history, model_name: str): # https://github.com/ggerganov/llama.cpp/tree/master/examples ## prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" @@ -33,19 +34,84 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): #prompt += f"{bot.name} is also very curious and will ask the user a lot of questions about themselves and their life, she will also try to make the user like her.\n" #prompt += f"\n" - prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" - prompt += f"### Instruction:\nGiven the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}\n\n" - prompt += "### Input:\n" - prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" - prompt += "Scenario: " + bot.get_scenario() + "\n\n" - prompt += "### Response:\n" + # Names + if model_name.startswith("pygmalion"): + ai_name = bot.name + user_name = "You" + elif model_name.startswith("vicuna"): + ai_name = "### Assistant" + user_name = "### Human" + elif model_name.startswith("alpaca"): + ai_name = bot.name # ToDo + user_name = bot.user_name # ToDo + else: + ai_name = bot.name + user_name = bot.user_name + + # First line / Task description + if model_name.startswith("pygmalion"): + prompt = "" + elif model_name.startswith("vicuna"): + prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" + prompt += f"### Instruction:\nGiven the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}\n\n" # ToDo + prompt += "### Input:\n" + elif model_name.startswith("alpaca"): + prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" + prompt += f"### Instruction:\nGiven the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}\n\n" + prompt += "### Input:\n" + else: + prompt = "" + + # Character description + if model_name.startswith("pygmalion"): + prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n" + elif model_name.startswith("vicuna"): + prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo + elif model_name.startswith("alpaca"): + prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo + else: + prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo + + # Scenario + if model_name.startswith("pygmalion"): + prompt += "Scenario: " + bot.get_scenario() + "\n\n" + elif model_name.startswith("vicuna"): + prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo + elif model_name.startswith("alpaca"): + prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo + else: + prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo + + # Response delimiter + if model_name.startswith("pygmalion"): + pass + elif model_name.startswith("vicuna"): + prompt += "### Response:\n" # ToDo + elif model_name.startswith("alpaca"): + prompt += "### Response:\n" + else: + pass + + # Example dialogue for dialogue_item in bot.get_example_dialogue(): - #prompt += "" + "\n" + if model_name.startswith("pygmalion"): + prompt += "" + "\n" dialogue_item = dialogue_item.replace('{{user}}', user_name) dialogue_item = dialogue_item.replace('{{char}}', ai_name) prompt += dialogue_item + "\n\n" - prompt += "" + "\n" + + # Dialogue start + if model_name.startswith("pygmalion"): + prompt += "" + "\n" + elif model_name.startswith("vicuna"): + pass # ToDo + elif model_name.startswith("alpaca"): + pass # ToDo + else: + pass # ToDo + + #prompt += f"{ai_name}: {bot.greeting}\n" #prompt += f"{user_name}: {simple_prompt}\n" #prompt += f"{ai_name}:" @@ -53,8 +119,8 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): MAX_TOKENS = 2048 WINDOW = 600 max_new_tokens = 200 - total_num_tokens = await num_tokens(prompt) - input_num_tokens = await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:") + total_num_tokens = await num_tokens(prompt, model_name) + input_num_tokens = await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:", model_name) total_num_tokens += input_num_tokens visible_history = [] num_message = 0 @@ -74,7 +140,7 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): #if chat_item.message["en"] == bot.greeting: # continue if chat_item.num_tokens == None: - chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"])) + chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"]), model_name) chat_item = chat_history.chat_history[key] # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") @@ -86,7 +152,7 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): visible_history = reversed(visible_history) if not hasattr(bot, "greeting_num_tokens"): - bot.greeting_num_tokens = await num_tokens(bot.greeting) + bot.greeting_num_tokens = await num_tokens(bot.greeting, model_name) if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: prompt += f"{ai_name}: {bot.greeting}\n" total_num_tokens += bot.greeting_num_tokens @@ -124,8 +190,19 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history): return prompt -async def num_tokens(input_text: str): - return await estimate_num_tokens(input_text) +async def num_tokens(input_text: str, model_name: str): + if model_name.startswith("pygmalion"): + #os.makedirs("./models/pygmalion-6b", exist_ok=True) + #hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") + #config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") + global gptj_tokenizer + if not gptj_tokenizer: + gptj_tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") + encoding = gptj_tokenizer.encode(input_text, add_special_tokens=False) + max_input_size = gptj_tokenizer.max_model_input_sizes + return len(encoding) + else: + return await estimate_num_tokens(input_text) async def estimate_num_tokens(input_text: str): diff --git a/matrix_pygmalion_bot/ai/pygmalion_helpers.py b/matrix_pygmalion_bot/ai/pygmalion_helpers.py deleted file mode 100644 index b0bf12e..0000000 --- a/matrix_pygmalion_bot/ai/pygmalion_helpers.py +++ /dev/null @@ -1,124 +0,0 @@ -import asyncio -import os, tempfile -import logging - -import json -import requests - -from transformers import AutoTokenizer, AutoConfig -from huggingface_hub import hf_hub_download - -import io -import base64 -from PIL import Image, PngImagePlugin - -logger = logging.getLogger(__name__) - - -tokenizer = None - - -async def get_full_prompt(simple_prompt: str, bot, chat_history): - - # Prompt without history - prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n" - prompt += "Scenario: " + bot.get_scenario() + "\n\n" - - for dialogue_item in bot.get_example_dialogue(): - prompt += "" + "\n" - dialogue_item = dialogue_item.replace('{{user}}', 'You') - dialogue_item = dialogue_item.replace('{{char}}', bot.name) - prompt += dialogue_item + "\n\n" - prompt += "" + "\n" - #prompt += bot.name + ": " + bot.greeting + "\n" - #prompt += "You: " + simple_prompt + "\n" - #prompt += bot.name + ":" - - MAX_TOKENS = 2048 - WINDOW = 800 - max_new_tokens = 200 - total_num_tokens = await num_tokens(prompt) - input_num_tokens = await num_tokens(f"You: " + simple_prompt + "\n{bot.name}:") - total_num_tokens += input_num_tokens - visible_history = [] - num_message = 0 - for key, chat_item in reversed(chat_history.chat_history.items()): - num_message += 1 - if num_message == 1: - # skip current_message - continue - if chat_item.stop_here: - break - if chat_item.message["en"].startswith('!begin'): - break - if chat_item.message["en"].startswith('!'): - continue - if chat_item.message["en"].startswith(''): - continue - #if chat_item.message["en"] == bot.greeting: - # continue - if chat_item.num_tokens == None: - chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) - chat_item = chat_history.chat_history[key] - # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? - logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") - if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: - visible_history.append(chat_item) - total_num_tokens += chat_item.num_tokens - else: - break - visible_history = reversed(visible_history) - - if not hasattr(bot, "greeting_num_tokens"): - bot.greeting_num_tokens = await num_tokens(bot.greeting) - if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: - prompt += bot.name + ": " + bot.greeting + "\n" - total_num_tokens += bot.greeting_num_tokens - - for chat_item in visible_history: - if chat_item.is_own_message: - line = bot.name + ": " + chat_item.message["en"] + "\n" - else: - line = "You" + ": " + chat_item.message["en"] + "\n" - prompt += line - if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt: - logger.info(f"adding to saved prompt: \"{line}\"") - chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens ) - chat_item.is_in_saved_prompt = True - - if chat_history.saved_context_num_tokens: - logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}") - if chat_history.getSavedPrompt(): - if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens: - chat_history.setFastForward(False) - if chat_history.getFastForward(): - logger.info("using saved prompt") - prompt = chat_history.getSavedPrompt() - if not chat_history.getSavedPrompt() or not chat_history.getFastForward(): - logger.info("regenerating prompt") - chat_history.setSavedPrompt(prompt, total_num_tokens) - for key, chat_item in chat_history.chat_history.items(): - if key != list(chat_history.chat_history)[-1]: # exclude current item - chat_history.chat_history[key].is_in_saved_prompt = True - chat_history.setFastForward(True) - - prompt += "You: " + simple_prompt + "\n" - prompt += bot.name + ":" - - return prompt - - -async def num_tokens(input_text: str): -# os.makedirs("./models/pygmalion-6b", exist_ok=True) -# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") -# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") - global tokenizer - if not tokenizer: - tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") - encoding = tokenizer.encode(input_text, add_special_tokens=False) - max_input_size = tokenizer.max_model_input_sizes - return len(encoding) - - -async def estimate_num_tokens(input_text: str): - return len(input_text)//4+1 diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod.py similarity index 99% rename from matrix_pygmalion_bot/ai/runpod_pygmalion.py rename to matrix_pygmalion_bot/ai/runpod.py index 78fdba9..eaf751e 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod.py @@ -12,7 +12,7 @@ import io import base64 from PIL import Image, PngImagePlugin -from .pygmalion_helpers import get_full_prompt, num_tokens +from .model_helpers import get_full_prompt, num_tokens logger = logging.getLogger(__name__) @@ -34,7 +34,7 @@ async def generate_sync( } max_new_tokens = 200 - prompt_num_tokens = await num_tokens(prompt) + prompt_num_tokens = await num_tokens(prompt, bot.model) # Define your inputs input_data = { diff --git a/matrix_pygmalion_bot/ai/stablehorde.py b/matrix_pygmalion_bot/ai/stablehorde.py index c060d6a..081ba83 100644 --- a/matrix_pygmalion_bot/ai/stablehorde.py +++ b/matrix_pygmalion_bot/ai/stablehorde.py @@ -7,8 +7,7 @@ import requests from transformers import AutoTokenizer, AutoConfig from huggingface_hub import hf_hub_download -from .pygmalion_helpers import get_full_prompt, num_tokens -#from .llama_helpers import get_full_prompt, num_tokens +from .model_helpers import get_full_prompt, num_tokens logger = logging.getLogger(__name__) @@ -29,7 +28,7 @@ async def generate_sync( } max_new_tokens = 200 - prompt_num_tokens = await num_tokens(prompt) + prompt_num_tokens = await num_tokens(prompt, bot.model) # Define your inputs input_data = { diff --git a/matrix_pygmalion_bot/core.py b/matrix_pygmalion_bot/core.py index 253154b..234d8fe 100644 --- a/matrix_pygmalion_bot/core.py +++ b/matrix_pygmalion_bot/core.py @@ -16,11 +16,7 @@ import json from .helpers import Event from .chatlog import BotChatHistory -image_ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") -text_ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp") -#ai = importlib.import_module("matrix_pygmalion_bot.ai.stablehorde") -#from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style -#from .runpod_pygmalion import generate_sync, get_full_prompt + import matrix_pygmalion_bot.translate as translate STORE_PATH = "./.store/" @@ -41,6 +37,7 @@ class Callbacks(object): async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: if not hasattr(event, 'body'): return + if not room.room_id in self.bot.room_config: self.bot.room_config[room.room_id] = {} self.bot.room_config[room.room_id]["tick"] = 0 @@ -136,32 +133,32 @@ class Callbacks(object): # negative_prompt = "ugly, deformed, out of frame" try: typing = lambda : self.client.room_typing(room.room_id, True, 15000) - if self.bot.service == "runpod": + if self.bot.service_image == "runpod": if num == 1: - output = await image_ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 2: - output = await image_ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 3: - output = await image_ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 4: - output = await image_ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 5: - output = await image_ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 6: - output = await image_ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 7: - output = await image_ai.generate_image7(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image7(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 8: - output = await image_ai.generate_image8(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await self.bot.image_ai.generate_image8(prompt, negative_prompt, self.bot.runpod_api_key, typing) else: raise ValueError('no image generator with that number') - elif self.bot.service == "stablehorde": + elif self.bot.service_image == "stablehorde": if num == 1: - output = await image_ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) + output = await self.bot.image_ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) elif num == 2: - output = await image_ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) + output = await self.bot.image_ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) elif num == 3: - output = await image_ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) + output = await self.bot.image_ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) else: raise ValueError('no image generator with that number') else: @@ -240,8 +237,8 @@ class Callbacks(object): # send, mail, drop, snap picture, photo, image, portrait pass - full_prompt = await text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.room_id)) - num_tokens = await text_ai.num_tokens(full_prompt) + full_prompt = await self.bot.text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.room_id), self.bot.model) + num_tokens = await self.bot.text_ai.num_tokens(full_prompt, self.bot.model) logger.debug(full_prompt) logger.info(f"Prompt has " + str(num_tokens) + " tokens") # answer = "" @@ -261,7 +258,7 @@ class Callbacks(object): # print("") try: typing = lambda : self.client.room_typing(room.room_id, True, 15000) - answer = await text_ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing, api_endpoint) + answer = await self.bot.text_ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing, api_endpoint) answer = answer.strip() await self.client.room_typing(room.room_id, False) if not (self.bot.translate is None): @@ -307,7 +304,12 @@ class ChatBot(object): self.password = password self.device_name = device_name + self.service_text = "other" + self.service_image = "other" + self.model = "other" self.runpod_api_key = None + self.text_ai = None + self.image_ai = None self.client = None self.callbacks = None @@ -398,6 +400,25 @@ class ChatBot(object): await self.client.close() #return sync_forever_task + async def load_ai(self): + if self.service_text == "runpod": + self.text_ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod") + elif self.service_text == "stablehorde": + self.text_ai = importlib.import_module("matrix_pygmalion_bot.ai.stablehorde") + elif self.service_text == "koboldcpp": + self.text_ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp") + else: + raise ValueError(f"no text service with the name \"{self.bot.service_text}\"") + + if self.service_image == "runpod": + self.image_ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod") + elif self.service_image == "stablehorde": + self.image_ai = importlib.import_module("matrix_pygmalion_bot.ai.stablehorde") + elif self.service_image == "koboldcpp": + self.image_ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp") + else: + raise ValueError(f"no image service with the name \"{self.bot.service_text}\"") + async def watch_for_sync(self, sync_event): logger.debug("Awaiting sync") await sync_event.wait() @@ -511,14 +532,25 @@ async def main() -> None: events = config[section]['events'].strip().split('\n') for event in events: await bot.add_event(event) - if config.has_option('DEFAULT', 'service'): - bot.service = config['DEFAULT']['service'] + if config.has_option('DEFAULT', 'service_text'): + bot.service_text = config['DEFAULT']['service_text'] + if config.has_option(section, 'service_text'): + bot.service_text = config[section]['service_text'] + if config.has_option('DEFAULT', 'service_image'): + bot.service_image = config['DEFAULT']['service_image'] + if config.has_option(section, 'service_image'): + bot.service_image = config[section]['service_image'] + if config.has_option('DEFAULT', 'model'): + bot.model = config['DEFAULT']['model'] + if config.has_option(section, 'model'): + bot.model = config[section]['model'] if config.has_option('DEFAULT', 'runpod_api_key'): bot.runpod_api_key = config['DEFAULT']['runpod_api_key'] if config.has_option('DEFAULT', 'stablehorde_api_key'): bot.stablehorde_api_key = config['DEFAULT']['stablehorde_api_key'] await bot.read_conf2(section) bots.append(bot) + await bot.load_ai() await bot.login() #logger.info("gather") if sys.version_info[0] == 3 and sys.version_info[1] < 11: