Hendrik Langer
2 years ago
6 changed files with 155 additions and 172 deletions
@ -1,124 +0,0 @@ |
|||||
import asyncio |
|
||||
import os, tempfile |
|
||||
import logging |
|
||||
|
|
||||
import json |
|
||||
import requests |
|
||||
|
|
||||
from transformers import AutoTokenizer, AutoConfig |
|
||||
from huggingface_hub import hf_hub_download |
|
||||
|
|
||||
import io |
|
||||
import base64 |
|
||||
from PIL import Image, PngImagePlugin |
|
||||
|
|
||||
logger = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
tokenizer = None |
|
||||
|
|
||||
|
|
||||
async def get_full_prompt(simple_prompt: str, bot, chat_history): |
|
||||
|
|
||||
# Prompt without history |
|
||||
prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n" |
|
||||
prompt += "Scenario: " + bot.get_scenario() + "\n\n" |
|
||||
|
|
||||
for dialogue_item in bot.get_example_dialogue(): |
|
||||
prompt += "<START>" + "\n" |
|
||||
dialogue_item = dialogue_item.replace('{{user}}', 'You') |
|
||||
dialogue_item = dialogue_item.replace('{{char}}', bot.name) |
|
||||
prompt += dialogue_item + "\n\n" |
|
||||
prompt += "<START>" + "\n" |
|
||||
#prompt += bot.name + ": " + bot.greeting + "\n" |
|
||||
#prompt += "You: " + simple_prompt + "\n" |
|
||||
#prompt += bot.name + ":" |
|
||||
|
|
||||
MAX_TOKENS = 2048 |
|
||||
WINDOW = 800 |
|
||||
max_new_tokens = 200 |
|
||||
total_num_tokens = await num_tokens(prompt) |
|
||||
input_num_tokens = await num_tokens(f"You: " + simple_prompt + "\n{bot.name}:") |
|
||||
total_num_tokens += input_num_tokens |
|
||||
visible_history = [] |
|
||||
num_message = 0 |
|
||||
for key, chat_item in reversed(chat_history.chat_history.items()): |
|
||||
num_message += 1 |
|
||||
if num_message == 1: |
|
||||
# skip current_message |
|
||||
continue |
|
||||
if chat_item.stop_here: |
|
||||
break |
|
||||
if chat_item.message["en"].startswith('!begin'): |
|
||||
break |
|
||||
if chat_item.message["en"].startswith('!'): |
|
||||
continue |
|
||||
if chat_item.message["en"].startswith('<ERROR>'): |
|
||||
continue |
|
||||
#if chat_item.message["en"] == bot.greeting: |
|
||||
# continue |
|
||||
if chat_item.num_tokens == None: |
|
||||
chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) |
|
||||
chat_item = chat_history.chat_history[key] |
|
||||
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? |
|
||||
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") |
|
||||
if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: |
|
||||
visible_history.append(chat_item) |
|
||||
total_num_tokens += chat_item.num_tokens |
|
||||
else: |
|
||||
break |
|
||||
visible_history = reversed(visible_history) |
|
||||
|
|
||||
if not hasattr(bot, "greeting_num_tokens"): |
|
||||
bot.greeting_num_tokens = await num_tokens(bot.greeting) |
|
||||
if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens: |
|
||||
prompt += bot.name + ": " + bot.greeting + "\n" |
|
||||
total_num_tokens += bot.greeting_num_tokens |
|
||||
|
|
||||
for chat_item in visible_history: |
|
||||
if chat_item.is_own_message: |
|
||||
line = bot.name + ": " + chat_item.message["en"] + "\n" |
|
||||
else: |
|
||||
line = "You" + ": " + chat_item.message["en"] + "\n" |
|
||||
prompt += line |
|
||||
if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt: |
|
||||
logger.info(f"adding to saved prompt: \"{line}\"") |
|
||||
chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens ) |
|
||||
chat_item.is_in_saved_prompt = True |
|
||||
|
|
||||
if chat_history.saved_context_num_tokens: |
|
||||
logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}") |
|
||||
if chat_history.getSavedPrompt(): |
|
||||
if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens: |
|
||||
chat_history.setFastForward(False) |
|
||||
if chat_history.getFastForward(): |
|
||||
logger.info("using saved prompt") |
|
||||
prompt = chat_history.getSavedPrompt() |
|
||||
if not chat_history.getSavedPrompt() or not chat_history.getFastForward(): |
|
||||
logger.info("regenerating prompt") |
|
||||
chat_history.setSavedPrompt(prompt, total_num_tokens) |
|
||||
for key, chat_item in chat_history.chat_history.items(): |
|
||||
if key != list(chat_history.chat_history)[-1]: # exclude current item |
|
||||
chat_history.chat_history[key].is_in_saved_prompt = True |
|
||||
chat_history.setFastForward(True) |
|
||||
|
|
||||
prompt += "You: " + simple_prompt + "\n" |
|
||||
prompt += bot.name + ":" |
|
||||
|
|
||||
return prompt |
|
||||
|
|
||||
|
|
||||
async def num_tokens(input_text: str): |
|
||||
# os.makedirs("./models/pygmalion-6b", exist_ok=True) |
|
||||
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") |
|
||||
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") |
|
||||
global tokenizer |
|
||||
if not tokenizer: |
|
||||
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
|
||||
encoding = tokenizer.encode(input_text, add_special_tokens=False) |
|
||||
max_input_size = tokenizer.max_model_input_sizes |
|
||||
return len(encoding) |
|
||||
|
|
||||
|
|
||||
async def estimate_num_tokens(input_text: str): |
|
||||
return len(input_text)//4+1 |
|
Loading…
Reference in new issue