Hendrik Langer
2 years ago
4 changed files with 181 additions and 68 deletions
@ -0,0 +1,84 @@ |
|||||
|
import asyncio |
||||
|
import os, tempfile |
||||
|
import logging |
||||
|
|
||||
|
import json |
||||
|
import requests |
||||
|
|
||||
|
from transformers import AutoTokenizer, AutoConfig |
||||
|
from huggingface_hub import hf_hub_download |
||||
|
|
||||
|
import io |
||||
|
import base64 |
||||
|
from PIL import Image, PngImagePlugin |
||||
|
|
||||
|
from .pygmalion_helpers import get_full_prompt, num_tokens |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
def setup(): |
||||
|
os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)") |
||||
|
os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)") |
||||
|
#python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin |
||||
|
|
||||
|
async def generate_sync( |
||||
|
prompt: str, |
||||
|
api_key: str, |
||||
|
bot, |
||||
|
typing_fn, |
||||
|
api_endpoint = "pygmalion-6b" |
||||
|
): |
||||
|
# Set the API endpoint URL |
||||
|
endpoint = f"http://172.16.85.10:5001/api/latest/generate" |
||||
|
|
||||
|
# Set the headers for the request |
||||
|
headers = { |
||||
|
"Content-Type": "application/json", |
||||
|
} |
||||
|
|
||||
|
max_new_tokens = 120 |
||||
|
prompt_num_tokens = await num_tokens(prompt) |
||||
|
|
||||
|
# Define your inputs |
||||
|
input_data = { |
||||
|
"prompt": prompt, |
||||
|
"max_context_length": 2048, |
||||
|
"max_length": max_new_tokens, |
||||
|
"temperature": bot.temperature, |
||||
|
"top_k": 0, |
||||
|
"top_p": 0, |
||||
|
"rep_pen": 1.08, |
||||
|
"rep_pen_range": 1024, |
||||
|
"quiet": True, |
||||
|
} |
||||
|
|
||||
|
logger.info(f"sending request to koboldcpp") |
||||
|
|
||||
|
# Make the request |
||||
|
try: |
||||
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=360) |
||||
|
except requests.exceptions.RequestException as e: |
||||
|
raise ValueError(f"<HTTP ERROR> {e}") |
||||
|
r_json = r.json() |
||||
|
logger.info(r_json) |
||||
|
|
||||
|
if r.status_code == 200: |
||||
|
reply = r_json["results"][0]["text"] |
||||
|
idx = reply.find(f"\nYou:") |
||||
|
if idx != -1: |
||||
|
reply = reply[:idx].strip() |
||||
|
else: |
||||
|
reply = reply.removesuffix('<|endoftext|>').strip() |
||||
|
reply = reply.replace(f"\n{bot.name}: ", " ") |
||||
|
reply = reply.replace(f"\n<BOT>: ", " ") |
||||
|
reply = reply.replace(f"<BOT>", "{bot.name}") |
||||
|
reply = reply.replace(f"<USER>", "You") |
||||
|
return reply.strip() |
||||
|
else: |
||||
|
raise ValueError(f"<ERROR>") |
||||
|
|
||||
|
|
||||
|
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn): |
||||
|
pass |
||||
|
|
@ -0,0 +1,82 @@ |
|||||
|
import asyncio |
||||
|
import os, tempfile |
||||
|
import logging |
||||
|
|
||||
|
import json |
||||
|
import requests |
||||
|
|
||||
|
from transformers import AutoTokenizer, AutoConfig |
||||
|
from huggingface_hub import hf_hub_download |
||||
|
|
||||
|
import io |
||||
|
import base64 |
||||
|
from PIL import Image, PngImagePlugin |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
async def get_full_prompt(simple_prompt: str, bot, chat_history): |
||||
|
|
||||
|
# Prompt without history |
||||
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n" |
||||
|
prompt += "Scenario: " + bot.scenario + "\n" |
||||
|
prompt += "<START>" + "\n" |
||||
|
#prompt += bot.name + ": " + bot.greeting + "\n" |
||||
|
prompt += "You: " + simple_prompt + "\n" |
||||
|
prompt += bot.name + ":" |
||||
|
|
||||
|
MAX_TOKENS = 2048 |
||||
|
max_new_tokens = 200 |
||||
|
total_num_tokens = await num_tokens(prompt) |
||||
|
visible_history = [] |
||||
|
current_message = True |
||||
|
for key, chat_item in reversed(chat_history.chat_history.items()): |
||||
|
if current_message: |
||||
|
current_message = False |
||||
|
continue |
||||
|
if chat_item.message["en"].startswith('!begin'): |
||||
|
break |
||||
|
if chat_item.message["en"].startswith('!'): |
||||
|
continue |
||||
|
if chat_item.message["en"].startswith('<ERROR>'): |
||||
|
continue |
||||
|
#if chat_item.message["en"] == bot.greeting: |
||||
|
# continue |
||||
|
if chat_item.num_tokens == None: |
||||
|
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) |
||||
|
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? |
||||
|
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") |
||||
|
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: |
||||
|
visible_history.append(chat_item) |
||||
|
total_num_tokens += chat_item.num_tokens |
||||
|
else: |
||||
|
break |
||||
|
visible_history = reversed(visible_history) |
||||
|
|
||||
|
prompt = bot.name + "'s Persona: " + bot.persona + "\n" |
||||
|
prompt += "Scenario: " + bot.scenario + "\n" |
||||
|
prompt += "<START>" + "\n" |
||||
|
#prompt += bot.name + ": " + bot.greeting + "\n" |
||||
|
for chat_item in visible_history: |
||||
|
if chat_item.is_own_message: |
||||
|
prompt += bot.name + ": " + chat_item.message["en"] + "\n" |
||||
|
else: |
||||
|
prompt += "You" + ": " + chat_item.message["en"] + "\n" |
||||
|
prompt += "You: " + simple_prompt + "\n" |
||||
|
prompt += bot.name + ":" |
||||
|
|
||||
|
return prompt |
||||
|
|
||||
|
|
||||
|
async def num_tokens(input_text: str): |
||||
|
# os.makedirs("./models/pygmalion-6b", exist_ok=True) |
||||
|
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") |
||||
|
# config = AutoConfig.from_pretrained("./models/pygmalion-6b/config.json") |
||||
|
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
||||
|
encoding = tokenizer.encode(input_text, add_special_tokens=False) |
||||
|
max_input_size = tokenizer.max_model_input_sizes |
||||
|
return len(encoding) |
||||
|
|
||||
|
|
||||
|
async def estimate_num_tokens(input_text: str): |
||||
|
return len(input_text)//4+1 |
Loading…
Reference in new issue