You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
2.2 KiB
59 lines
2.2 KiB
import asyncio
|
|
import requests
|
|
import json
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class KoboldCppTextWrapper(object):
|
|
"""Base Class for koboldcpp"""
|
|
|
|
def __init__(self, endpoint_name: str, model_name: str):
|
|
self.endpoint_name = endpoint_name
|
|
self.model_name = model_name
|
|
|
|
def setup():
|
|
os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)")
|
|
os.system("apt update && apt-get install libopenblas-dev libclblast-dev libmkl-dev")
|
|
os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)")
|
|
#python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin
|
|
#python3 koboldcpp.py --smartcontext models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin
|
|
|
|
async def generate(self, prompt: str, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
|
|
# Set the API endpoint URL
|
|
endpoint = f"http://{self.endpoint_name}/api/latest/generate"
|
|
|
|
# Set the headers for the request
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
# Define your inputs
|
|
input_data = {
|
|
"prompt": prompt,
|
|
"max_context_length": 2048,
|
|
"max_length": max_new_tokens,
|
|
"temperature": temperature,
|
|
"top_k": 50,
|
|
"top_p": 0.85,
|
|
"rep_pen": 1.08,
|
|
"rep_pen_range": 1024,
|
|
"stop_sequence": ['<|endoftext|>'],
|
|
}
|
|
|
|
logger.info(f"sending request to koboldcpp. endpoint=\"{self.endpoint_name}\"")
|
|
|
|
TRIES = 30
|
|
for i in range(TRIES):
|
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout)
|
|
r_json = r.json()
|
|
logger.info(r_json)
|
|
if r.status_code == 200:
|
|
output = r_json["results"][0]["text"]
|
|
return output
|
|
elif r.status_code == 503:
|
|
logger.info(f"api is busy. waiting...")
|
|
asyncio.sleep(5)
|
|
|
|
raise ValueError(f"<ERROR> TIMEOUT / NO OUTOUT")
|
|
|
|
|