You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
369 lines
16 KiB
369 lines
16 KiB
import asyncio
|
|
import nio
|
|
from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse)
|
|
|
|
import os, sys
|
|
import time
|
|
import importlib
|
|
import configparser
|
|
import logging
|
|
|
|
import aiofiles.os
|
|
import magic
|
|
from PIL import Image
|
|
import re
|
|
|
|
from .helpers import ChatItem, Event
|
|
ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion")
|
|
#from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style
|
|
#from .runpod_pygmalion import generate_sync, get_full_prompt
|
|
import matrix_pygmalion_bot.translate as translate
|
|
|
|
STORE_PATH = "./.store/"
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
config = configparser.ConfigParser()
|
|
bots = []
|
|
background_tasks = set()
|
|
|
|
class Callbacks(object):
|
|
"""Class to pass client to callback methods."""
|
|
|
|
def __init__(self, client: AsyncClient, bot):
|
|
self.client = client
|
|
self.bot = bot
|
|
|
|
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
|
event_id = event.event_id
|
|
message = event.body
|
|
is_own_message = False
|
|
if event.sender == self.client.user:
|
|
is_own_message = True
|
|
is_command = False
|
|
if event.body.startswith('!'):
|
|
is_command = True
|
|
if not (self.bot.owner is None):
|
|
if not (event.sender == self.bot.owner or is_own_message):
|
|
return
|
|
relates_to = None
|
|
if 'm.relates_to' in event.source["content"]:
|
|
relates_to = event.source["content"]['m.relates_to']["event_id"]
|
|
translated_message = message
|
|
if not (self.bot.translate is None) and not is_command:
|
|
if 'original_message' in event.source["content"]:
|
|
translated_message = event.source["content"]['original_message']
|
|
else:
|
|
translated_message = translate.translate(message, self.bot.translate, "en")
|
|
if hasattr(event, 'body'):
|
|
self.bot.chat_history[event_id] = ChatItem(event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message)
|
|
if self.bot.not_synced:
|
|
return
|
|
print(
|
|
"Message received for room {} | {}: {}".format(
|
|
room.display_name, room.user_name(event.sender), event.body
|
|
)
|
|
)
|
|
os.makedirs("./chatlogs", exist_ok=True)
|
|
with open("chatlogs/" + self.bot.name + "_" + room.display_name + ".txt", "a") as f:
|
|
f.write("{}: {}\n".format(room.user_name(event.sender), event.body))
|
|
await self.client.room_read_markers(room.room_id, event_id, event_id)
|
|
# Ignore messages from ourselves
|
|
if is_own_message:
|
|
return
|
|
|
|
if not hasattr(event, 'body'):
|
|
return
|
|
|
|
if event.body.startswith('!replybot'):
|
|
print(event)
|
|
await self.bot.send_message(self.client, room.room_id, "Hello World!")
|
|
return
|
|
elif event.body.startswith('!image'):
|
|
prompt = event.body.removeprefix('!image').strip()
|
|
negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head"
|
|
if len(prompt) == 0:
|
|
prompt = "a beautiful woman"
|
|
output = await ai.generate_image(prompt, negative_prompt, self.bot.runpod_api_key)
|
|
for imagefile in output:
|
|
await self.bot.send_image(self.client, room.room_id, imagefile)
|
|
return
|
|
elif event.body.startswith('!begin'):
|
|
self.bot.chat_history = {}
|
|
self.bot.timestamp = time.time()
|
|
await self.bot.write_conf2(self.bot.name)
|
|
await self.bot.send_message(self.client, room.room_id, self.bot.greeting)
|
|
return
|
|
elif event.body.startswith('!!!'):
|
|
if len(self.bot.chat_history) < 3:
|
|
return
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
return
|
|
elif event.body.startswith('!!'):
|
|
if len(self.bot.chat_history) < 3:
|
|
return
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current
|
|
self.bot.chat_history[chat_history_event_id] = chat_history_item
|
|
event_id = chat_history_item.event_id
|
|
message = chat_history_item.message
|
|
translated_message = message
|
|
# don't return, we generate a new answer
|
|
elif event.body.startswith('!replace'):
|
|
if len(self.bot.chat_history) < 3:
|
|
return
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
new_answer = event.body.removeprefix('!replace').strip()
|
|
new_translated_answer = new_answer
|
|
if not (self.bot.translate is None):
|
|
new_translated_answer = translate.translate(new_answer, "en", self.bot.translate)
|
|
await self.bot.send_message(self.client, room.room_id, new_translated_answer, reply_to=chat_history_item.relates_to_event, original_message=new_answer)
|
|
else:
|
|
await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event)
|
|
return
|
|
|
|
# Other commands
|
|
if re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", message):
|
|
# send, mail, drop, snap picture, photo, image, portrait
|
|
pass
|
|
|
|
full_prompt = await ai.get_full_prompt(translated_message, self.bot)
|
|
num_tokens = await ai.num_tokens(full_prompt)
|
|
logger.info(full_prompt)
|
|
logger.info(f"num tokens:" + str(num_tokens))
|
|
# answer = ""
|
|
# time = 0
|
|
# error = None
|
|
# try:
|
|
# async for output in generate(full_prompt):
|
|
# await asyncio.sleep(0.1)
|
|
# answer += output
|
|
# if time % 5 == 0:
|
|
# await self.client.room_typing(room.room_id, True, 15000)
|
|
# time +=1
|
|
# print(output, end='', flush=True)
|
|
# except Exception as e:
|
|
# error = e.__str__()
|
|
# answer = answer.strip()
|
|
# print("")
|
|
await self.client.room_typing(room.room_id, True, 15000)
|
|
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name)
|
|
answer = answer.strip()
|
|
await self.client.room_typing(room.room_id, False)
|
|
translated_answer = answer
|
|
if not (self.bot.translate is None):
|
|
translated_answer = translate.translate(answer, "en", self.bot.translate)
|
|
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event_id, original_message=answer)
|
|
else:
|
|
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event_id)
|
|
|
|
|
|
|
|
async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None:
|
|
"""Automatically join all rooms we get invited to"""
|
|
result = await self.client.join(room.room_id)
|
|
print('Invited to room: {} {}'.format(room.name, room.room_id))
|
|
if isinstance(result, nio.responses.JoinResponse):
|
|
print('Joined')
|
|
else:
|
|
print("Error joining room: {}".format(str(result)))
|
|
|
|
class ChatBot(object):
|
|
"""Main chatbot"""
|
|
|
|
def __init__(self, homeserver, user_id, password):
|
|
self.homeserver = homeserver
|
|
self.user_id = user_id
|
|
self.password = password
|
|
|
|
self.runpod_api_key = None
|
|
|
|
self.client = None
|
|
self.callbacks = None
|
|
self.config = None
|
|
self.not_synced = True
|
|
self.timestamp = time.time()
|
|
|
|
self.owner = None
|
|
self.translate = None
|
|
|
|
self.name = None
|
|
self.persona = None
|
|
self.scenario = None
|
|
self.greeting = None
|
|
self.events = []
|
|
self.chat_history = {}
|
|
|
|
if STORE_PATH and not os.path.isdir(STORE_PATH):
|
|
os.mkdir(STORE_PATH)
|
|
|
|
def character_init(self, name, persona, scenario, greeting):
|
|
self.name = name
|
|
self.persona = persona
|
|
self.scenario = scenario
|
|
self.greeting = greeting
|
|
|
|
async def event_loop(self):
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
#print(time.time())
|
|
for event in self.events:
|
|
event.loop(self)
|
|
|
|
async def add_event(self, event_string):
|
|
items = event_string.split(',', 3)
|
|
for item in items:
|
|
item = item.strip()
|
|
event = Event(float(items[0]), float(items[1]), float(items[2]), items[3])
|
|
self.events.append(event)
|
|
pass
|
|
|
|
async def login(self):
|
|
self.config = AsyncClientConfig(store_sync_tokens=True)
|
|
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config)
|
|
self.callbacks = Callbacks(self.client, self)
|
|
self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText)
|
|
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent)
|
|
|
|
sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced))
|
|
event_loop = asyncio.create_task(self.event_loop())
|
|
background_tasks.add(event_loop)
|
|
event_loop.add_done_callback(background_tasks.discard)
|
|
|
|
try:
|
|
response = await self.client.login(self.password)
|
|
print(response)
|
|
#sync_forever_task = asyncio.create_task(self.client.sync_forever(timeout=30000, full_state=True))
|
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
print("Received interrupt.")
|
|
await self.client.close()
|
|
#return sync_forever_task
|
|
|
|
async def watch_for_sync(self, sync_event):
|
|
print("Awaiting sync")
|
|
await sync_event.wait()
|
|
print("Client is synced")
|
|
self.not_synced = False
|
|
|
|
async def read_conf2(self, section):
|
|
config2 = configparser.ConfigParser()
|
|
config2.read('bot.conf2')
|
|
if config2.has_section(section) and config2.has_option(section, 'timestamp'):
|
|
self.timestamp = float(config2[section]['timestamp'])
|
|
else:
|
|
self.timestamp = time.time()
|
|
|
|
async def write_conf2(self, section):
|
|
config2 = configparser.ConfigParser()
|
|
config2.read('bot.conf2')
|
|
config2[section] = {}
|
|
config2[section]['timestamp'] = str(self.timestamp)
|
|
with open('bot.conf2', 'w') as configfile:
|
|
config2.write(configfile)
|
|
|
|
async def send_message(self, client, room_id, message, reply_to=None, original_message=None):
|
|
content={"msgtype": "m.text", "body": message}
|
|
if reply_to:
|
|
content["m.relates_to"] = {"event_id": reply_to, "rel_type": "de.xd0.mpygbot.in_reply_to"}
|
|
if original_message:
|
|
content["original_message"] = original_message
|
|
|
|
await client.room_send(
|
|
room_id=room_id,
|
|
message_type="m.room.message",
|
|
content=content,
|
|
)
|
|
|
|
async def send_image(self, client, room_id, image):
|
|
"""Send image to room
|
|
https://matrix-nio.readthedocs.io/en/latest/examples.html#sending-an-image
|
|
"""
|
|
mime_type = magic.from_file(image, mime=True) # e.g. "image/jpeg"
|
|
if not mime_type.startswith("image/"):
|
|
logger.error("Drop message because file does not have an image mime type.")
|
|
return
|
|
|
|
im = Image.open(image)
|
|
(width, height) = im.size # im.size returns (width,height) tuple
|
|
|
|
# first do an upload of image, then send URI of upload to room
|
|
file_stat = await aiofiles.os.stat(image)
|
|
async with aiofiles.open(image, "r+b") as f:
|
|
resp, maybe_keys = await client.upload(
|
|
f,
|
|
content_type=mime_type, # image/jpeg
|
|
filename=os.path.basename(image),
|
|
filesize=file_stat.st_size,
|
|
)
|
|
if isinstance(resp, UploadResponse):
|
|
print("Image was uploaded successfully to server. ")
|
|
else:
|
|
print(f"Failed to upload image. Failure response: {resp}")
|
|
|
|
content = {
|
|
"body": os.path.basename(image), # descriptive title
|
|
"info": {
|
|
"size": file_stat.st_size,
|
|
"mimetype": mime_type,
|
|
"thumbnail_info": None, # TODO
|
|
"w": width, # width in pixel
|
|
"h": height, # height in pixel
|
|
"thumbnail_url": None, # TODO
|
|
},
|
|
"msgtype": "m.image",
|
|
"url": resp.content_uri,
|
|
}
|
|
|
|
try:
|
|
await client.room_send(room_id, message_type="m.room.message", content=content)
|
|
print("Image was sent successfully")
|
|
except Exception:
|
|
print(f"Image send of file {image} failed.")
|
|
|
|
async def main() -> None:
|
|
config.read('bot.conf')
|
|
logging.basicConfig(level=logging.INFO)
|
|
for section in config.sections():
|
|
if section == 'DEFAULT' or section == 'Common':
|
|
pass
|
|
botname = section
|
|
homeserver = config[section]['url']
|
|
user_id = config[section]['username']
|
|
password = config[section]['password']
|
|
bot = ChatBot(homeserver, user_id, password)
|
|
bot.character_init(botname, config[section]['persona'], config[section]['scenario'], config[section]['greeting'])
|
|
if config.has_option(section, 'owner'):
|
|
bot.owner = config[section]['owner']
|
|
if config.has_option(section, 'translate'):
|
|
bot.translate = config[section]['translate']
|
|
translate.init(bot.translate, "en")
|
|
translate.init("en", bot.translate)
|
|
if config.has_option(section, 'image_prompt'):
|
|
bot.image_prompt = config[section]['image_prompt']
|
|
if config.has_option(section, 'events'):
|
|
events = config[section]['events'].strip().split('\n')
|
|
for event in events:
|
|
await bot.add_event(event)
|
|
if config.has_option('DEFAULT', 'runpod_api_key'):
|
|
bot.runpod_api_key = config['DEFAULT']['runpod_api_key']
|
|
await bot.read_conf2(section)
|
|
await bot.write_conf2(section)
|
|
bots.append(bot)
|
|
await bot.login()
|
|
print("gather")
|
|
async with asyncio.TaskGroup() as tg:
|
|
for bot in bots:
|
|
task = tg.create_task(bot.client.sync_forever(timeout=30000, full_state=True))
|
|
|
|
asyncio.get_event_loop().run_until_complete(main())
|
|
|