"""KoboldCpp LLM wrapper for testing purposes.""" import logging import time from typing import Any, List, Mapping, Optional import json import requests from langchain.llms.base import LLM logger = logging.getLogger(__name__) class KoboldCpp(LLM): """KoboldCpp LLM wrapper for testing purposes.""" endpoint_url: str = "http://172.16.85.10:5001/api/latest/generate" temperature: Optional[float] = 0.8 """The temperature to use for sampling.""" max_tokens: Optional[int] = 256 """The maximum number of tokens to generate.""" top_p: Optional[float] = 0.90 """The top-p value to use for sampling.""" repeat_penalty: Optional[float] = 1.1 """The penalty to apply to repeated tokens.""" top_k: Optional[int] = 40 """The top-k value to use for sampling.""" stop: Optional[List[str]] = [] """A list of strings to stop generation when encountered.""" # model_kwargs: Dict[str, Any] = Field(default_factory=dict) @property def _llm_type(self) -> str: """Return type of llm.""" return "KoboldCpp" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" #params = self.model_kwargs or {} input_data = { "prompt": prompt, "max_context_length": 2048, "max_length": self.max_tokens, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, "rep_pen": self.repeat_penalty, "rep_pen_range": 256, "stop_sequence": self.stop, } if stop: input_data["stop_sequence"] = stop headers = { "Content-Type": "application/json", } logger.info(f"sending request to koboldcpp.") TRIES = 30 for i in range(TRIES): try: r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600) r_json = r.json() except requests.exceptions.RequestException as e: raise ValueError(f"http connection error.") logger.info(r_json) if r.status_code == 200: try: response = r_json["results"][0]["text"] except KeyError: raise ValueError(f"LangChain requires 'results' key in response.") break elif r.status_code == 503: logger.info(f"api is busy. waiting...") time.sleep(5) else: raise ValueError(f"http error. unknown response code") for s in self.stop: response = response.rstrip().removesuffix(s) return response @property def _identifying_params(self) -> Mapping[str, Any]: return {}