Browse Source

context size

master
Hendrik Langer 2 years ago
parent
commit
0b52bb31b7
  1. 7
      matrix_pygmalion_bot/bot/ai/langchain.py
  2. 18
      matrix_pygmalion_bot/bot/ai/langchain_memory.py
  3. 7
      matrix_pygmalion_bot/bot/utilities/messages.py

7
matrix_pygmalion_bot/bot/ai/langchain.py

@ -119,6 +119,7 @@ class AI(object):
room_id = message.additional_kwargs['room_id']
conversation_memory = self.get_memory(room_id)
conversation_memory.chat_memory.messages.append(message)
conversation_memory.chat_memory_day.messages.append(message)
async def clear(self, room_id):
conversation_memory = self.get_memory(room_id)
@ -219,6 +220,8 @@ class AI(object):
if prompt_len+256 > 2000:
logger.warning(f"Prompt too large. Estimated {prompt_len} tokens")
await reply_fn(f"<WARNING> Prompt too large. Estimated {prompt_len} tokens")
await conversation_memory.prune_memory(conversation_memory.min_len)
#roleplay_chain = RoleplayChain(llm_chain=chain, character_name=self.bot.name, persona=self.bot.persona, scenario=self.bot.scenario, ai_name_chat=chat_ai_name, human_name_chat=chat_human_name)
@ -370,6 +373,8 @@ class AI(object):
"room_id": room_id,
}
)
if conversation_memory.chat_memory.messages[-1].content.startswith('~~~~ '):
conversation_memory.chat_memory.messages.pop()
conversation_memory.chat_memory.messages.append(message)
#conversation_memory.chat_memory.add_system_message(message)
@ -377,6 +382,8 @@ class AI(object):
yesterday = ( datetime.now() - timedelta(days=1) ).strftime('%Y-%m-%d')
for room_id in self.rooms.keys():
if len(conversation_memory.chat_memory_day.messages) > 0:
if not "diary" in self.bot.rooms[room_id]:
self.bot.rooms[room_id]['diary'] = {}
self.bot.rooms[room_id]["diary"][yesterday] = await self.diary(room_id)
# Calculate new goals for the character
# Update stats

18
matrix_pygmalion_bot/bot/ai/langchain_memory.py

@ -91,22 +91,24 @@ class CustomMemory(BaseMemory):
self.predict_new_summary(pruned_memory, self.moving_summary_buffer)
)
async def asave_context(self, input_msg: BaseMessage, output_msg: BaseMessage) -> None:
"""Save context from this conversation to buffer."""
self.chat_memory.messages.append(input_msg)
self.chat_memory.messages.append(output_msg)
self.chat_memory_day.messages.append(input_msg)
self.chat_memory_day.messages.append(output_msg)
async def prune_memory(self, max_len):
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_len:
if curr_buffer_length > max_len:
pruned_memory = []
while curr_buffer_length > self.min_len:
while curr_buffer_length > self.min_len and len(buffer) > 0:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
self.moving_summary_buffer = await self.apredict_new_summary(pruned_memory, self.moving_summary_buffer)
async def asave_context(self, input_msg: BaseMessage, output_msg: BaseMessage) -> None:
"""Save context from this conversation to buffer."""
self.chat_memory.messages.append(input_msg)
self.chat_memory.messages.append(output_msg)
self.chat_memory_day.messages.append(input_msg)
self.chat_memory_day.messages.append(output_msg)
await self.prune_memory(self.max_len)
def clear(self) -> None:
"""Clear memory contents."""

7
matrix_pygmalion_bot/bot/utilities/messages.py

@ -77,7 +77,12 @@ class Message(object):
return self.message.startswith('!')
def is_error(self):
return self.message.startswith('<ERROR>')
if self.message.startswith('<ERROR>'):
return True
elif self.message.startswith('<WARNING>'):
return True
else:
return False
def __str__(self):
return str("{}: {}".format(self.user_name, self.message))

Loading…
Cancel
Save