|
@ -13,7 +13,8 @@ import magic |
|
|
from PIL import Image |
|
|
from PIL import Image |
|
|
import re |
|
|
import re |
|
|
|
|
|
|
|
|
from .helpers import ChatItem, Event |
|
|
from .helpers import Event |
|
|
|
|
|
from .chatlog import BotChatHistory |
|
|
ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") |
|
|
ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") |
|
|
#from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style |
|
|
#from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style |
|
|
#from .runpod_pygmalion import generate_sync, get_full_prompt |
|
|
#from .runpod_pygmalion import generate_sync, get_full_prompt |
|
@ -35,8 +36,9 @@ class Callbacks(object): |
|
|
self.bot = bot |
|
|
self.bot = bot |
|
|
|
|
|
|
|
|
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: |
|
|
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: |
|
|
event_id = event.event_id |
|
|
if not hasattr(event, 'body'): |
|
|
message = event.body |
|
|
return |
|
|
|
|
|
#message = event.body |
|
|
is_own_message = False |
|
|
is_own_message = False |
|
|
if event.sender == self.client.user: |
|
|
if event.sender == self.client.user: |
|
|
is_own_message = True |
|
|
is_own_message = True |
|
@ -49,14 +51,15 @@ class Callbacks(object): |
|
|
relates_to = None |
|
|
relates_to = None |
|
|
if 'm.relates_to' in event.source["content"]: |
|
|
if 'm.relates_to' in event.source["content"]: |
|
|
relates_to = event.source["content"]['m.relates_to']["event_id"] |
|
|
relates_to = event.source["content"]['m.relates_to']["event_id"] |
|
|
translated_message = message |
|
|
language = "en" |
|
|
if not (self.bot.translate is None) and not is_command: |
|
|
if not (self.bot.translate is None) and not is_command: |
|
|
|
|
|
language = self.bot.translate |
|
|
if 'original_message' in event.source["content"]: |
|
|
if 'original_message' in event.source["content"]: |
|
|
translated_message = event.source["content"]['original_message'] |
|
|
english_original_message = event.source["content"]['original_message'] |
|
|
else: |
|
|
else: |
|
|
translated_message = translate.translate(message, self.bot.translate, "en") |
|
|
english_original_message = None |
|
|
if hasattr(event, 'body'): |
|
|
|
|
|
self.bot.chat_history[event_id] = ChatItem(event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message) |
|
|
chat_message = self.bot.chat_history.room(room.display_name).add(event.event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, event.body, language, english_original_message) |
|
|
if self.bot.not_synced: |
|
|
if self.bot.not_synced: |
|
|
return |
|
|
return |
|
|
print( |
|
|
print( |
|
@ -64,10 +67,8 @@ class Callbacks(object): |
|
|
room.display_name, room.user_name(event.sender), event.body |
|
|
room.display_name, room.user_name(event.sender), event.body |
|
|
) |
|
|
) |
|
|
) |
|
|
) |
|
|
os.makedirs("./chatlogs", exist_ok=True) |
|
|
|
|
|
with open("chatlogs/" + self.bot.name + "_" + room.display_name + ".txt", "a") as f: |
|
|
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id) |
|
|
f.write("{}: {}\n".format(room.user_name(event.sender), event.body)) |
|
|
|
|
|
await self.client.room_read_markers(room.room_id, event_id, event_id) |
|
|
|
|
|
# Ignore messages from ourselves |
|
|
# Ignore messages from ourselves |
|
|
if is_own_message: |
|
|
if is_own_message: |
|
|
return |
|
|
return |
|
@ -89,56 +90,47 @@ class Callbacks(object): |
|
|
await self.bot.send_image(self.client, room.room_id, imagefile) |
|
|
await self.bot.send_image(self.client, room.room_id, imagefile) |
|
|
return |
|
|
return |
|
|
elif event.body.startswith('!begin'): |
|
|
elif event.body.startswith('!begin'): |
|
|
self.bot.chat_history = {} |
|
|
self.bot.chat_history.room(room.display_name).clear() |
|
|
self.bot.timestamp = time.time() |
|
|
self.bot.timestamp = time.time() |
|
|
await self.bot.write_conf2(self.bot.name) |
|
|
await self.bot.write_conf2(self.bot.name) |
|
|
await self.bot.send_message(self.client, room.room_id, self.bot.greeting) |
|
|
await self.bot.send_message(self.client, room.room_id, self.bot.greeting) |
|
|
return |
|
|
return |
|
|
elif event.body.startswith('!!!'): |
|
|
elif event.body.startswith('!!!'): |
|
|
if len(self.bot.chat_history) < 3: |
|
|
if self.bot.chat_history.room(room.display_name).getLen() < 3: |
|
|
return |
|
|
return |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current |
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
return |
|
|
return |
|
|
elif event.body.startswith('!!'): |
|
|
elif event.body.startswith('!!'): |
|
|
if len(self.bot.chat_history) < 3: |
|
|
if self.bot.chat_history.room(room.display_name).getLen() < 3: |
|
|
return |
|
|
return |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current |
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)# current |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current |
|
|
chat_message = self.bot.chat_history.room(room.display_name).getLastItem() # new current |
|
|
self.bot.chat_history[chat_history_event_id] = chat_history_item |
|
|
|
|
|
event_id = chat_history_item.event_id |
|
|
|
|
|
message = chat_history_item.message |
|
|
|
|
|
translated_message = message |
|
|
|
|
|
# don't return, we generate a new answer |
|
|
# don't return, we generate a new answer |
|
|
elif event.body.startswith('!replace'): |
|
|
elif event.body.startswith('!replace'): |
|
|
if len(self.bot.chat_history) < 3: |
|
|
if self.bot.chat_history.room(room.display_name).getLen() < 3: |
|
|
return |
|
|
return |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current |
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
new_answer = event.body.removeprefix('!replace').strip() |
|
|
new_answer = event.body.removeprefix('!replace').strip() |
|
|
new_translated_answer = new_answer |
|
|
|
|
|
if not (self.bot.translate is None): |
|
|
|
|
|
new_translated_answer = translate.translate(new_answer, "en", self.bot.translate) |
|
|
|
|
|
await self.bot.send_message(self.client, room.room_id, new_translated_answer, reply_to=chat_history_item.relates_to_event, original_message=new_answer) |
|
|
|
|
|
else: |
|
|
|
|
|
await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) |
|
|
await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) |
|
|
return |
|
|
return |
|
|
|
|
|
|
|
|
# Other commands |
|
|
# Other commands |
|
|
if re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", message): |
|
|
if re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body): |
|
|
# send, mail, drop, snap picture, photo, image, portrait |
|
|
# send, mail, drop, snap picture, photo, image, portrait |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
full_prompt = await ai.get_full_prompt(translated_message, self.bot) |
|
|
full_prompt = await ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.display_name)) |
|
|
num_tokens = await ai.num_tokens(full_prompt) |
|
|
num_tokens = await ai.num_tokens(full_prompt) |
|
|
logger.info(full_prompt) |
|
|
logger.info(full_prompt) |
|
|
logger.info(f"num tokens:" + str(num_tokens)) |
|
|
logger.info(f"num tokens:" + str(num_tokens)) |
|
@ -161,12 +153,11 @@ class Callbacks(object): |
|
|
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name) |
|
|
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name) |
|
|
answer = answer.strip() |
|
|
answer = answer.strip() |
|
|
await self.client.room_typing(room.room_id, False) |
|
|
await self.client.room_typing(room.room_id, False) |
|
|
translated_answer = answer |
|
|
|
|
|
if not (self.bot.translate is None): |
|
|
if not (self.bot.translate is None): |
|
|
translated_answer = translate.translate(answer, "en", self.bot.translate) |
|
|
translated_answer = translate.translate(answer, "en", self.bot.translate) |
|
|
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event_id, original_message=answer) |
|
|
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=chat_message.event_id, original_message=answer) |
|
|
else: |
|
|
else: |
|
|
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event_id) |
|
|
await self.bot.send_message(self.client, room.room_id, answer, reply_to=chat_message.event_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -203,7 +194,7 @@ class ChatBot(object): |
|
|
self.scenario = None |
|
|
self.scenario = None |
|
|
self.greeting = None |
|
|
self.greeting = None |
|
|
self.events = [] |
|
|
self.events = [] |
|
|
self.chat_history = {} |
|
|
self.chat_history = None |
|
|
|
|
|
|
|
|
if STORE_PATH and not os.path.isdir(STORE_PATH): |
|
|
if STORE_PATH and not os.path.isdir(STORE_PATH): |
|
|
os.mkdir(STORE_PATH) |
|
|
os.mkdir(STORE_PATH) |
|
@ -232,6 +223,7 @@ class ChatBot(object): |
|
|
async def login(self): |
|
|
async def login(self): |
|
|
self.config = AsyncClientConfig(store_sync_tokens=True) |
|
|
self.config = AsyncClientConfig(store_sync_tokens=True) |
|
|
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) |
|
|
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) |
|
|
|
|
|
self.chat_history = BotChatHistory(self.name) |
|
|
self.callbacks = Callbacks(self.client, self) |
|
|
self.callbacks = Callbacks(self.client, self) |
|
|
self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText) |
|
|
self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText) |
|
|
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) |
|
|
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) |
|
|