Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

94 lines
2.8 KiB

"""KoboldCpp LLM wrapper for testing purposes."""
import logging
import time
from typing import Any, List, Mapping, Optional
import json
import requests
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class KoboldCpp(LLM):
"""KoboldCpp LLM wrapper for testing purposes."""
endpoint_url: str = "http://172.16.85.10:5001/api/latest/generate"
temperature: Optional[float] = 0.8
"""The temperature to use for sampling."""
max_tokens: Optional[int] = 256
"""The maximum number of tokens to generate."""
top_p: Optional[float] = 0.90
"""The top-p value to use for sampling."""
repeat_penalty: Optional[float] = 1.1
"""The penalty to apply to repeated tokens."""
top_k: Optional[int] = 40
"""The top-k value to use for sampling."""
stop: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
# model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "KoboldCpp"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
#params = self.model_kwargs or {}
input_data = {
"prompt": prompt,
"max_context_length": 2048,
"max_length": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"rep_pen": self.repeat_penalty,
"rep_pen_range": 256,
"stop_sequence": self.stop,
}
if stop:
input_data["stop_sequence"] = stop
headers = {
"Content-Type": "application/json",
}
logger.info(f"sending request to koboldcpp.")
TRIES = 30
for i in range(TRIES):
try:
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
r_json = r.json()
except requests.exceptions.RequestException as e:
raise ValueError(f"http connection error.")
logger.info(r_json)
if r.status_code == 200:
try:
response = r_json["results"][0]["text"]
except KeyError:
raise ValueError(f"LangChain requires 'results' key in response.")
break
elif r.status_code == 503:
logger.info(f"api is busy. waiting...")
time.sleep(5)
else:
raise ValueError(f"http error. unknown response code")
for s in self.stop:
response = response.rstrip().removesuffix(s)
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}