diff --git a/matrix_pygmalion_bot/bot/ai/langchain.py b/matrix_pygmalion_bot/bot/ai/langchain.py index a5509df..57ccc04 100644 --- a/matrix_pygmalion_bot/bot/ai/langchain.py +++ b/matrix_pygmalion_bot/bot/ai/langchain.py @@ -5,7 +5,7 @@ from .langchain_memory import BotConversationSummerBufferWindowMemory from langchain import PromptTemplate from langchain import LLMChain, ConversationChain -from langchain.memory import ConversationBufferMemory +from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory from langchain.chains.base import Chain from typing import Dict, List @@ -94,7 +94,7 @@ class AI(object): docs = text_splitter.split_documents(documents) - db = Chroma(persist_directory=f'{self.memory_path}/chroma-db', embedding_function=embeddings) + db = Chroma(persist_directory=os.path.join(self.memory_path, f'chroma-db'), embedding_function=embeddings) print(f"Indexing {len(docs)} documents") texts = [doc.page_content for doc in docs] @@ -123,6 +123,7 @@ class AI(object): async def generate_roleplay(self, message, reply_fn, typing_fn): memory = self.get_memory(message) + readonlymemory = ReadOnlySharedMemory(memory=memory) prompt = prompt_vicuna.partial( ai_name=self.bot.name, @@ -137,7 +138,7 @@ class AI(object): llm=self.llm_chat, prompt=prompt, verbose=True, - memory=memory, + memory=readonlymemory, #stop=['<|endoftext|>', '\nYou:', f"\n{message.user_name}:"], ) @@ -148,6 +149,10 @@ class AI(object): stop = ['<|endoftext|>', f"\n{message.user_name}:"] print(f"Message is: \"{message.message}\"") output = await chain.arun({"input":message.message, "stop": stop}) + memory.chat_memory.add_user_message(message.message) + memory.chat_memory.add_ai_message(output) + output = output.strip() + memory.load_memory_variables({}) return output.strip() diff --git a/matrix_pygmalion_bot/bot/core.py b/matrix_pygmalion_bot/bot/core.py index 77131b5..d7a7d2a 100644 --- a/matrix_pygmalion_bot/bot/core.py +++ b/matrix_pygmalion_bot/bot/core.py @@ -3,6 +3,7 @@ import os, sys import time import importlib import re +import json import logging from functools import partial from .memory.chatlog import ChatLog @@ -38,14 +39,26 @@ class ChatBot(object): #self.example_dialogue = self.example_dialogue.replace('{{char}}', self.name) - def persist(self, data_dir): - self.chatlog_path = f"{data_dir}/chatlogs" - self.images_path = f"{data_dir}/images" - self.memory_path = f"{data_dir}/memory" + async def persist(self, data_dir): + self.chatlog_path = os.path.join(data_dir, "chatlogs/") + self.images_path = os.path.join(data_dir, "images/") + self.memory_path = os.path.join(data_dir, "memory/") + self.rooms_conf_file = os.path.join(data_dir, "rooms.conf") os.makedirs(self.chatlog_path, exist_ok=True) os.makedirs(self.images_path, exist_ok=True) os.makedirs(self.memory_path, exist_ok=True) self.chatlog.enable_logging(self.chatlog_path) + self.rooms = await self.read_conf2() + + async def read_conf2(self): + if not os.path.isfile(self.rooms_conf_file): + return {} + with open(self.rooms_conf_file, "r") as f: + return json.load(f) + + async def write_conf2(self, data): + with open(self.rooms_conf_file, "w") as f: + json.dump(data, f) async def connect(self): await self.connection.login() @@ -114,6 +127,7 @@ class ChatBot(object): if not room.room_id in self.rooms: self.rooms[room.room_id] = {} + self.write_conf2(self.rooms) # ToDo: set ticks 0 / start if not self.connection.synced: @@ -209,8 +223,10 @@ class ChatBot(object): await reply_fn(self.greeting) elif message.message.startswith('!start'): self.rooms[message.room_id]["disabled"] = False + self.write_conf2(self.rooms) elif message.message.startswith('!stop'): self.rooms[message.room_id]["disabled"] = True + self.write_conf2(self.rooms) elif message.message.startswith('!!'): if self.chatlog.chat_history_len(message.room_id) > 2: for _ in range(2): diff --git a/matrix_pygmalion_bot/bot/memory/chatlog.py b/matrix_pygmalion_bot/bot/memory/chatlog.py index 8458418..1aa295b 100644 --- a/matrix_pygmalion_bot/bot/memory/chatlog.py +++ b/matrix_pygmalion_bot/bot/memory/chatlog.py @@ -1,4 +1,5 @@ from time import gmtime, localtime, strftime +import os from ..utilities.messages import Message class ChatLog(object): @@ -20,7 +21,7 @@ class ChatLog(object): room_id_sanitized = "".join(c for c in message.room_id if c.isalnum() or c in keepcharacters).strip() time_suffix = strftime("%Y-%m", localtime()) time = strftime("%a, %d %b %Y %H:%M:%S", localtime(message.timestamp)) - with open(f"{self.directory}/{message.room_name}_{room_id_sanitized}_{time_suffix}.txt", "a") as f: + with open(os.path.join(self.directory, f"{message.room_name}_{room_id_sanitized}_{time_suffix}.txt"), "a") as f: f.write("{} | {}: {}\n".format(time, message.user_name, message.message)) diff --git a/matrix_pygmalion_bot/connections/matrix.py b/matrix_pygmalion_bot/connections/matrix.py index 2ee706c..d7f34c9 100644 --- a/matrix_pygmalion_bot/connections/matrix.py +++ b/matrix_pygmalion_bot/connections/matrix.py @@ -154,7 +154,7 @@ class ChatClient(object): self.device_name = device_name self.synced = False - def persist(self, data_dir): + async def persist(self, data_dir): #self.data_dir = data_dir self.config_file = f"{data_dir}/matrix_credentials.json" self.store_path = f"{data_dir}/store/" diff --git a/matrix_pygmalion_bot/main.py b/matrix_pygmalion_bot/main.py index 8fca87a..3b82ee5 100644 --- a/matrix_pygmalion_bot/main.py +++ b/matrix_pygmalion_bot/main.py @@ -33,9 +33,9 @@ async def main() -> None: for section in config.sections(): bot_config = config[section] connection = ChatClient(bot_config['matrix_homeserver'], bot_config['matrix_username'], bot_config['matrix_password'], bot_config.get('matrix_device_name', 'matrix-nio')) - connection.persist(f"{DATA_DIR}/{section}/matrix") + await connection.persist(f"{DATA_DIR}/{section}/matrix") bot = ChatBot(section, connection) - bot.persist(f"{DATA_DIR}/{section}") + await bot.persist(f"{DATA_DIR}/{section}") bot.init_character( bot_config['persona'], bot_config['scenario'],