|
|
@ -3,6 +3,7 @@ import nio |
|
|
|
from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse) |
|
|
|
|
|
|
|
import os, sys |
|
|
|
import time |
|
|
|
import importlib |
|
|
|
import configparser |
|
|
|
import logging |
|
|
@ -10,6 +11,7 @@ import logging |
|
|
|
import aiofiles.os |
|
|
|
import magic |
|
|
|
from PIL import Image |
|
|
|
import re |
|
|
|
|
|
|
|
from .helpers import ChatItem |
|
|
|
ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") |
|
|
@ -23,6 +25,7 @@ STORE_PATH = "./.store/" |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
config = configparser.ConfigParser() |
|
|
|
bots = [] |
|
|
|
background_tasks = set() |
|
|
|
|
|
|
|
class Callbacks(object): |
|
|
|
"""Class to pass client to callback methods.""" |
|
|
@ -32,6 +35,7 @@ class Callbacks(object): |
|
|
|
self.bot = bot |
|
|
|
|
|
|
|
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: |
|
|
|
event_id = event.event_id |
|
|
|
message = event.body |
|
|
|
is_own_message = False |
|
|
|
if event.sender == self.client.user: |
|
|
@ -39,6 +43,9 @@ class Callbacks(object): |
|
|
|
is_command = False |
|
|
|
if event.body.startswith('!'): |
|
|
|
is_command = True |
|
|
|
if not (self.bot.owner is None): |
|
|
|
if not (event.sender == self.bot.owner or is_own_message): |
|
|
|
return |
|
|
|
relates_to = None |
|
|
|
if 'm.relates_to' in event.source["content"]: |
|
|
|
relates_to = event.source["content"]['m.relates_to']["event_id"] |
|
|
@ -49,7 +56,7 @@ class Callbacks(object): |
|
|
|
else: |
|
|
|
translated_message = translate.translate(message, self.bot.translate, "en") |
|
|
|
if hasattr(event, 'body'): |
|
|
|
self.bot.chat_history[event.event_id] = ChatItem(event.event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message) |
|
|
|
self.bot.chat_history[event_id] = ChatItem(event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message) |
|
|
|
if self.bot.not_synced: |
|
|
|
return |
|
|
|
print( |
|
|
@ -57,7 +64,7 @@ class Callbacks(object): |
|
|
|
room.display_name, room.user_name(event.sender), event.body |
|
|
|
) |
|
|
|
) |
|
|
|
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id) |
|
|
|
await self.client.room_read_markers(room.room_id, event_id, event_id) |
|
|
|
# Ignore messages from ourselves |
|
|
|
if is_own_message: |
|
|
|
return |
|
|
@ -81,23 +88,26 @@ class Callbacks(object): |
|
|
|
return |
|
|
|
elif event.body.startswith('!!!'): |
|
|
|
return |
|
|
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current |
|
|
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
# return |
|
|
|
elif event.body.startswith('!!'): |
|
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current |
|
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
return |
|
|
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current |
|
|
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current |
|
|
|
# self.bot.chat_history[chat_history_event_id] = chat_history_item |
|
|
|
# message = chat_history_item.message |
|
|
|
# # don't return, we generate a new answer |
|
|
|
elif event.body.startswith('!!'): |
|
|
|
if len(self.bot.chat_history) < 3: |
|
|
|
return |
|
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current |
|
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() |
|
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") |
|
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current |
|
|
|
self.bot.chat_history[chat_history_event_id] = chat_history_item |
|
|
|
event_id = chat_history_item.event_id |
|
|
|
message = chat_history_item.message |
|
|
|
translated_message = message |
|
|
|
# don't return, we generate a new answer |
|
|
|
|
|
|
|
full_prompt = await ai.get_full_prompt(translated_message, self.bot) |
|
|
|
num_tokens = await ai.num_tokens(full_prompt) |
|
|
@ -119,15 +129,15 @@ class Callbacks(object): |
|
|
|
# answer = answer.strip() |
|
|
|
# print("") |
|
|
|
await self.client.room_typing(room.room_id, True, 15000) |
|
|
|
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key) |
|
|
|
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name) |
|
|
|
answer = answer.strip() |
|
|
|
await self.client.room_typing(room.room_id, False) |
|
|
|
translated_answer = answer |
|
|
|
if not (self.bot.translate is None): |
|
|
|
translated_answer = translate.translate(answer, "en", self.bot.translate) |
|
|
|
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event.event_id, original_message=answer) |
|
|
|
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event_id, original_message=answer) |
|
|
|
else: |
|
|
|
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event.event_id) |
|
|
|
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -162,6 +172,7 @@ class ChatBot(object): |
|
|
|
self.persona = None |
|
|
|
self.scenario = None |
|
|
|
self.greeting = None |
|
|
|
self.events = [] |
|
|
|
self.chat_history = {} |
|
|
|
|
|
|
|
if STORE_PATH and not os.path.isdir(STORE_PATH): |
|
|
@ -173,6 +184,13 @@ class ChatBot(object): |
|
|
|
self.scenario = scenario |
|
|
|
self.greeting = greeting |
|
|
|
|
|
|
|
async def event_loop(self): |
|
|
|
while True: |
|
|
|
await asyncio.sleep(60) |
|
|
|
print(time.time()) |
|
|
|
for event in self.events: |
|
|
|
event.loop() |
|
|
|
|
|
|
|
async def login(self): |
|
|
|
self.config = AsyncClientConfig(store_sync_tokens=True) |
|
|
|
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) |
|
|
@ -181,14 +199,18 @@ class ChatBot(object): |
|
|
|
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) |
|
|
|
|
|
|
|
sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced)) |
|
|
|
event_loop = asyncio.create_task(self.event_loop()) |
|
|
|
background_tasks.add(event_loop) |
|
|
|
event_loop.add_done_callback(background_tasks.discard) |
|
|
|
|
|
|
|
try: |
|
|
|
response = await self.client.login(self.password) |
|
|
|
print(response) |
|
|
|
await self.client.sync_forever(timeout=30000, full_state=True) |
|
|
|
#sync_forever_task = asyncio.create_task(self.client.sync_forever(timeout=30000, full_state=True)) |
|
|
|
except (asyncio.CancelledError, KeyboardInterrupt): |
|
|
|
print("Received interrupt.") |
|
|
|
await self.client.close() |
|
|
|
#return sync_forever_task |
|
|
|
|
|
|
|
async def watch_for_sync(self, sync_event): |
|
|
|
print("Awaiting sync") |
|
|
@ -279,6 +301,9 @@ async def main() -> None: |
|
|
|
bot.runpod_api_key = config['DEFAULT']['runpod_api_key'] |
|
|
|
bots.append(bot) |
|
|
|
await bot.login() |
|
|
|
print("logged in") |
|
|
|
print("gather") |
|
|
|
async with asyncio.TaskGroup() as tg: |
|
|
|
for bot in bots: |
|
|
|
task = tg.create_task(bot.client.sync_forever(timeout=30000, full_state=True)) |
|
|
|
|
|
|
|
asyncio.get_event_loop().run_until_complete(main()) |
|
|
|