diff --git a/matrix_pygmalion_bot/bot/ai/langchain.py b/matrix_pygmalion_bot/bot/ai/langchain.py index 0326331..a5509df 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain.py +++ b/matrix_pygmalion_bot/bot/ai/langchain.py @@ -73,7 +73,7 @@ class AI(object): else: memory = self.rooms[message.room_id]["memory"] print(f"memory: {memory.load_memory_variables({})}") - print(f"memory has an estimated {estimate_num_tokens(memory.buffer)} number of tokens") + print(f"memory has an estimated {self.llm_chat.get_num_tokens(memory.buffer)} number of tokens") return memory @@ -117,7 +117,7 @@ class AI(object): llm=self.llm_chat, prompt=PromptTemplate.from_template(prompt_template), ) - output = chain.run(message.message) + output = await chain.arun(message.message) return output.strip() async def generate_roleplay(self, message, reply_fn, typing_fn): @@ -147,14 +147,11 @@ class AI(object): stop = ['<|endoftext|>', f"\n{message.user_name}:"] print(f"Message is: \"{message.message}\"") - output = chain.run({"input":message.message, "stop": stop}) + output = await chain.arun({"input":message.message, "stop": stop}) return output.strip() -def estimate_num_tokens(input_text: str): - return len(input_text)//4+1 - def replace_all(text, dic): for i, j in dic.items(): text = text.replace(i, j) diff --git a/matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py b/matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py index 8b6091f..6f75d63 100644 --- a/matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py +++ b/matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py @@ -1,6 +1,7 @@ """KoboldCpp LLM wrapper for testing purposes.""" -import logging +import asyncio import time +import logging from typing import Any, List, Mapping, Optional import json @@ -8,6 +9,8 @@ import requests from langchain.llms.base import LLM +from langchain.schema import BaseMessage + logger = logging.getLogger(__name__) @@ -87,8 +90,93 @@ class KoboldCpp(LLM): raise ValueError(f"http error. unknown response code") for s in input_data["stop_sequence"]: response = response.removesuffix(s).rstrip() - return response.lstrip() + return response + + + async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to KoboldCpp's completion endpoint asynchronuosly.""" + + #params = self.model_kwargs or {} + input_data = { + "prompt": prompt, + "max_context_length": 2048, + "max_length": self.max_tokens, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + "rep_pen": self.repeat_penalty, + "rep_pen_range": 256, + "stop_sequence": self.stop, + } + + if stop: + input_data["stop_sequence"] = stop + + headers = { + "Content-Type": "application/json", + } + + logger.info(f"sending request to koboldcpp.") + + TRIES = 30 + for i in range(TRIES): + try: + r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600) + r_json = r.json() + except requests.exceptions.RequestException as e: + raise ValueError(f"http connection error.") + logger.info(r_json) + if r.status_code == 200: + try: + response = r_json["results"][0]["text"] + except KeyError: + raise ValueError(f"LangChain requires 'results' key in response.") + break + elif r.status_code == 503: + logger.info(f"api is busy. waiting...") + await asyncio.sleep(5) + else: + raise ValueError(f"http error. unknown response code") + for s in input_data["stop_sequence"]: + response = response.removesuffix(s).rstrip() + return response + @property def _identifying_params(self) -> Mapping[str, Any]: return {} + + def get_num_tokens(self, text: str) -> int: + """Estimate num tokens.""" + return len(text)//4+1 + + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Estimate num tokens.""" + tokens_per_message = 3 + tokens_per_name = 1 + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(self.get_num_tokens(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + return num_tokens + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict diff --git a/matrix_pygmalion_bot/main.py b/matrix_pygmalion_bot/main.py index 415e9d0..8fca87a 100644 --- a/matrix_pygmalion_bot/main.py +++ b/matrix_pygmalion_bot/main.py @@ -64,13 +64,13 @@ async def main() -> None: if sys.version_info[0] == 3 and sys.version_info[1] < 11: tasks = [] for bot in bots: - task = asyncio.create_task(bot.connection.sync_forever(timeout=0, full_state=True)) # timeout 30000 + task = asyncio.create_task(bot.connection.sync_forever(timeout=30000, full_state=True)) tasks.append(task) await asyncio.gather(*tasks) else: async with asyncio.TaskGroup() as tg: for bot in bots: - task = tg.create_task(bot.connection.sync_forever(timeout=0, full_state=True)) # timeout 30000 + task = tg.create_task(bot.connection.sync_forever(timeout=30000, full_state=True)) except Exception: print(traceback.format_exc())