Browse Source

more work on langchain memory

master
Hendrik Langer 2 years ago
parent
commit
ee1ffbac6b
  1. 81
      matrix_pygmalion_bot/bot/ai/langchain.py
  2. 167
      matrix_pygmalion_bot/bot/ai/langchain_memory.py
  3. 1
      matrix_pygmalion_bot/bot/ai/prompts.py
  4. 9
      matrix_pygmalion_bot/bot/core.py
  5. 43
      matrix_pygmalion_bot/bot/utilities/messages.py
  6. 7
      matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py

81
matrix_pygmalion_bot/bot/ai/langchain.py

@ -1,7 +1,7 @@
import asyncio import asyncio
import os, time import os, time
from .prompts import * from .prompts import *
#from .langchain_memory import BotConversationSummaryBufferWindowMemory, TestMemory from .langchain_memory import CustomMemory # BotConversationSummaryBufferWindowMemory, TestMemory
from ..utilities.messages import Message from ..utilities.messages import Message
from langchain import PromptTemplate from langchain import PromptTemplate
@ -90,7 +90,7 @@ class AI(object):
from ..wrappers.langchain_koboldcpp import KoboldCpp from ..wrappers.langchain_koboldcpp import KoboldCpp
self.llm_chat = KoboldCpp(temperature=self.bot.temperature, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>']) self.llm_chat = KoboldCpp(temperature=self.bot.temperature, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>'])
self.llm_summary = KoboldCpp(temperature=0.2, endpoint_url="http://172.16.85.10:5001/api/latest/generate", stop=['<|endoftext|>'], max_tokens=512) self.llm_summary = KoboldCpp(temperature=0.2, endpoint_url="http://172.16.85.10:5002/api/latest/generate", stop=['<|endoftext|>'], max_tokens=512)
self.text_wrapper = text_wrapper self.text_wrapper = text_wrapper
self.image_wrapper = image_wrapper self.image_wrapper = image_wrapper
self.embeddings = SentenceTransformerEmbeddings() self.embeddings = SentenceTransformerEmbeddings()
@ -102,7 +102,7 @@ class AI(object):
def get_memory(self, room_id, human_prefix="Human"): def get_memory(self, room_id, human_prefix="Human"):
if not room_id in self.rooms: if not room_id in self.rooms:
self.rooms[room_id] = {} self.rooms[room_id] = {}
memory = ConversationBufferMemory(memory_key="chat_history", input_key="input", human_prefix=human_prefix, ai_prefix=self.bot.name) memory = CustomMemory(memory_key="chat_history", input_key="input", human_prefix=human_prefix, ai_prefix=self.bot.name, llm=self.llm_summary, summary_prompt=prompt_progressive_summary, max_len=1200, min_len=200)
self.rooms[room_id]["memory"] = memory self.rooms[room_id]["memory"] = memory
self.rooms[room_id]["summary"] = "No previous events." self.rooms[room_id]["summary"] = "No previous events."
memory.chat_memory.add_ai_message(self.bot.greeting) memory.chat_memory.add_ai_message(self.bot.greeting)
@ -181,25 +181,15 @@ class AI(object):
async def generate_roleplay(self, message, reply_fn, typing_fn): async def generate_roleplay(self, message, reply_fn, typing_fn):
langchain_human_message = HumanMessage(
content=message.message,
additional_kwargs={
"timestamp": message.timestamp,
"user_name": message.user_name,
"event_id": message.event_id,
"user_id": message.user_id,
"room_name": message.room_name,
"room_id": message.room_id,
}
)
chat_ai_name = self.bot.name chat_ai_name = self.bot.name
chat_human_name = message.user_name chat_human_name = message.additional_kwargs['user_name']
room_id = message.additional_kwargs['room_id']
if False: # model is vicuna if False: # model is vicuna
chat_ai_name = "### Assistant" chat_ai_name = "### Assistant"
chat_human_name = "### Human" chat_human_name = "### Human"
conversation_memory = self.get_memory(message.room_id, message.user_name) conversation_memory = self.get_memory(room_id, chat_human_name)
conversation_memory.human_prefix = chat_human_name conversation_memory.human_prefix = chat_human_name
readonlymemory = ReadOnlySharedMemory(memory=conversation_memory) readonlymemory = ReadOnlySharedMemory(memory=conversation_memory)
summary_memory = ConversationSummaryMemory(llm=self.llm_summary, memory_key="summary", input_key="input") summary_memory = ConversationSummaryMemory(llm=self.llm_summary, memory_key="summary", input_key="input")
@ -211,11 +201,11 @@ class AI(object):
async def make_progressive_summary(previous_summary, chat_history_text_string): async def make_progressive_summary(previous_summary, chat_history_text_string):
await asyncio.sleep(0) # yield for matrix-nio await asyncio.sleep(0) # yield for matrix-nio
#self.rooms[message.room_id]["summary"] = summary_memory.predict_new_summary(conversation_memory.chat_memory.messages, previous_summary).strip() #self.rooms[room_id]["summary"] = summary_memory.predict_new_summary(conversation_memory.chat_memory.messages, previous_summary).strip()
summary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_progressive_summary, verbose=True) summary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_progressive_summary, verbose=True)
self.rooms[message.room_id]["summary"] = await summary_chain.apredict(summary=previous_summary, chat_history=chat_history_text_string) self.rooms[room_id]["summary"] = await summary_chain.apredict(summary=previous_summary, chat_history=chat_history_text_string)
# ToDo: maybe add an add_task_done callback and don't access the variable directly from here? # ToDo: maybe add an add_task_done callback and don't access the variable directly from here?
logger.info(f"New summary is: \"{self.rooms[message.room_id]['summary']}\"") logger.info(f"New summary is: \"{self.rooms[room_id]['summary']}\"")
conversation_memory.chat_memory.messages = conversation_memory.chat_memory.messages[-k * 2 :] conversation_memory.chat_memory.messages = conversation_memory.chat_memory.messages[-k * 2 :]
conversation_memory.load_memory_variables({}) conversation_memory.load_memory_variables({})
#summary = summarize(conversation_memory.buffer) #summary = summarize(conversation_memory.buffer)
@ -224,11 +214,11 @@ class AI(object):
logger.info("memory progressive summary scheduled...") logger.info("memory progressive summary scheduled...")
await self.bot.schedule(self.bot.queue, make_progressive_summary, self.rooms[message.room_id]["summary"], conversation_memory.buffer) #.add_done_callback( await self.bot.schedule(self.bot.queue, make_progressive_summary, self.rooms[room_id]["summary"], conversation_memory.buffer) #.add_done_callback(
#t = datetime.fromtimestamp(message.timestamp) #t = datetime.fromtimestamp(message.additional_kwargs['timestamp'])
#when = humanize.naturaltime(t) #when = humanize.naturaltime(t)
#print(when) #print(when)
@ -241,8 +231,8 @@ class AI(object):
ai_name=self.bot.name, ai_name=self.bot.name,
persona=self.bot.persona, persona=self.bot.persona,
scenario=self.bot.scenario, scenario=self.bot.scenario,
summary=self.rooms[message.room_id]["summary"], summary=self.rooms[room_id]["summary"],
human_name=message.user_name, human_name=chat_human_name,
#example_dialogue=replace_all(self.bot.example_dialogue, {"{{user}}": chat_human_name, "{{char}}": chat_ai_name}) #example_dialogue=replace_all(self.bot.example_dialogue, {"{{user}}": chat_human_name, "{{char}}": chat_ai_name})
ai_name_chat=chat_ai_name, ai_name_chat=chat_ai_name,
) )
@ -252,48 +242,44 @@ class AI(object):
prompt=prompt, prompt=prompt,
verbose=True, verbose=True,
memory=readonlymemory, memory=readonlymemory,
#stop=['<|endoftext|>', '\nYou:', f"\n{message.user_name}:"], #stop=['<|endoftext|>', '\nYou:', f"\n{chat_human_name}:"],
) )
# output = llm_chain(inputs={"ai_name": self.bot.name, "persona": self.bot.persona, "scenario": self.bot.scenario, "human_name": message.user_name, "ai_name_chat": self.bot.name, "chat_history": "", "input": message.message})['results'][0]['text'] # output = llm_chain(inputs={"ai_name": self.bot.name, "persona": self.bot.persona, "scenario": self.bot.scenario, "human_name": chat_human_name, "ai_name_chat": self.bot.name, "chat_history": "", "input": message.content})['results'][0]['text']
#roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=chat_ai_name, human_name_chat=chat_human_name) #roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=chat_ai_name, human_name_chat=chat_human_name)
stop = ['<|endoftext|>', f"\n{chat_human_name}"] stop = ['<|endoftext|>', f"\n{chat_human_name}"]
#print(f"Message is: \"{message.message}\"") #print(f"Message is: \"{message.content}\"")
await asyncio.sleep(0) await asyncio.sleep(0)
output = await chain.arun({"input":message.message, "stop": stop}) output = await chain.arun({"input":message.content, "stop": stop})
output = output.replace("<BOT>", self.bot.name).replace("<USER>", message.user_name) output = output.replace("<BOT>", self.bot.name).replace("<USER>", chat_human_name)
output = output.replace("### Assistant", self.bot.name) output = output.replace("### Assistant", self.bot.name)
output = output.replace(f"\n{self.bot.name}: ", " ") output = output.replace(f"\n{self.bot.name}: ", " ")
output = output.strip() output = output.strip()
if "*activates the neural uplink*" in output.casefold():
pass # call agent
own_message_resp = await reply_fn(output)
langchain_ai_message = AIMessage( langchain_ai_message = AIMessage(
content=output, content=output,
additional_kwargs={ additional_kwargs={
"timestamp": datetime.now().timestamp(), "timestamp": datetime.now().timestamp(),
"user_name": self.bot.name, "user_name": self.bot.name,
"event_id": None, "event_id": own_message_resp.event_id,
"user_id": None, "user_id": None,
"room_name": message.room_name, "room_name": message.additional_kwargs['room_name'],
"room_id": message.room_id, "room_id": own_message_resp.room_id,
} }
) )
if "*activates the neural uplink*" in output.casefold(): conversation_memory.save_context({"input": message.content}, {"ouput": output})
pass # call agent
#conversation_memory.chat_memory.messages.append(ChatMessage(content=message, role=message.user_name))
conversation_memory.chat_memory.add_user_message(message.message)
conversation_memory.chat_memory.add_ai_message(output)
conversation_memory.load_memory_variables({}) conversation_memory.load_memory_variables({})
if not "messages_today" in self.rooms[message.room_id]: return output
self.rooms[message.room_id]["messages_today"] = []
self.rooms[message.room_id]["messages_today"].append(langchain_human_message)
self.rooms[message.room_id]["messages_today"].append(langchain_ai_message)
return output.strip()
async def summarize(self, text): async def summarize(self, text):
@ -306,10 +292,11 @@ class AI(object):
async def diary(self, room_id): async def diary(self, room_id):
await asyncio.sleep(0) # yield for matrix-nio await asyncio.sleep(0) # yield for matrix-nio
diary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_outline, verbose=True) diary_chain = LLMChain(llm=self.llm_summary, prompt=prompt_outline, verbose=True)
conversation_memory = self.get_memory(room_id)
#self.rooms[message.room_id]["summary"] #self.rooms[message.room_id]["summary"]
string_messages = [] string_messages = []
for m in self.rooms[room_id]["messages_today"]: for m in conversation_memory.chat_memory_day.messages:
string_messages.append(f"{message.user_name}: {message.message}") string_messages.append(f"{message.role}: {message.content}")
return await diary_chain.apredict(text="\n".join(string_messages)) return await diary_chain.apredict(text="\n".join(string_messages))
@ -397,12 +384,12 @@ class AI(object):
# Summarize the last day and save a diary entry # Summarize the last day and save a diary entry
yesterday = ( datetime.now() - timedelta(days=1) ).strftime('%Y-%m-%d') yesterday = ( datetime.now() - timedelta(days=1) ).strftime('%Y-%m-%d')
for room_id in self.rooms.keys(): for room_id in self.rooms.keys():
if "messages_today" in self.rooms[room_id]: if len(conversation_memory.chat_memory_day.messages) > 0:
self.bot.rooms[room_id]["diary"][yesterday] = await self.diary(room_id) self.bot.rooms[room_id]["diary"][yesterday] = await self.diary(room_id)
# Calculate new goals for the character # Calculate new goals for the character
# Update stats # Update stats
# Let background tasks run # Let background tasks run
self.rooms[room_id]["messages_today"] = [] conversation_memory.chat_memory_day.clear()
await self.bot.write_conf2(self.bot.rooms) await self.bot.write_conf2(self.bot.rooms)

167
matrix_pygmalion_bot/bot/ai/langchain_memory.py

@ -1,15 +1,170 @@
from typing import Any, Dict, List import asyncio
from typing import Any, Dict, List, Tuple, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.prompt import SUMMARY_PROMPT from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string from langchain.schema import BaseLanguageModel, BaseMessage, BaseChatMessageHistory, BaseMemory, get_buffer_string
from langchain.schema import AIMessage, HumanMessage, SystemMessage, ChatMessage
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from ..utilities.messages import Message from ..utilities.messages import Message
class ChatMessageHistory(BaseModel):
messages: List[Message] = [] class ChatMessageHistoryCustom(BaseChatMessageHistory, BaseModel):
messages: List[BaseMessage] = []
def add_user_message(self, message: str) -> None:
self.messages.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.messages.append(AIMessage(content=message))
def add_system_message(self, message: str) -> None:
self.messages.append(SystemMessage(content=message))
def add_chat_message(self, message: str) -> None:
self.messages.append(ChatMessage(content=message))
def clear(self) -> None:
self.messages = []
class CustomMemory(BaseMemory):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistoryCustom)
chat_memory_day: BaseChatMessageHistory = Field(default_factory=ChatMessageHistoryCustom)
output_key: Optional[str] = None
input_key: Optional[str] = None
return_messages: bool = False
max_len: int = 1200
min_len: int = 200
#length_function: Callable[[str], int] = len,
#length_function: Callable[[str], int] = self.llm.get_num_tokens_from_messages,
moving_summary_buffer: str = ""
llm: BaseLanguageModel
summary_prompt: BasePromptTemplate = SUMMARY_PROMPT
#summary_message_cls: Type[BaseMessage] = SystemMessage
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str)
self.chat_memory.add_ai_message(output_str)
self.chat_memory_day.add_user_message(input_str)
self.chat_memory_day.add_ai_message(output_str)
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_len:
pruned_memory = []
while curr_buffer_length > self.min_len:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
loop = asyncio.get_event_loop()
self.moving_summary_buffer = loop.run_until_complete(
self.apredict_new_summary(pruned_memory, self.moving_summary_buffer)
)
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.chat_memory_day.clear()
self.moving_summary_buffer = ""
def get_buffer_string(self, messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI") -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
@property
def buffer(self) -> Any:
"""String buffer of memory."""
if self.return_messages:
return self.chat_memory.messages
else:
return self.get_buffer_string(
self.chat_memory.messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_day(self) -> Any:
"""String buffer of memory."""
if self.return_messages:
return self.chat_memory_day.messages
else:
return self.get_buffer_string(
self.chat_memory_day.messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
async def apredict_new_summary(self, messages: List[BaseMessage], existing_summary: str) -> str:
new_lines = self.get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.summary_prompt)
await asyncio.sleep(0)
return await chain.apredict(summary=existing_summary, new_lines=new_lines)
class ChatMessageHistoryMessage(BaseModel):
#messages: List[Message] = []
messages = []
def add_user_message(self, message: Message) -> None: def add_user_message(self, message: Message) -> None:
self.messages.append(message) self.messages.append(message)
@ -32,7 +187,7 @@ class TestMemory(BaseMemory):
human_prefix: str = "Human" human_prefix: str = "Human"
ai_prefix: str = "AI" ai_prefix: str = "AI"
chat_memory: ChatMessageHistory = Field(default_factory=ChatMessageHistory) chat_memory: ChatMessageHistory = Field(default_factory=ChatMessageHistoryMessage)
# buffer: str = "" # buffer: str = ""
output_key: Optional[str] = None output_key: Optional[str] = None
input_key: Optional[str] = None input_key: Optional[str] = None

1
matrix_pygmalion_bot/bot/ai/prompts.py

@ -128,6 +128,7 @@ New summary:
""" """
) )
#Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary. #Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
#only include relevant facts for {{char}}'s long-term memory / future
prompt_outline = PromptTemplate.from_template( prompt_outline = PromptTemplate.from_template(
"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

9
matrix_pygmalion_bot/bot/core.py

@ -154,9 +154,12 @@ class ChatBot(object):
if message.is_from(self.connection.user_id): if message.is_from(self.connection.user_id):
"""Skip messages from ouselves""" """Skip messages from ouselves"""
message.role = "ai"
self.chatlog.save(message) self.chatlog.save(message)
await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id) await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id)
return return
else:
message.role = "human"
# if event.decrypted: # if event.decrypted:
# encrypted_symbol = "🛡 " # encrypted_symbol = "🛡 "
@ -191,6 +194,8 @@ class ChatBot(object):
await self.schedule(self.queue, self.process_command, message, reply_fn, typing_fn) await self.schedule(self.queue, self.process_command, message, reply_fn, typing_fn)
# elif re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body, flags=re.IGNORECASE): # elif re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body, flags=re.IGNORECASE):
# # send, mail, drop, snap picture, photo, image, portrait # # send, mail, drop, snap picture, photo, image, portrait
elif message.is_error():
return
else: else:
await self.schedule(self.queue, self.process_message, message, reply_fn, typing_fn) await self.schedule(self.queue, self.process_message, message, reply_fn, typing_fn)
self.rooms[room.room_id]['num_messages'] += 1 self.rooms[room.room_id]['num_messages'] += 1
@ -257,10 +262,10 @@ class ChatBot(object):
async def process_message(self, message, reply_fn, typing_fn): async def process_message(self, message, reply_fn, typing_fn):
output = await self.ai.generate_roleplay(message, reply_fn, typing_fn) output = await self.ai.generate_roleplay(message.to_langchain(), reply_fn, typing_fn)
#output = await self.ai.generate(message, reply_fn, typing_fn) #output = await self.ai.generate(message, reply_fn, typing_fn)
# typing false # typing false
await reply_fn(output) #await reply_fn(output)
async def event_loop(self): async def event_loop(self):

43
matrix_pygmalion_bot/bot/utilities/messages.py

@ -2,7 +2,7 @@ from langchain.schema import AIMessage, HumanMessage, SystemMessage, ChatMessage
class Message(object): class Message(object):
def __init__(self, timestamp, user_name, message, event_id=None, user_id=None, room_name=None, room_id=None): def __init__(self, timestamp, user_name, message, event_id=None, user_id=None, room_name=None, room_id=None, role=None):
self.timestamp = timestamp self.timestamp = timestamp
self.user_name = user_name self.user_name = user_name
self.message = message self.message = message
@ -10,12 +10,53 @@ class Message(object):
self.user_id = user_id self.user_id = user_id
self.room_name = room_name self.room_name = room_name
self.room_id = room_id self.room_id = room_id
self.role = role
@classmethod @classmethod
def from_matrix(cls, room, event): def from_matrix(cls, room, event):
return cls(event.server_timestamp/1000, room.user_name(event.sender), event.body, event.event_id, event.sender, room.display_name, room.room_id) return cls(event.server_timestamp/1000, room.user_name(event.sender), event.body, event.event_id, event.sender, room.display_name, room.room_id)
def to_langchain(self): def to_langchain(self):
if self.role == "human":
return HumanMessage(
content=self.message,
role=self.user_name, # "chat"
additional_kwargs={
"timestamp": self.timestamp,
"user_name": self.user_name,
"event_id": self.event_id,
"user_id": self.user_id,
"room_name": self.room_name,
"room_id": self.room_id,
}
)
elif self.role == "ai":
return AIMessage(
content=self.message,
role=self.user_name, # "chat"
additional_kwargs={
"timestamp": self.timestamp,
"user_name": self.user_name,
"event_id": self.event_id,
"user_id": self.user_id,
"room_name": self.room_name,
"room_id": self.room_id,
}
)
elif self.role == "system":
return SystemMessage(
content=self.message,
role=self.user_name, # "chat"
additional_kwargs={
"timestamp": self.timestamp,
"user_name": self.user_name,
"event_id": self.event_id,
"user_id": self.user_id,
"room_name": self.room_name,
"room_id": self.room_id,
}
)
else:
return ChatMessage( return ChatMessage(
content=self.message, content=self.message,
role=self.user_name, # "chat" role=self.user_name, # "chat"

7
matrix_pygmalion_bot/bot/wrappers/langchain_koboldcpp.py

@ -10,7 +10,7 @@ import functools
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema import BaseMessage from langchain.schema import BaseMessage, AIMessage, HumanMessage, SystemMessage, ChatMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -164,8 +164,9 @@ class KoboldCpp(LLM):
for message in messages_dict: for message in messages_dict:
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():
num_tokens += len(self.get_num_tokens(value)) if key == "content":
if key == "name": num_tokens += self.get_num_tokens(value)
elif key == "name":
num_tokens += tokens_per_name num_tokens += tokens_per_name
num_tokens += 3 num_tokens += 3
return num_tokens return num_tokens

Loading…
Cancel
Save