You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
2.8 KiB
94 lines
2.8 KiB
"""KoboldCpp LLM wrapper for testing purposes."""
|
|
import logging
|
|
import time
|
|
from typing import Any, List, Mapping, Optional
|
|
|
|
import json
|
|
import requests
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KoboldCpp(LLM):
|
|
"""KoboldCpp LLM wrapper for testing purposes."""
|
|
|
|
endpoint_url: str = "http://172.16.85.10:5001/api/latest/generate"
|
|
|
|
temperature: Optional[float] = 0.8
|
|
"""The temperature to use for sampling."""
|
|
|
|
max_tokens: Optional[int] = 256
|
|
"""The maximum number of tokens to generate."""
|
|
|
|
top_p: Optional[float] = 0.90
|
|
"""The top-p value to use for sampling."""
|
|
|
|
repeat_penalty: Optional[float] = 1.1
|
|
"""The penalty to apply to repeated tokens."""
|
|
|
|
top_k: Optional[int] = 40
|
|
"""The top-k value to use for sampling."""
|
|
|
|
stop: Optional[List[str]] = []
|
|
"""A list of strings to stop generation when encountered."""
|
|
|
|
# model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "KoboldCpp"
|
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
|
|
|
#params = self.model_kwargs or {}
|
|
input_data = {
|
|
"prompt": prompt,
|
|
"max_context_length": 2048,
|
|
"max_length": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
"rep_pen": self.repeat_penalty,
|
|
"rep_pen_range": 256,
|
|
"stop_sequence": self.stop,
|
|
}
|
|
|
|
if stop:
|
|
input_data["stop_sequence"] = stop
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
logger.info(f"sending request to koboldcpp.")
|
|
|
|
TRIES = 30
|
|
for i in range(TRIES):
|
|
try:
|
|
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
|
|
r_json = r.json()
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"http connection error.")
|
|
logger.info(r_json)
|
|
if r.status_code == 200:
|
|
try:
|
|
response = r_json["results"][0]["text"]
|
|
except KeyError:
|
|
raise ValueError(f"LangChain requires 'results' key in response.")
|
|
break
|
|
elif r.status_code == 503:
|
|
logger.info(f"api is busy. waiting...")
|
|
time.sleep(5)
|
|
else:
|
|
raise ValueError(f"http error. unknown response code")
|
|
for s in input_data["stop_sequence"]:
|
|
response = response.removesuffix(s).rstrip()
|
|
return response.lstrip()
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
return {}
|
|
|