Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
7.1 KiB

2 years ago
"""KoboldCpp LLM wrapper for testing purposes."""
2 years ago
import asyncio
2 years ago
import time
2 years ago
import logging
2 years ago
from typing import Any, List, Mapping, Optional
import json
import requests
2 years ago
import functools
2 years ago
from langchain.llms.base import LLM
from langchain.schema import BaseMessage, AIMessage, HumanMessage, SystemMessage, ChatMessage
2 years ago
2 years ago
logger = logging.getLogger(__name__)
class KoboldCpp(LLM):
"""KoboldCpp LLM wrapper for testing purposes."""
2 years ago
endpoint_url: str = "http://172.16.33.10:5001/api/latest/generate"
2 years ago
temperature: Optional[float] = 0.7
2 years ago
"""The temperature to use for sampling."""
max_context: Optional[int] = 2048
"""The maximum context size."""
2 years ago
max_tokens: Optional[int] = 256
"""The maximum number of tokens to generate."""
top_p: Optional[float] = 0.90
"""The top-p value to use for sampling."""
repeat_penalty: Optional[float] = 1.15
2 years ago
"""The penalty to apply to repeated tokens."""
top_k: Optional[int] = 20
2 years ago
"""The top-k value to use for sampling."""
stop: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
# model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "KoboldCpp"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
#params = self.model_kwargs or {}
input_data = {
"prompt": prompt,
"max_context_length": self.max_context,
2 years ago
"max_length": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"rep_pen": self.repeat_penalty,
"rep_pen_range": 1024,
2 years ago
"stop_sequence": self.stop,
}
if stop:
input_data["stop_sequence"] = stop
headers = {
"Content-Type": "application/json",
}
logger.info(f"sending request to koboldcpp.")
logger.warning(f"WARNING: request is blocking. try to use llm's _acall()")
2 years ago
TRIES = 30
for i in range(TRIES):
try:
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
r_json = r.json()
except requests.exceptions.RequestException as e:
raise ValueError(f"http connection error.")
logger.info(r_json)
if r.status_code == 200:
try:
response = r_json["results"][0]["text"]
except KeyError:
raise ValueError(f"LangChain requires 'results' key in response.")
break
elif r.status_code == 503:
logger.info(f"api is busy. waiting...")
time.sleep(5)
else:
raise ValueError(f"http error. unknown response code")
2 years ago
for s in input_data["stop_sequence"]:
response = response.removesuffix(s).rstrip()
2 years ago
return response
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to KoboldCpp's completion endpoint asynchronuosly."""
#params = self.model_kwargs or {}
input_data = {
"prompt": prompt,
"max_context_length": self.max_context,
2 years ago
"max_length": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"rep_pen": self.repeat_penalty,
"rep_pen_range": 1024,
2 years ago
"stop_sequence": self.stop,
}
if stop:
input_data["stop_sequence"] = stop
headers = {
"Content-Type": "application/json",
}
logger.info(f"sending request to koboldcpp.")
TRIES = 60
request_timeout=20*60
2 years ago
for i in range(TRIES):
try:
2 years ago
loop = asyncio.get_running_loop()
#r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
r = await loop.run_in_executor(None, functools.partial(requests.post, self.endpoint_url, json=input_data, headers=headers, timeout=request_timeout))
#r.raise_for_status()
2 years ago
r_json = r.json()
except requests.exceptions.HTTPError as errh:
print ("Http Error:",errh)
await asyncio.sleep(5)
continue
except requests.exceptions.ConnectionError as errc:
print ("Error Connecting:",errc)
await asyncio.sleep(5)
continue
except requests.exceptions.Timeout as errt:
raise ValueError(f"http timeout error.")
#print ("Timeout Error:",errt)
#await asyncio.sleep(5)
#continue
2 years ago
except requests.exceptions.RequestException as e:
raise ValueError(f"http connection error.")
logger.info(r_json)
if r.status_code == 200:
try:
response = r_json["results"][0]["text"]
except KeyError:
raise ValueError(f"LangChain requires 'results' key in response.")
break
elif r.status_code == 503:
logger.info(f"api is busy. waiting...")
await asyncio.sleep(5)
continue
2 years ago
else:
raise ValueError(f"http error. unknown response code")
for s in input_data["stop_sequence"]:
response = response.removesuffix(s).rstrip()
return response
2 years ago
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}
2 years ago
def get_num_tokens(self, text: str) -> int:
"""Estimate num tokens."""
return len(text)//4+1
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Estimate num tokens."""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if key == "content":
num_tokens += self.get_num_tokens(value)
elif key == "name":
2 years ago
num_tokens += tokens_per_name
num_tokens += 3
return num_tokens
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict