diff --git a/matrix_pygmalion_bot/ai/stablehorde.py b/matrix_pygmalion_bot/ai/stablehorde.py index 8ac6109..c060d6a 100644 --- a/matrix_pygmalion_bot/ai/stablehorde.py +++ b/matrix_pygmalion_bot/ai/stablehorde.py @@ -7,6 +7,9 @@ import requests from transformers import AutoTokenizer, AutoConfig from huggingface_hub import hf_hub_download +from .pygmalion_helpers import get_full_prompt, num_tokens +#from .llama_helpers import get_full_prompt, num_tokens + logger = logging.getLogger(__name__) @@ -96,71 +99,6 @@ async def generate_sync( else: return f" {status}" -async def get_full_prompt(simple_prompt: str, bot, chat_history): - - # Prompt without history - prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n" - prompt += "Scenario: " + bot.get_scenario() + "\n" - prompt += "" + "\n" - #prompt += bot.name + ": " + bot.greeting + "\n" - prompt += "You: " + simple_prompt + "\n" - prompt += bot.name + ":" - - MAX_TOKENS = 2048 - max_new_tokens = 200 - total_num_tokens = await num_tokens(prompt) - visible_history = [] - current_message = True - for key, chat_item in reversed(chat_history.chat_history.items()): - if current_message: - current_message = False - continue - if chat_item.message["en"].startswith('!begin'): - break - if chat_item.message["en"].startswith('!'): - continue - if chat_item.message["en"].startswith(''): - continue - #if chat_item.message["en"] == bot.greeting: - # continue - if chat_item.num_tokens == None: - chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) - # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? - logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") - if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: - visible_history.append(chat_item) - total_num_tokens += chat_item.num_tokens - else: - break - visible_history = reversed(visible_history) - - prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n" - prompt += "Scenario: " + bot.get_scenario() + "\n" - prompt += "" + "\n" - #prompt += bot.name + ": " + bot.greeting + "\n" - for chat_item in visible_history: - if chat_item.is_own_message: - prompt += bot.name + ": " + chat_item.message["en"] + "\n" - else: - prompt += "You" + ": " + chat_item.message["en"] + "\n" - prompt += "You: " + simple_prompt + "\n" - prompt += bot.name + ":" - - return prompt - - -async def num_tokens(input_text: str): -# os.makedirs("./models/pygmalion-6b", exist_ok=True) -# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") -# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") - tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") - encoding = tokenizer.encode(input_text, add_special_tokens=False) - max_input_size = tokenizer.max_model_input_sizes - return len(encoding) - -async def estimate_num_tokens(input_text: str): - return len(input_text)//4+1 - async def generate_image(input_prompt: str, negative_prompt: str, model: str, api_key: str):