You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.9 KiB
83 lines
2.9 KiB
2 years ago
|
import asyncio
|
||
|
import os, tempfile
|
||
|
import logging
|
||
|
|
||
|
import json
|
||
|
import requests
|
||
|
|
||
|
from transformers import AutoTokenizer, AutoConfig
|
||
|
from huggingface_hub import hf_hub_download
|
||
|
|
||
|
import io
|
||
|
import base64
|
||
|
from PIL import Image, PngImagePlugin
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
async def get_full_prompt(simple_prompt: str, bot, chat_history):
|
||
|
|
||
|
# Prompt without history
|
||
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
|
||
|
prompt += "Scenario: " + bot.scenario + "\n"
|
||
|
prompt += "<START>" + "\n"
|
||
|
#prompt += bot.name + ": " + bot.greeting + "\n"
|
||
|
prompt += "You: " + simple_prompt + "\n"
|
||
|
prompt += bot.name + ":"
|
||
|
|
||
|
MAX_TOKENS = 2048
|
||
|
max_new_tokens = 200
|
||
|
total_num_tokens = await num_tokens(prompt)
|
||
|
visible_history = []
|
||
|
current_message = True
|
||
|
for key, chat_item in reversed(chat_history.chat_history.items()):
|
||
|
if current_message:
|
||
|
current_message = False
|
||
|
continue
|
||
|
if chat_item.message["en"].startswith('!begin'):
|
||
|
break
|
||
|
if chat_item.message["en"].startswith('!'):
|
||
|
continue
|
||
|
if chat_item.message["en"].startswith('<ERROR>'):
|
||
|
continue
|
||
|
#if chat_item.message["en"] == bot.greeting:
|
||
|
# continue
|
||
|
if chat_item.num_tokens == None:
|
||
|
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"]))
|
||
|
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
|
||
|
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
|
||
|
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens:
|
||
|
visible_history.append(chat_item)
|
||
|
total_num_tokens += chat_item.num_tokens
|
||
|
else:
|
||
|
break
|
||
|
visible_history = reversed(visible_history)
|
||
|
|
||
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n"
|
||
|
prompt += "Scenario: " + bot.scenario + "\n"
|
||
|
prompt += "<START>" + "\n"
|
||
|
#prompt += bot.name + ": " + bot.greeting + "\n"
|
||
|
for chat_item in visible_history:
|
||
|
if chat_item.is_own_message:
|
||
|
prompt += bot.name + ": " + chat_item.message["en"] + "\n"
|
||
|
else:
|
||
|
prompt += "You" + ": " + chat_item.message["en"] + "\n"
|
||
|
prompt += "You: " + simple_prompt + "\n"
|
||
|
prompt += bot.name + ":"
|
||
|
|
||
|
return prompt
|
||
|
|
||
|
|
||
|
async def num_tokens(input_text: str):
|
||
|
# os.makedirs("./models/pygmalion-6b", exist_ok=True)
|
||
|
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
|
||
|
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
|
||
|
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
|
||
|
encoding = tokenizer.encode(input_text, add_special_tokens=False)
|
||
|
max_input_size = tokenizer.max_model_input_sizes
|
||
|
return len(encoding)
|
||
|
|
||
|
|
||
|
async def estimate_num_tokens(input_text: str):
|
||
|
return len(input_text)//4+1
|