You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
9.4 KiB
210 lines
9.4 KiB
2 years ago
|
import asyncio
|
||
|
import os, tempfile
|
||
|
import logging
|
||
|
|
||
|
import json
|
||
|
import requests
|
||
|
|
||
|
from transformers import AutoTokenizer, AutoConfig
|
||
|
from huggingface_hub import hf_hub_download
|
||
|
|
||
|
import io
|
||
|
import base64
|
||
|
from PIL import Image, PngImagePlugin
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
2 years ago
|
gptj_tokenizer = None
|
||
|
|
||
|
|
||
|
async def get_full_prompt(simple_prompt: str, bot, chat_history, model_name: str):
|
||
2 years ago
|
|
||
|
# https://github.com/ggerganov/llama.cpp/tree/master/examples
|
||
2 years ago
|
## prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"
|
||
|
# prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
|
||
|
# #"BEGINNING OF CONVERSATION:"
|
||
2 years ago
|
# prompt += user_name + ": " + simple_prompt + "\n"
|
||
|
# prompt += ai_name + ":"
|
||
2 years ago
|
|
||
2 years ago
|
#prompt = f"This is a transcript of a 1000 page, never ending conversation between {bot.user_name} and the cute and helpful AI assistant {bot.name}. {bot.name} is a girl who is an AI running on the users computer.\n"
|
||
|
#prompt += f"{bot.name} can think for herself without the user seeing her thoughts by adding a /think prefix to her output. She uses this to reason about the world and to think about what she should say next.\n"
|
||
|
#prompt += f"{bot.name} is always coherent and makes sense, but if she isn't sure if what she is saying is correct she will ask the user for help.\n"
|
||
|
#prompt += f"{bot.name} is a very helpful AI and will help the user with anything they need, she is also very friendly and will try to make the user feel better if they are sad.\n"
|
||
|
#prompt += f"{bot.name} is also very curious and will ask the user a lot of questions about themselves and their life, she will also try to make the user like her.\n"
|
||
|
#prompt += f"\n"
|
||
|
|
||
|
|
||
2 years ago
|
# Names
|
||
|
if model_name.startswith("pygmalion"):
|
||
|
ai_name = bot.name
|
||
|
user_name = "You"
|
||
|
elif model_name.startswith("vicuna"):
|
||
|
ai_name = "### Assistant"
|
||
|
user_name = "### Human"
|
||
|
elif model_name.startswith("alpaca"):
|
||
|
ai_name = bot.name # ToDo
|
||
|
user_name = bot.user_name # ToDo
|
||
|
else:
|
||
|
ai_name = bot.name
|
||
|
user_name = bot.user_name
|
||
|
|
||
|
# First line / Task description
|
||
|
if model_name.startswith("pygmalion"):
|
||
|
prompt = ""
|
||
|
elif model_name.startswith("vicuna"):
|
||
|
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||
|
prompt += f"### Instruction:\nGiven the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}\n\n" # ToDo
|
||
|
prompt += "### Input:\n"
|
||
|
elif model_name.startswith("alpaca"):
|
||
|
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||
|
prompt += f"### Instruction:\nGiven the following character description and scenario, write a script for a dialogue between the human user {bot.user_name} and the fictional AI assistant {bot.name}. Play the role of the character {bot.name}\n\n"
|
||
|
prompt += "### Input:\n"
|
||
|
else:
|
||
|
prompt = ""
|
||
|
|
||
|
# Character description
|
||
|
if model_name.startswith("pygmalion"):
|
||
|
prompt = bot.name + "'s Persona: " + bot.get_persona() + "\n"
|
||
|
elif model_name.startswith("vicuna"):
|
||
|
prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo
|
||
|
elif model_name.startswith("alpaca"):
|
||
|
prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo
|
||
|
else:
|
||
|
prompt += bot.name + "'s Persona: " + bot.get_persona() + "\n" # ToDo
|
||
|
|
||
|
# Scenario
|
||
|
if model_name.startswith("pygmalion"):
|
||
|
prompt += "Scenario: " + bot.get_scenario() + "\n\n"
|
||
|
elif model_name.startswith("vicuna"):
|
||
|
prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo
|
||
|
elif model_name.startswith("alpaca"):
|
||
|
prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo
|
||
|
else:
|
||
|
prompt += "Scenario: " + bot.get_scenario() + "\n\n" # ToDo
|
||
|
|
||
|
# Response delimiter
|
||
|
if model_name.startswith("pygmalion"):
|
||
|
pass
|
||
|
elif model_name.startswith("vicuna"):
|
||
|
prompt += "### Response:\n" # ToDo
|
||
|
elif model_name.startswith("alpaca"):
|
||
|
prompt += "### Response:\n"
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
# Example dialogue
|
||
2 years ago
|
for dialogue_item in bot.get_example_dialogue():
|
||
2 years ago
|
if model_name.startswith("pygmalion"):
|
||
|
prompt += "<START>" + "\n"
|
||
2 years ago
|
dialogue_item = dialogue_item.replace('{{user}}', user_name)
|
||
|
dialogue_item = dialogue_item.replace('{{char}}', ai_name)
|
||
2 years ago
|
prompt += dialogue_item + "\n\n"
|
||
2 years ago
|
|
||
|
# Dialogue start
|
||
|
if model_name.startswith("pygmalion"):
|
||
|
prompt += "<START>" + "\n"
|
||
|
elif model_name.startswith("vicuna"):
|
||
|
pass # ToDo
|
||
|
elif model_name.startswith("alpaca"):
|
||
|
pass # ToDo
|
||
|
else:
|
||
|
pass # ToDo
|
||
|
|
||
|
|
||
2 years ago
|
#prompt += f"{ai_name}: {bot.greeting}\n"
|
||
|
#prompt += f"{user_name}: {simple_prompt}\n"
|
||
|
#prompt += f"{ai_name}:"
|
||
2 years ago
|
|
||
|
MAX_TOKENS = 2048
|
||
2 years ago
|
WINDOW = 600
|
||
2 years ago
|
max_new_tokens = 200
|
||
2 years ago
|
total_num_tokens = await num_tokens(prompt, model_name)
|
||
|
input_num_tokens = await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:", model_name)
|
||
2 years ago
|
total_num_tokens += input_num_tokens
|
||
2 years ago
|
visible_history = []
|
||
2 years ago
|
num_message = 0
|
||
2 years ago
|
for key, chat_item in reversed(chat_history.chat_history.items()):
|
||
2 years ago
|
num_message += 1
|
||
|
if num_message == 1:
|
||
|
# skip current_message
|
||
2 years ago
|
continue
|
||
2 years ago
|
if chat_item.stop_here:
|
||
|
break
|
||
2 years ago
|
if chat_item.message["en"].startswith('!begin'):
|
||
|
break
|
||
|
if chat_item.message["en"].startswith('!'):
|
||
|
continue
|
||
|
if chat_item.message["en"].startswith('<ERROR>'):
|
||
|
continue
|
||
|
#if chat_item.message["en"] == bot.greeting:
|
||
|
# continue
|
||
|
if chat_item.num_tokens == None:
|
||
2 years ago
|
chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"]), model_name)
|
||
2 years ago
|
chat_item = chat_history.chat_history[key]
|
||
2 years ago
|
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
|
||
|
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
|
||
2 years ago
|
if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
|
||
2 years ago
|
visible_history.append(chat_item)
|
||
|
total_num_tokens += chat_item.num_tokens
|
||
|
else:
|
||
|
break
|
||
|
visible_history = reversed(visible_history)
|
||
|
|
||
2 years ago
|
if not hasattr(bot, "greeting_num_tokens"):
|
||
2 years ago
|
bot.greeting_num_tokens = await num_tokens(bot.greeting, model_name)
|
||
2 years ago
|
if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
|
||
2 years ago
|
prompt += f"{ai_name}: {bot.greeting}\n"
|
||
2 years ago
|
total_num_tokens += bot.greeting_num_tokens
|
||
2 years ago
|
|
||
2 years ago
|
for chat_item in visible_history:
|
||
|
if chat_item.is_own_message:
|
||
2 years ago
|
line = f"{ai_name}: {chat_item.message['en']}\n"
|
||
2 years ago
|
else:
|
||
2 years ago
|
line = f"{user_name}: {chat_item.message['en']}\n"
|
||
|
prompt += line
|
||
|
if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt:
|
||
|
logger.info(f"adding to saved prompt: \"line\"")
|
||
|
chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens )
|
||
|
chat_item.is_in_saved_prompt = True
|
||
|
|
||
|
if chat_history.saved_context_num_tokens:
|
||
|
logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}")
|
||
|
if chat_history.getSavedPrompt():
|
||
|
if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens:
|
||
|
chat_history.setFastForward(False)
|
||
|
if chat_history.getFastForward():
|
||
|
logger.info("using saved prompt")
|
||
|
prompt = chat_history.getSavedPrompt()
|
||
|
if not chat_history.getSavedPrompt() or not chat_history.getFastForward():
|
||
|
logger.info("regenerating prompt")
|
||
|
chat_history.setSavedPrompt(prompt, total_num_tokens)
|
||
|
for key, chat_item in chat_history.chat_history.items():
|
||
|
if key != list(chat_history.chat_history)[-1]: # exclude current item
|
||
|
chat_history.chat_history[key].is_in_saved_prompt = True
|
||
|
chat_history.setFastForward(True)
|
||
|
|
||
2 years ago
|
prompt += f"{user_name}: {simple_prompt}\n"
|
||
|
prompt += f"{ai_name}:"
|
||
2 years ago
|
|
||
|
return prompt
|
||
|
|
||
|
|
||
2 years ago
|
async def num_tokens(input_text: str, model_name: str):
|
||
|
if model_name.startswith("pygmalion"):
|
||
|
#os.makedirs("./models/pygmalion-6b", exist_ok=True)
|
||
|
#hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")
|
||
|
#config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json")
|
||
|
global gptj_tokenizer
|
||
|
if not gptj_tokenizer:
|
||
|
gptj_tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
|
||
|
encoding = gptj_tokenizer.encode(input_text, add_special_tokens=False)
|
||
|
max_input_size = gptj_tokenizer.max_model_input_sizes
|
||
|
return len(encoding)
|
||
|
else:
|
||
|
return await estimate_num_tokens(input_text)
|
||
2 years ago
|
|
||
|
|
||
|
async def estimate_num_tokens(input_text: str):
|
||
|
return len(input_text)//4+1
|