Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
4.9 KiB

import asyncio
import os, tempfile
import logging
import json
import requests
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import io
import base64
from PIL import Image, PngImagePlugin
logger = logging.getLogger(__name__)
tokenizer = None
async def get_full_prompt(simple_prompt: str, bot, chat_history):
# Prompt without history
prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n"
prompt += "Scenario: " + bot.get_scenario() + "\n\n"
for dialogue_item in bot.get_example_dialogue():
prompt += "<START>" + "\n"
dialogue_item = dialogue_item.replace('{{user}}', 'You')
dialogue_item = dialogue_item.replace('{{char}}', bot.name)
prompt += dialogue_item + "\n\n"
prompt += "<START>" + "\n"
#prompt += bot.name + ": " + bot.greeting + "\n"
#prompt += "You: " + simple_prompt + "\n"
#prompt += bot.name + ":"
MAX_TOKENS = 2048
WINDOW = 800
max_new_tokens = 200
total_num_tokens = await num_tokens(prompt)
input_num_tokens = await num_tokens(f"You: " + simple_prompt + "\n{bot.name}:")
total_num_tokens += input_num_tokens
visible_history = []
num_message = 0
for key, chat_item in reversed(chat_history.chat_history.items()):
num_message += 1
if num_message == 1:
# skip current_message
continue
if chat_item.stop_here:
break
if chat_item.message["en"].startswith('!begin'):
break
if chat_item.message["en"].startswith('!'):
continue
if chat_item.message["en"].startswith('<ERROR>'):
continue
#if chat_item.message["en"] == bot.greeting:
# continue
if chat_item.num_tokens == None:
chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"]))
chat_item = chat_history.chat_history[key]
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
visible_history.append(chat_item)
total_num_tokens += chat_item.num_tokens
else:
break
visible_history = reversed(visible_history)
if not hasattr(bot, "greeting_num_tokens"):
bot.greeting_num_tokens = await num_tokens(bot.greeting)
if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
prompt += bot.name + ": " + bot.greeting + "\n"
total_num_tokens += bot.greeting_num_tokens
for chat_item in visible_history:
if chat_item.is_own_message:
line = bot.name + ": " + chat_item.message["en"] + "\n"
else:
line = "You" + ": " + chat_item.message["en"] + "\n"
prompt += line
if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt:
logger.info(f"adding to saved prompt: \"{line}\"")
chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens )
chat_item.is_in_saved_prompt = True
if chat_history.saved_context_num_tokens:
logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}")
if chat_history.getSavedPrompt():
if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens:
chat_history.setFastForward(False)
if chat_history.getFastForward():
logger.info("using saved prompt")
prompt = chat_history.getSavedPrompt()
if not chat_history.getSavedPrompt() or not chat_history.getFastForward():
logger.info("regenerating prompt")
chat_history.setSavedPrompt(prompt, total_num_tokens)
for key, chat_item in chat_history.chat_history.items():
if key != list(chat_history.chat_history)[-1]: # exclude current item
chat_history.chat_history[key].is_in_saved_prompt = True
chat_history.setFastForward(True)
prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":"
return prompt
async def num_tokens(input_text: str):
# os.makedirs("./models/pygmalion-6b", exist_ok=True)
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
global tokenizer
if not tokenizer:
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
encoding = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes
return len(encoding)
async def estimate_num_tokens(input_text: str):
return len(input_text)//4+1