Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

206 lines
7.1 KiB

"""KoboldCpp LLM wrapper for testing purposes."""
import asyncio
import time
import logging
from typing import Any, List, Mapping, Optional
import json
import requests
import functools
from langchain.llms.base import LLM
from langchain.schema import BaseMessage, AIMessage, HumanMessage, SystemMessage, ChatMessage
logger = logging.getLogger(__name__)
class KoboldCpp(LLM):
"""KoboldCpp LLM wrapper for testing purposes."""
endpoint_url: str = "http://172.16.33.10:5001/api/latest/generate"
temperature: Optional[float] = 0.7
"""The temperature to use for sampling."""
max_context: Optional[int] = 2048
"""The maximum context size."""
max_tokens: Optional[int] = 256
"""The maximum number of tokens to generate."""
top_p: Optional[float] = 0.90
"""The top-p value to use for sampling."""
repeat_penalty: Optional[float] = 1.15
"""The penalty to apply to repeated tokens."""
top_k: Optional[int] = 20
"""The top-k value to use for sampling."""
stop: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
# model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "KoboldCpp"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
#params = self.model_kwargs or {}
input_data = {
"prompt": prompt,
"max_context_length": self.max_context,
"max_length": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"rep_pen": self.repeat_penalty,
"rep_pen_range": 1024,
"stop_sequence": self.stop,
}
if stop:
input_data["stop_sequence"] = stop
headers = {
"Content-Type": "application/json",
}
logger.info(f"sending request to koboldcpp.")
logger.warning(f"WARNING: request is blocking. try to use llm's _acall()")
TRIES = 30
for i in range(TRIES):
try:
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
r_json = r.json()
except requests.exceptions.RequestException as e:
raise ValueError(f"http connection error.")
logger.info(r_json)
if r.status_code == 200:
try:
response = r_json["results"][0]["text"]
except KeyError:
raise ValueError(f"LangChain requires 'results' key in response.")
break
elif r.status_code == 503:
logger.info(f"api is busy. waiting...")
time.sleep(5)
else:
raise ValueError(f"http error. unknown response code")
for s in input_data["stop_sequence"]:
response = response.removesuffix(s).rstrip()
return response
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to KoboldCpp's completion endpoint asynchronuosly."""
#params = self.model_kwargs or {}
input_data = {
"prompt": prompt,
"max_context_length": self.max_context,
"max_length": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"rep_pen": self.repeat_penalty,
"rep_pen_range": 1024,
"stop_sequence": self.stop,
}
if stop:
input_data["stop_sequence"] = stop
headers = {
"Content-Type": "application/json",
}
logger.info(f"sending request to koboldcpp.")
TRIES = 60
request_timeout=20*60
for i in range(TRIES):
try:
loop = asyncio.get_running_loop()
#r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
r = await loop.run_in_executor(None, functools.partial(requests.post, self.endpoint_url, json=input_data, headers=headers, timeout=request_timeout))
#r.raise_for_status()
r_json = r.json()
except requests.exceptions.HTTPError as errh:
print ("Http Error:",errh)
await asyncio.sleep(5)
continue
except requests.exceptions.ConnectionError as errc:
print ("Error Connecting:",errc)
await asyncio.sleep(5)
continue
except requests.exceptions.Timeout as errt:
raise ValueError(f"http timeout error.")
#print ("Timeout Error:",errt)
#await asyncio.sleep(5)
#continue
except requests.exceptions.RequestException as e:
raise ValueError(f"http connection error.")
logger.info(r_json)
if r.status_code == 200:
try:
response = r_json["results"][0]["text"]
except KeyError:
raise ValueError(f"LangChain requires 'results' key in response.")
break
elif r.status_code == 503:
logger.info(f"api is busy. waiting...")
await asyncio.sleep(5)
continue
else:
raise ValueError(f"http error. unknown response code")
for s in input_data["stop_sequence"]:
response = response.removesuffix(s).rstrip()
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}
def get_num_tokens(self, text: str) -> int:
"""Estimate num tokens."""
return len(text)//4+1
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Estimate num tokens."""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if key == "content":
num_tokens += self.get_num_tokens(value)
elif key == "name":
num_tokens += tokens_per_name
num_tokens += 3
return num_tokens
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict