import asyncio
import os , tempfile
import logging
import json
import requests
from transformers import AutoTokenizer , AutoConfig
from huggingface_hub import hf_hub_download
import io
import base64
from PIL import Image , PngImagePlugin
logger = logging . getLogger ( __name__ )
tokenizer = None
async def get_full_prompt ( simple_prompt : str , bot , chat_history ) :
# Prompt without history
prompt = bot . name + " ' s Persona: " + bot . get_persona ( ) + " \n "
prompt + = " Scenario: " + bot . get_scenario ( ) + " \n \n "
for dialogue_item in bot . get_example_dialogue ( ) :
prompt + = " <START> " + " \n "
dialogue_item = dialogue_item . replace ( ' {{ user}} ' , ' You ' )
dialogue_item = dialogue_item . replace ( ' {{ char}} ' , bot . name )
prompt + = dialogue_item + " \n \n "
prompt + = " <START> " + " \n "
#prompt += bot.name + ": " + bot.greeting + "\n"
#prompt += "You: " + simple_prompt + "\n"
#prompt += bot.name + ":"
MAX_TOKENS = 2048
WINDOW = 800
max_new_tokens = 200
total_num_tokens = await num_tokens ( prompt )
input_num_tokens = await num_tokens ( f " You: " + simple_prompt + " \n {bot.name} : " )
total_num_tokens + = input_num_tokens
visible_history = [ ]
num_message = 0
for key , chat_item in reversed ( chat_history . chat_history . items ( ) ) :
num_message + = 1
if num_message == 1 :
# skip current_message
continue
if chat_item . stop_here :
break
if chat_item . message [ " en " ] . startswith ( ' !begin ' ) :
break
if chat_item . message [ " en " ] . startswith ( ' ! ' ) :
continue
if chat_item . message [ " en " ] . startswith ( ' <ERROR> ' ) :
continue
#if chat_item.message["en"] == bot.greeting:
# continue
if chat_item . num_tokens == None :
chat_history . chat_history [ key ] . num_tokens = await num_tokens ( " {} : {} " . format ( chat_item . user_name , chat_item . message [ " en " ] ) )
chat_item = chat_history . chat_history [ key ]
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
logger . debug ( f " History: " + str ( chat_item ) + " [ " + str ( chat_item . num_tokens ) + " ] " )
if total_num_tokens + chat_item . num_tokens < = MAX_TOKENS - WINDOW - max_new_tokens :
visible_history . append ( chat_item )
total_num_tokens + = chat_item . num_tokens
else :
break
visible_history = reversed ( visible_history )
if not hasattr ( bot , " greeting_num_tokens " ) :
bot . greeting_num_tokens = await num_tokens ( bot . greeting )
if total_num_tokens + bot . greeting_num_tokens < = MAX_TOKENS - WINDOW - max_new_tokens :
prompt + = bot . name + " : " + bot . greeting + " \n "
total_num_tokens + = bot . greeting_num_tokens
for chat_item in visible_history :
if chat_item . is_own_message :
line = bot . name + " : " + chat_item . message [ " en " ] + " \n "
else :
line = " You " + " : " + chat_item . message [ " en " ] + " \n "
prompt + = line
if chat_history . getSavedPrompt ( ) and not chat_item . is_in_saved_prompt :
logger . info ( f " adding to saved prompt: \" { line } \" " )
chat_history . setSavedPrompt ( chat_history . getSavedPrompt ( ) + line , chat_history . saved_context_num_tokens + chat_item . num_tokens )
chat_item . is_in_saved_prompt = True
if chat_history . saved_context_num_tokens :
logger . info ( f " saved_context has { chat_history . saved_context_num_tokens + input_num_tokens } tokens. new context would be { total_num_tokens } . Limit is { MAX_TOKENS } " )
if chat_history . getSavedPrompt ( ) :
if chat_history . saved_context_num_tokens + input_num_tokens > MAX_TOKENS - max_new_tokens :
chat_history . setFastForward ( False )
if chat_history . getFastForward ( ) :
logger . info ( " using saved prompt " )
prompt = chat_history . getSavedPrompt ( )
if not chat_history . getSavedPrompt ( ) or not chat_history . getFastForward ( ) :
logger . info ( " regenerating prompt " )
chat_history . setSavedPrompt ( prompt , total_num_tokens )
for key , chat_item in chat_history . chat_history . items ( ) :
if key != list ( chat_history . chat_history ) [ - 1 ] : # exclude current item
chat_history . chat_history [ key ] . is_in_saved_prompt = True
chat_history . setFastForward ( True )
prompt + = " You: " + simple_prompt + " \n "
prompt + = bot . name + " : "
return prompt
async def num_tokens ( input_text : str ) :
# os.makedirs("./models/pygmalion-6b", exist_ok=True)
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
global tokenizer
if not tokenizer :
tokenizer = AutoTokenizer . from_pretrained ( " PygmalionAI/pygmalion-6b " )
encoding = tokenizer . encode ( input_text , add_special_tokens = False )
max_input_size = tokenizer . max_model_input_sizes
return len ( encoding )
async def estimate_num_tokens ( input_text : str ) :
return len ( input_text ) / / 4 + 1