Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
2.3 KiB

import asyncio
import os, tempfile
import logging
import json
import requests
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import io
import base64
from PIL import Image, PngImagePlugin
from .pygmalion_helpers import get_full_prompt, num_tokens
logger = logging.getLogger(__name__)
def setup():
os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)")
os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)")
#python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin
async def generate_sync(
prompt: str,
api_key: str,
bot,
typing_fn,
api_endpoint = "pygmalion-6b"
):
# Set the API endpoint URL
endpoint = f"http://172.16.85.10:5001/api/latest/generate"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
}
max_new_tokens = 120
prompt_num_tokens = await num_tokens(prompt)
# Define your inputs
input_data = {
"prompt": prompt,
"max_context_length": 2048,
"max_length": max_new_tokens,
"temperature": bot.temperature,
"top_k": 0,
"top_p": 0,
"rep_pen": 1.08,
"rep_pen_range": 1024,
"quiet": True,
}
logger.info(f"sending request to koboldcpp")
# Make the request
try:
r = requests.post(endpoint, json=input_data, headers=headers, timeout=360)
except requests.exceptions.RequestException as e:
raise ValueError(f"<HTTP ERROR> {e}")
r_json = r.json()
logger.info(r_json)
if r.status_code == 200:
reply = r_json["results"][0]["text"]
idx = reply.find(f"\nYou:")
if idx != -1:
reply = reply[:idx].strip()
else:
reply = reply.removesuffix('<|endoftext|>').strip()
reply = reply.replace(f"\n{bot.name}: ", " ")
reply = reply.replace(f"\n<BOT>: ", " ")
reply = reply.replace(f"<BOT>", "{bot.name}")
reply = reply.replace(f"<USER>", "You")
return reply.strip()
else:
raise ValueError(f"<ERROR>")
async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn):
pass