"""KoboldCpp LLM wrapper for testing purposes.""" import asyncio import time import logging from typing import Any, List, Mapping, Optional import json import requests from langchain.llms.base import LLM from langchain.schema import BaseMessage logger = logging.getLogger(__name__) class KoboldCpp(LLM): """KoboldCpp LLM wrapper for testing purposes.""" endpoint_url: str = "http://172.16.85.10:5001/api/latest/generate" temperature: Optional[float] = 0.8 """The temperature to use for sampling.""" max_tokens: Optional[int] = 256 """The maximum number of tokens to generate.""" top_p: Optional[float] = 0.90 """The top-p value to use for sampling.""" repeat_penalty: Optional[float] = 1.1 """The penalty to apply to repeated tokens.""" top_k: Optional[int] = 40 """The top-k value to use for sampling.""" stop: Optional[List[str]] = [] """A list of strings to stop generation when encountered.""" # model_kwargs: Dict[str, Any] = Field(default_factory=dict) @property def _llm_type(self) -> str: """Return type of llm.""" return "KoboldCpp" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" #params = self.model_kwargs or {} input_data = { "prompt": prompt, "max_context_length": 2048, "max_length": self.max_tokens, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, "rep_pen": self.repeat_penalty, "rep_pen_range": 256, "stop_sequence": self.stop, } if stop: input_data["stop_sequence"] = stop headers = { "Content-Type": "application/json", } logger.info(f"sending request to koboldcpp.") logger.warning(f"WARNING: request is blocking. try to use llm's _acall()") TRIES = 30 for i in range(TRIES): try: r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600) r_json = r.json() except requests.exceptions.RequestException as e: raise ValueError(f"http connection error.") logger.info(r_json) if r.status_code == 200: try: response = r_json["results"][0]["text"] except KeyError: raise ValueError(f"LangChain requires 'results' key in response.") break elif r.status_code == 503: logger.info(f"api is busy. waiting...") time.sleep(5) else: raise ValueError(f"http error. unknown response code") for s in input_data["stop_sequence"]: response = response.removesuffix(s).rstrip() return response async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to KoboldCpp's completion endpoint asynchronuosly.""" #params = self.model_kwargs or {} input_data = { "prompt": prompt, "max_context_length": 2048, "max_length": self.max_tokens, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, "rep_pen": self.repeat_penalty, "rep_pen_range": 256, "stop_sequence": self.stop, } if stop: input_data["stop_sequence"] = stop headers = { "Content-Type": "application/json", } logger.info(f"sending request to koboldcpp.") TRIES = 30 for i in range(TRIES): try: r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600) r_json = r.json() except requests.exceptions.RequestException as e: raise ValueError(f"http connection error.") logger.info(r_json) if r.status_code == 200: try: response = r_json["results"][0]["text"] except KeyError: raise ValueError(f"LangChain requires 'results' key in response.") break elif r.status_code == 503: logger.info(f"api is busy. waiting...") await asyncio.sleep(5) else: raise ValueError(f"http error. unknown response code") for s in input_data["stop_sequence"]: response = response.removesuffix(s).rstrip() return response @property def _identifying_params(self) -> Mapping[str, Any]: return {} def get_num_tokens(self, text: str) -> int: """Estimate num tokens.""" return len(text)//4+1 def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: """Estimate num tokens.""" tokens_per_message = 3 tokens_per_name = 1 num_tokens = 0 messages_dict = [_convert_message_to_dict(m) for m in messages] for message in messages_dict: num_tokens += tokens_per_message for key, value in message.items(): num_tokens += len(self.get_num_tokens(value)) if key == "name": num_tokens += tokens_per_name num_tokens += 3 return num_tokens def _convert_message_to_dict(message: BaseMessage) -> dict: if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] return message_dict