Browse Source

langchain async

master
Hendrik Langer 2 years ago
parent
commit
00a2a65510
  1. 9
      matrix_pygmalion_bot/bot/ai/langchain.py
  2. 92
      matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py
  3. 4
      matrix_pygmalion_bot/main.py

9
matrix_pygmalion_bot/bot/ai/langchain.py

@ -73,7 +73,7 @@ class AI(object):
else: else:
memory = self.rooms[message.room_id]["memory"] memory = self.rooms[message.room_id]["memory"]
print(f"memory: {memory.load_memory_variables({})}") print(f"memory: {memory.load_memory_variables({})}")
print(f"memory has an estimated {estimate_num_tokens(memory.buffer)} number of tokens") print(f"memory has an estimated {self.llm_chat.get_num_tokens(memory.buffer)} number of tokens")
return memory return memory
@ -117,7 +117,7 @@ class AI(object):
llm=self.llm_chat, llm=self.llm_chat,
prompt=PromptTemplate.from_template(prompt_template), prompt=PromptTemplate.from_template(prompt_template),
) )
output = chain.run(message.message) output = await chain.arun(message.message)
return output.strip() return output.strip()
async def generate_roleplay(self, message, reply_fn, typing_fn): async def generate_roleplay(self, message, reply_fn, typing_fn):
@ -147,14 +147,11 @@ class AI(object):
stop = ['<|endoftext|>', f"\n{message.user_name}:"] stop = ['<|endoftext|>', f"\n{message.user_name}:"]
print(f"Message is: \"{message.message}\"") print(f"Message is: \"{message.message}\"")
output = chain.run({"input":message.message, "stop": stop}) output = await chain.arun({"input":message.message, "stop": stop})
return output.strip() return output.strip()
def estimate_num_tokens(input_text: str):
return len(input_text)//4+1
def replace_all(text, dic): def replace_all(text, dic):
for i, j in dic.items(): for i, j in dic.items():
text = text.replace(i, j) text = text.replace(i, j)

92
matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py

@ -1,6 +1,7 @@
"""KoboldCpp LLM wrapper for testing purposes.""" """KoboldCpp LLM wrapper for testing purposes."""
import logging import asyncio
import time import time
import logging
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
import json import json
@ -8,6 +9,8 @@ import requests
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema import BaseMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -87,8 +90,93 @@ class KoboldCpp(LLM):
raise ValueError(f"http error. unknown response code") raise ValueError(f"http error. unknown response code")
for s in input_data["stop_sequence"]: for s in input_data["stop_sequence"]:
response = response.removesuffix(s).rstrip() response = response.removesuffix(s).rstrip()
return response.lstrip() return response
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to KoboldCpp's completion endpoint asynchronuosly."""
#params = self.model_kwargs or {}
input_data = {
"prompt": prompt,
"max_context_length": 2048,
"max_length": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"rep_pen": self.repeat_penalty,
"rep_pen_range": 256,
"stop_sequence": self.stop,
}
if stop:
input_data["stop_sequence"] = stop
headers = {
"Content-Type": "application/json",
}
logger.info(f"sending request to koboldcpp.")
TRIES = 30
for i in range(TRIES):
try:
r = requests.post(self.endpoint_url, json=input_data, headers=headers, timeout=600)
r_json = r.json()
except requests.exceptions.RequestException as e:
raise ValueError(f"http connection error.")
logger.info(r_json)
if r.status_code == 200:
try:
response = r_json["results"][0]["text"]
except KeyError:
raise ValueError(f"LangChain requires 'results' key in response.")
break
elif r.status_code == 503:
logger.info(f"api is busy. waiting...")
await asyncio.sleep(5)
else:
raise ValueError(f"http error. unknown response code")
for s in input_data["stop_sequence"]:
response = response.removesuffix(s).rstrip()
return response
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return {} return {}
def get_num_tokens(self, text: str) -> int:
"""Estimate num tokens."""
return len(text)//4+1
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Estimate num tokens."""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(self.get_num_tokens(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
return num_tokens
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict

4
matrix_pygmalion_bot/main.py

@ -64,13 +64,13 @@ async def main() -> None:
if sys.version_info[0] == 3 and sys.version_info[1] < 11: if sys.version_info[0] == 3 and sys.version_info[1] < 11:
tasks = [] tasks = []
for bot in bots: for bot in bots:
task = asyncio.create_task(bot.connection.sync_forever(timeout=0, full_state=True)) # timeout 30000 task = asyncio.create_task(bot.connection.sync_forever(timeout=30000, full_state=True))
tasks.append(task) tasks.append(task)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
else: else:
async with asyncio.TaskGroup() as tg: async with asyncio.TaskGroup() as tg:
for bot in bots: for bot in bots:
task = tg.create_task(bot.connection.sync_forever(timeout=0, full_state=True)) # timeout 30000 task = tg.create_task(bot.connection.sync_forever(timeout=30000, full_state=True))
except Exception: except Exception:
print(traceback.format_exc()) print(traceback.format_exc())

Loading…
Cancel
Save