|
|
@ -7,6 +7,9 @@ import requests |
|
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
from .pygmalion_helpers import get_full_prompt, num_tokens |
|
|
|
#from .llama_helpers import get_full_prompt, num_tokens |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
@ -96,71 +99,6 @@ async def generate_sync( |
|
|
|
else: |
|
|
|
return f"<ERROR> {status}" |
|
|
|
|
|
|
|
async def get_full_prompt(simple_prompt: str, bot, chat_history): |
|
|
|
|
|
|
|
# Prompt without history |
|
|
|
prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n" |
|
|
|
prompt += "Scenario: " + bot.get_scenario() + "\n" |
|
|
|
prompt += "<START>" + "\n" |
|
|
|
#prompt += bot.name + ": " + bot.greeting + "\n" |
|
|
|
prompt += "You: " + simple_prompt + "\n" |
|
|
|
prompt += bot.name + ":" |
|
|
|
|
|
|
|
MAX_TOKENS = 2048 |
|
|
|
max_new_tokens = 200 |
|
|
|
total_num_tokens = await num_tokens(prompt) |
|
|
|
visible_history = [] |
|
|
|
current_message = True |
|
|
|
for key, chat_item in reversed(chat_history.chat_history.items()): |
|
|
|
if current_message: |
|
|
|
current_message = False |
|
|
|
continue |
|
|
|
if chat_item.message["en"].startswith('!begin'): |
|
|
|
break |
|
|
|
if chat_item.message["en"].startswith('!'): |
|
|
|
continue |
|
|
|
if chat_item.message["en"].startswith('<ERROR>'): |
|
|
|
continue |
|
|
|
#if chat_item.message["en"] == bot.greeting: |
|
|
|
# continue |
|
|
|
if chat_item.num_tokens == None: |
|
|
|
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) |
|
|
|
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? |
|
|
|
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") |
|
|
|
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: |
|
|
|
visible_history.append(chat_item) |
|
|
|
total_num_tokens += chat_item.num_tokens |
|
|
|
else: |
|
|
|
break |
|
|
|
visible_history = reversed(visible_history) |
|
|
|
|
|
|
|
prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n" |
|
|
|
prompt += "Scenario: " + bot.get_scenario() + "\n" |
|
|
|
prompt += "<START>" + "\n" |
|
|
|
#prompt += bot.name + ": " + bot.greeting + "\n" |
|
|
|
for chat_item in visible_history: |
|
|
|
if chat_item.is_own_message: |
|
|
|
prompt += bot.name + ": " + chat_item.message["en"] + "\n" |
|
|
|
else: |
|
|
|
prompt += "You" + ": " + chat_item.message["en"] + "\n" |
|
|
|
prompt += "You: " + simple_prompt + "\n" |
|
|
|
prompt += bot.name + ":" |
|
|
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
async def num_tokens(input_text: str): |
|
|
|
# os.makedirs("./models/pygmalion-6b", exist_ok=True) |
|
|
|
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") |
|
|
|
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
|
|
|
encoding = tokenizer.encode(input_text, add_special_tokens=False) |
|
|
|
max_input_size = tokenizer.max_model_input_sizes |
|
|
|
return len(encoding) |
|
|
|
|
|
|
|
async def estimate_num_tokens(input_text: str): |
|
|
|
return len(input_text)//4+1 |
|
|
|
|
|
|
|
|
|
|
|
async def generate_image(input_prompt: str, negative_prompt: str, model: str, api_key: str): |
|
|
|
|
|
|
|