|
|
@ -1,6 +1,7 @@ |
|
|
|
"""KoboldCpp LLM wrapper for testing purposes.""" |
|
|
|
import logging |
|
|
|
import asyncio |
|
|
|
import time |
|
|
|
import logging |
|
|
|
from typing import Any, List, Mapping, Optional |
|
|
|
|
|
|
|
import json |
|
|
@ -8,6 +9,8 @@ import requests |
|
|
|
|
|
|
|
from langchain.llms.base import LLM |
|
|
|
|
|
|
|
from langchain.schema import BaseMessage |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
@ -87,8 +90,93 @@ class KoboldCpp(LLM): |
|
|
|
raise ValueError(f"http error. unknown response code") |
|
|
|
for s in input_data["stop_sequence"]: |
|
|
|
response = response.removesuffix(s).rstrip() |
|
|
|
return response.lstrip() |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
|
|
|
"""Call out to KoboldCpp's completion endpoint asynchronuosly.""" |
|
|
|
|
|
|
|
#params = self.model_kwargs or {} |
|
|
|
input_data = { |
|
|
|
"prompt": prompt, |
|
|
|
"max_context_length": 2048, |
|
|
|
"max_length": self.max_tokens, |
|
|
|
"temperature": self.temperature, |
|
|
|
"top_k": self.top_k, |
|
|
|
"top_p": self.top_p, |
|
|
|
"rep_pen": self.repeat_penalty, |
|
|
|
"rep_pen_range": 256, |
|
|
|
"stop_sequence": self.stop, |
|
|
|
} |
|
|
|
|
|
|
|
if stop: |
|
|
|
input_data["stop_sequence"] = stop |
|
|
|
|
|
|
|
headers = { |
|
|
|
"Content-Type": "application/json", |
|
|
|
} |
|
|
|
|
|
|
|
logger.info(f"sending request to koboldcpp.") |
|
|
|
|
|
|
|
TRIES = 30 |
|
|
|
for i in range(TRIES): |
|
|
|
try: |
|
|
|
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600) |
|
|
|
r_json = r.json() |
|
|
|
except requests.exceptions.RequestException as e: |
|
|
|
raise ValueError(f"http connection error.") |
|
|
|
logger.info(r_json) |
|
|
|
if r.status_code == 200: |
|
|
|
try: |
|
|
|
response = r_json["results"][0]["text"] |
|
|
|
except KeyError: |
|
|
|
raise ValueError(f"LangChain requires 'results' key in response.") |
|
|
|
break |
|
|
|
elif r.status_code == 503: |
|
|
|
logger.info(f"api is busy. waiting...") |
|
|
|
await asyncio.sleep(5) |
|
|
|
else: |
|
|
|
raise ValueError(f"http error. unknown response code") |
|
|
|
for s in input_data["stop_sequence"]: |
|
|
|
response = response.removesuffix(s).rstrip() |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def _identifying_params(self) -> Mapping[str, Any]: |
|
|
|
return {} |
|
|
|
|
|
|
|
def get_num_tokens(self, text: str) -> int: |
|
|
|
"""Estimate num tokens.""" |
|
|
|
return len(text)//4+1 |
|
|
|
|
|
|
|
|
|
|
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: |
|
|
|
"""Estimate num tokens.""" |
|
|
|
tokens_per_message = 3 |
|
|
|
tokens_per_name = 1 |
|
|
|
num_tokens = 0 |
|
|
|
messages_dict = [_convert_message_to_dict(m) for m in messages] |
|
|
|
for message in messages_dict: |
|
|
|
num_tokens += tokens_per_message |
|
|
|
for key, value in message.items(): |
|
|
|
num_tokens += len(self.get_num_tokens(value)) |
|
|
|
if key == "name": |
|
|
|
num_tokens += tokens_per_name |
|
|
|
num_tokens += 3 |
|
|
|
return num_tokens |
|
|
|
|
|
|
|
def _convert_message_to_dict(message: BaseMessage) -> dict: |
|
|
|
if isinstance(message, ChatMessage): |
|
|
|
message_dict = {"role": message.role, "content": message.content} |
|
|
|
elif isinstance(message, HumanMessage): |
|
|
|
message_dict = {"role": "user", "content": message.content} |
|
|
|
elif isinstance(message, AIMessage): |
|
|
|
message_dict = {"role": "assistant", "content": message.content} |
|
|
|
elif isinstance(message, SystemMessage): |
|
|
|
message_dict = {"role": "system", "content": message.content} |
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown type {message}") |
|
|
|
if "name" in message.additional_kwargs: |
|
|
|
message_dict["name"] = message.additional_kwargs["name"] |
|
|
|
return message_dict |
|
|
|