You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
3.4 KiB
109 lines
3.4 KiB
import asyncio
|
|
import os, tempfile
|
|
import logging
|
|
|
|
import json
|
|
import requests
|
|
|
|
from transformers import AutoTokenizer, AutoConfig
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
import io
|
|
import base64
|
|
from PIL import Image, PngImagePlugin
|
|
|
|
from .pygmalion_helpers import get_full_prompt, num_tokens
|
|
#from .llama_helpers import get_full_prompt, num_tokens
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def setup():
|
|
os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)")
|
|
os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)")
|
|
#python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin
|
|
|
|
async def generate_sync(
|
|
prompt: str,
|
|
api_key: str,
|
|
bot,
|
|
typing_fn,
|
|
api_endpoint = "pygmalion-6b"
|
|
):
|
|
# Set the API endpoint URL
|
|
endpoint = f"http://172.16.85.10:5001/api/latest/generate"
|
|
|
|
# Set the headers for the request
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
max_new_tokens = 200
|
|
prompt_num_tokens = await num_tokens(prompt)
|
|
|
|
# Define your inputs
|
|
input_data = {
|
|
"prompt": prompt,
|
|
"max_context_length": 2048,
|
|
"max_length": max_new_tokens,
|
|
"temperature": bot.temperature,
|
|
"top_k": 50,
|
|
"top_p": 0.85,
|
|
"rep_pen": 1.08,
|
|
"rep_pen_range": 1024,
|
|
"quiet": True,
|
|
}
|
|
|
|
logger.info(f"sending request to koboldcpp")
|
|
|
|
TIMEOUT = 360
|
|
DELAY = 5
|
|
tokens = 0
|
|
complete = False
|
|
complete_reply = ""
|
|
for i in range(TIMEOUT//DELAY):
|
|
input_data["max_length"] = 16 # pseudo streaming
|
|
# Make the request
|
|
try:
|
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=360)
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"<ERROR> HTTP ERROR {e}")
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
if r.status_code == 200:
|
|
partial_reply = r_json["results"][0]["text"]
|
|
input_data["prompt"] += partial_reply
|
|
complete_reply += partial_reply
|
|
tokens += input_data["max_length"]
|
|
await typing_fn()
|
|
if not partial_reply or tokens >= max_new_tokens:
|
|
complete = True
|
|
break
|
|
for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>']:
|
|
idx = complete_reply.find(t)
|
|
if idx != -1:
|
|
complete_reply = complete_reply[:idx].strip()
|
|
complete = True
|
|
break
|
|
if complete:
|
|
break
|
|
elif r.status_code == 503:
|
|
#model busy
|
|
await asyncio.sleep(DELAY)
|
|
else:
|
|
raise ValueError(f"<ERROR>")
|
|
|
|
if complete_reply:
|
|
complete_reply = complete_reply.removesuffix('<|endoftext|>')
|
|
complete_reply = complete_reply.replace(f"<BOT>", f"{bot.name}")
|
|
complete_reply = complete_reply.replace(f"<USER>", f"You")
|
|
complete_reply = complete_reply.replace(f"### Assistant", f"{bot.name}")
|
|
complete_reply = complete_reply.replace(f"\n{bot.name}: ", " ")
|
|
return complete_reply.strip()
|
|
else:
|
|
raise ValueError(f"<ERROR> Timeout")
|
|
|
|
|
|
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
|
|
pass
|
|
|
|
|