Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
2.8 KiB

2 years ago
"""KoboldCpp LLM wrapper for testing purposes."""
import logging
import time
from typing import Any, List, Mapping, Optional
import json
import requests
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class KoboldCpp(LLM):
"""KoboldCpp LLM wrapper for testing purposes."""
endpoint_url: str = "http://172.16.85.10:5001/api/latest/generate"
temperature: Optional[float] = 0.8
"""The temperature to use for sampling."""
max_tokens: Optional[int] = 256
"""The maximum number of tokens to generate."""
top_p: Optional[float] = 0.90
"""The top-p value to use for sampling."""
repeat_penalty: Optional[float] = 1.1
"""The penalty to apply to repeated tokens."""
top_k: Optional[int] = 40
"""The top-k value to use for sampling."""
stop: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
# model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "KoboldCpp"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
#params = self.model_kwargs or {}
input_data = {
"prompt": prompt,
"max_context_length": 2048,
"max_length": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"rep_pen": self.repeat_penalty,
"rep_pen_range": 256,
"stop_sequence": self.stop,
}
if stop:
input_data["stop_sequence"] = stop
headers = {
"Content-Type": "application/json",
}
logger.info(f"sending request to koboldcpp.")
TRIES = 30
for i in range(TRIES):
try:
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
r_json = r.json()
except requests.exceptions.RequestException as e:
raise ValueError(f"http connection error.")
logger.info(r_json)
if r.status_code == 200:
try:
response = r_json["results"][0]["text"]
except KeyError:
raise ValueError(f"LangChain requires 'results' key in response.")
break
elif r.status_code == 503:
logger.info(f"api is busy. waiting...")
time.sleep(5)
else:
raise ValueError(f"http error. unknown response code")
for s in self.stop:
response = response.rstrip().removesuffix(s)
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}