You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
6.3 KiB
186 lines
6.3 KiB
"""KoboldCpp LLM wrapper for testing purposes."""
|
|
import asyncio
|
|
import time
|
|
import logging
|
|
from typing import Any, List, Mapping, Optional
|
|
|
|
import json
|
|
import requests
|
|
import functools
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
from langchain.schema import BaseMessage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KoboldCpp(LLM):
|
|
"""KoboldCpp LLM wrapper for testing purposes."""
|
|
|
|
endpoint_url: str = "http://172.16.85.10:5001/api/latest/generate"
|
|
|
|
temperature: Optional[float] = 0.8
|
|
"""The temperature to use for sampling."""
|
|
|
|
max_tokens: Optional[int] = 256
|
|
"""The maximum number of tokens to generate."""
|
|
|
|
top_p: Optional[float] = 0.90
|
|
"""The top-p value to use for sampling."""
|
|
|
|
repeat_penalty: Optional[float] = 1.1
|
|
"""The penalty to apply to repeated tokens."""
|
|
|
|
top_k: Optional[int] = 40
|
|
"""The top-k value to use for sampling."""
|
|
|
|
stop: Optional[List[str]] = []
|
|
"""A list of strings to stop generation when encountered."""
|
|
|
|
# model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "KoboldCpp"
|
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
|
|
|
#params = self.model_kwargs or {}
|
|
input_data = {
|
|
"prompt": prompt,
|
|
"max_context_length": 2048,
|
|
"max_length": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
"rep_pen": self.repeat_penalty,
|
|
"rep_pen_range": 256,
|
|
"stop_sequence": self.stop,
|
|
}
|
|
|
|
if stop:
|
|
input_data["stop_sequence"] = stop
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
logger.info(f"sending request to koboldcpp.")
|
|
logger.warning(f"WARNING: request is blocking. try to use llm's _acall()")
|
|
|
|
TRIES = 30
|
|
for i in range(TRIES):
|
|
try:
|
|
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
|
|
r_json = r.json()
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"http connection error.")
|
|
logger.info(r_json)
|
|
if r.status_code == 200:
|
|
try:
|
|
response = r_json["results"][0]["text"]
|
|
except KeyError:
|
|
raise ValueError(f"LangChain requires 'results' key in response.")
|
|
break
|
|
elif r.status_code == 503:
|
|
logger.info(f"api is busy. waiting...")
|
|
time.sleep(5)
|
|
else:
|
|
raise ValueError(f"http error. unknown response code")
|
|
for s in input_data["stop_sequence"]:
|
|
response = response.removesuffix(s).rstrip()
|
|
return response
|
|
|
|
|
|
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
"""Call out to KoboldCpp's completion endpoint asynchronuosly."""
|
|
|
|
#params = self.model_kwargs or {}
|
|
input_data = {
|
|
"prompt": prompt,
|
|
"max_context_length": 2048,
|
|
"max_length": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
"rep_pen": self.repeat_penalty,
|
|
"rep_pen_range": 256,
|
|
"stop_sequence": self.stop,
|
|
}
|
|
|
|
if stop:
|
|
input_data["stop_sequence"] = stop
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
logger.info(f"sending request to koboldcpp.")
|
|
|
|
TRIES = 30
|
|
for i in range(TRIES):
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
#r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
|
|
r = await loop.run_in_executor(None, functools.partial(requests.post, self.endpoint_url, json=input_data, headers=headers, timeout=600))
|
|
r_json = r.json()
|
|
except requests.exceptions.RequestException as e:
|
|
raise ValueError(f"http connection error.")
|
|
logger.info(r_json)
|
|
if r.status_code == 200:
|
|
try:
|
|
response = r_json["results"][0]["text"]
|
|
except KeyError:
|
|
raise ValueError(f"LangChain requires 'results' key in response.")
|
|
break
|
|
elif r.status_code == 503:
|
|
logger.info(f"api is busy. waiting...")
|
|
await asyncio.sleep(5)
|
|
else:
|
|
raise ValueError(f"http error. unknown response code")
|
|
for s in input_data["stop_sequence"]:
|
|
response = response.removesuffix(s).rstrip()
|
|
return response
|
|
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
return {}
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
"""Estimate num tokens."""
|
|
return len(text)//4+1
|
|
|
|
|
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
"""Estimate num tokens."""
|
|
tokens_per_message = 3
|
|
tokens_per_name = 1
|
|
num_tokens = 0
|
|
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
|
for message in messages_dict:
|
|
num_tokens += tokens_per_message
|
|
for key, value in message.items():
|
|
num_tokens += len(self.get_num_tokens(value))
|
|
if key == "name":
|
|
num_tokens += tokens_per_name
|
|
num_tokens += 3
|
|
return num_tokens
|
|
|
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|
if isinstance(message, ChatMessage):
|
|
message_dict = {"role": message.role, "content": message.content}
|
|
elif isinstance(message, HumanMessage):
|
|
message_dict = {"role": "user", "content": message.content}
|
|
elif isinstance(message, AIMessage):
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
elif isinstance(message, SystemMessage):
|
|
message_dict = {"role": "system", "content": message.content}
|
|
else:
|
|
raise ValueError(f"Got unknown type {message}")
|
|
if "name" in message.additional_kwargs:
|
|
message_dict["name"] = message.additional_kwargs["name"]
|
|
return message_dict
|
|
|