You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
285 lines
12 KiB
285 lines
12 KiB
2 years ago
|
import asyncio
|
||
|
import nio
|
||
|
from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse)
|
||
|
|
||
|
import os, sys
|
||
|
import importlib
|
||
|
import configparser
|
||
|
import logging
|
||
|
|
||
|
import aiofiles.os
|
||
|
import magic
|
||
|
from PIL import Image
|
||
|
|
||
|
from .helpers import ChatItem
|
||
|
ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion")
|
||
|
#from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style
|
||
|
#from .runpod_pygmalion import generate_sync, get_full_prompt
|
||
|
import matrix_pygmalion_bot.translate as translate
|
||
|
|
||
|
STORE_PATH = "./.store/"
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
config = configparser.ConfigParser()
|
||
|
bots = []
|
||
|
|
||
|
class Callbacks(object):
|
||
|
"""Class to pass client to callback methods."""
|
||
|
|
||
|
def __init__(self, client: AsyncClient, bot):
|
||
|
self.client = client
|
||
|
self.bot = bot
|
||
|
|
||
|
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||
|
message = event.body
|
||
|
is_own_message = False
|
||
|
if event.sender == self.client.user:
|
||
|
is_own_message = True
|
||
|
is_command = False
|
||
|
if event.body.startswith('!'):
|
||
|
is_command = True
|
||
|
relates_to = None
|
||
|
if 'm.relates_to' in event.source["content"]:
|
||
|
relates_to = event.source["content"]['m.relates_to']["event_id"]
|
||
|
translated_message = message
|
||
|
if not (self.bot.translate is None) and not is_command:
|
||
|
if 'original_message' in event.source["content"]:
|
||
|
translated_message = event.source["content"]['original_message']
|
||
|
else:
|
||
|
translated_message = translate.translate(message, self.bot.translate, "en")
|
||
|
if hasattr(event, 'body'):
|
||
|
self.bot.chat_history[event.event_id] = ChatItem(event.event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message)
|
||
|
if self.bot.not_synced:
|
||
|
return
|
||
|
print(
|
||
|
"Message received for room {} | {}: {}".format(
|
||
|
room.display_name, room.user_name(event.sender), event.body
|
||
|
)
|
||
|
)
|
||
|
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
|
||
|
# Ignore messages from ourselves
|
||
|
if is_own_message:
|
||
|
return
|
||
|
|
||
|
if hasattr(event, 'body') and event.body.startswith('!replybot'):
|
||
|
print(event)
|
||
|
await self.bot.send_message(self.client, room.room_id, "Hello World!")
|
||
|
return
|
||
|
elif hasattr(event, 'body') and event.body.startswith('!image'):
|
||
|
prompt = event.body.removeprefix('!image').strip()
|
||
|
negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head"
|
||
|
if len(prompt) == 0:
|
||
|
prompt = "a beautiful woman"
|
||
|
output = await ai.generate_image(prompt, negative_prompt, self.bot.runpod_api_key)
|
||
|
for imagefile in output:
|
||
|
await self.bot.send_image(self.client, room.room_id, imagefile)
|
||
|
return
|
||
|
elif hasattr(event, 'body') and event.body.startswith('!begin'):
|
||
|
self.bot.chat_history = {}
|
||
|
await self.bot.send_message(self.client, room.room_id, self.bot.greeting)
|
||
|
return
|
||
|
elif event.body.startswith('!!!'):
|
||
|
return
|
||
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
|
||
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
||
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
|
||
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
||
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
|
||
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
||
|
# return
|
||
|
elif event.body.startswith('!!'):
|
||
|
return
|
||
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
|
||
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
||
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
|
||
|
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
||
|
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current
|
||
|
# self.bot.chat_history[chat_history_event_id] = chat_history_item
|
||
|
# message = chat_history_item.message
|
||
|
# # don't return, we generate a new answer
|
||
|
|
||
|
full_prompt = await ai.get_full_prompt(translated_message, self.bot)
|
||
|
num_tokens = await ai.num_tokens(full_prompt)
|
||
|
logger.info(full_prompt)
|
||
|
logger.info(f"num tokens:" + str(num_tokens))
|
||
|
# answer = ""
|
||
|
# time = 0
|
||
|
# error = None
|
||
|
# try:
|
||
|
# async for output in generate(full_prompt):
|
||
|
# await asyncio.sleep(0.1)
|
||
|
# answer += output
|
||
|
# if time % 5 == 0:
|
||
|
# await self.client.room_typing(room.room_id, True, 15000)
|
||
|
# time +=1
|
||
|
# print(output, end='', flush=True)
|
||
|
# except Exception as e:
|
||
|
# error = e.__str__()
|
||
|
# answer = answer.strip()
|
||
|
# print("")
|
||
|
await self.client.room_typing(room.room_id, True, 15000)
|
||
|
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key)
|
||
|
answer = answer.strip()
|
||
|
await self.client.room_typing(room.room_id, False)
|
||
|
translated_answer = answer
|
||
|
if not (self.bot.translate is None):
|
||
|
translated_answer = translate.translate(answer, "en", self.bot.translate)
|
||
|
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event.event_id, original_message=answer)
|
||
|
else:
|
||
|
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event.event_id)
|
||
|
|
||
|
|
||
|
|
||
|
async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||
|
"""Automatically join all rooms we get invited to"""
|
||
|
result = await self.client.join(room.room_id)
|
||
|
print('Invited to room: {} {}'.format(room.name, room.room_id))
|
||
|
if isinstance(result, nio.responses.JoinResponse):
|
||
|
print('Joined')
|
||
|
else:
|
||
|
print("Error joining room: {}".format(str(result)))
|
||
|
|
||
|
class ChatBot(object):
|
||
|
"""Main chatbot"""
|
||
|
|
||
|
def __init__(self, homeserver, user_id, password):
|
||
|
self.homeserver = homeserver
|
||
|
self.user_id = user_id
|
||
|
self.password = password
|
||
|
|
||
|
self.runpod_api_key = None
|
||
|
|
||
|
self.client = None
|
||
|
self.callbacks = None
|
||
|
self.config = None
|
||
|
self.not_synced = True
|
||
|
|
||
|
self.owner = None
|
||
|
self.translate = None
|
||
|
|
||
|
self.name = None
|
||
|
self.persona = None
|
||
|
self.scenario = None
|
||
|
self.greeting = None
|
||
|
self.chat_history = {}
|
||
|
|
||
|
if STORE_PATH and not os.path.isdir(STORE_PATH):
|
||
|
os.mkdir(STORE_PATH)
|
||
|
|
||
|
def character_init(self, name, persona, scenario, greeting):
|
||
|
self.name = name
|
||
|
self.persona = persona
|
||
|
self.scenario = scenario
|
||
|
self.greeting = greeting
|
||
|
|
||
|
async def login(self):
|
||
|
self.config = AsyncClientConfig(store_sync_tokens=True)
|
||
|
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config)
|
||
|
self.callbacks = Callbacks(self.client, self)
|
||
|
self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText)
|
||
|
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent)
|
||
|
|
||
|
sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced))
|
||
|
|
||
|
try:
|
||
|
response = await self.client.login(self.password)
|
||
|
print(response)
|
||
|
await self.client.sync_forever(timeout=30000, full_state=True)
|
||
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
||
|
print("Received interrupt.")
|
||
|
await self.client.close()
|
||
|
|
||
|
async def watch_for_sync(self, sync_event):
|
||
|
print("Awaiting sync")
|
||
|
await sync_event.wait()
|
||
|
print("Client is synced")
|
||
|
self.not_synced = False
|
||
|
|
||
|
async def send_message(self, client, room_id, message, reply_to=None, original_message=None):
|
||
|
content={"msgtype": "m.text", "body": message}
|
||
|
if reply_to:
|
||
|
content["m.relates_to"] = {"event_id": reply_to, "rel_type": "de.xd0.mpygbot.in_reply_to"}
|
||
|
if original_message:
|
||
|
content["original_message"] = original_message
|
||
|
|
||
|
await client.room_send(
|
||
|
room_id=room_id,
|
||
|
message_type="m.room.message",
|
||
|
content=content,
|
||
|
)
|
||
|
|
||
|
async def send_image(self, client, room_id, image):
|
||
|
"""Send image to room
|
||
|
https://matrix-nio.readthedocs.io/en/latest/examples.html#sending-an-image
|
||
|
"""
|
||
|
mime_type = magic.from_file(image, mime=True) # e.g. "image/jpeg"
|
||
|
if not mime_type.startswith("image/"):
|
||
|
logger.error("Drop message because file does not have an image mime type.")
|
||
|
return
|
||
|
|
||
|
im = Image.open(image)
|
||
|
(width, height) = im.size # im.size returns (width,height) tuple
|
||
|
|
||
|
# first do an upload of image, then send URI of upload to room
|
||
|
file_stat = await aiofiles.os.stat(image)
|
||
|
async with aiofiles.open(image, "r+b") as f:
|
||
|
resp, maybe_keys = await client.upload(
|
||
|
f,
|
||
|
content_type=mime_type, # image/jpeg
|
||
|
filename=os.path.basename(image),
|
||
|
filesize=file_stat.st_size,
|
||
|
)
|
||
|
if isinstance(resp, UploadResponse):
|
||
|
print("Image was uploaded successfully to server. ")
|
||
|
else:
|
||
|
print(f"Failed to upload image. Failure response: {resp}")
|
||
|
|
||
|
content = {
|
||
|
"body": os.path.basename(image), # descriptive title
|
||
|
"info": {
|
||
|
"size": file_stat.st_size,
|
||
|
"mimetype": mime_type,
|
||
|
"thumbnail_info": None, # TODO
|
||
|
"w": width, # width in pixel
|
||
|
"h": height, # height in pixel
|
||
|
"thumbnail_url": None, # TODO
|
||
|
},
|
||
|
"msgtype": "m.image",
|
||
|
"url": resp.content_uri,
|
||
|
}
|
||
|
|
||
|
try:
|
||
|
await client.room_send(room_id, message_type="m.room.message", content=content)
|
||
|
print("Image was sent successfully")
|
||
|
except Exception:
|
||
|
print(f"Image send of file {image} failed.")
|
||
|
|
||
|
async def main() -> None:
|
||
|
config.read('bot.conf')
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
for section in config.sections():
|
||
|
if section == 'DEFAULT' or section == 'Common':
|
||
|
pass
|
||
|
botname = section
|
||
|
homeserver = config[section]['url']
|
||
|
user_id = config[section]['username']
|
||
|
password = config[section]['password']
|
||
|
bot = ChatBot(homeserver, user_id, password)
|
||
|
bot.character_init(botname, config[section]['persona'], config[section]['scenario'], config[section]['greeting'])
|
||
|
if config.has_option(section, 'owner'):
|
||
|
bot.owner = config[section]['owner']
|
||
|
if config.has_option(section, 'translate'):
|
||
|
bot.translate = config[section]['translate']
|
||
|
translate.init(bot.translate, "en")
|
||
|
translate.init("en", bot.translate)
|
||
|
if config.has_option(section, 'image_prompt'):
|
||
|
bot.image_prompt = config[section]['image_prompt']
|
||
|
if config.has_option('DEFAULT', 'runpod_api_key'):
|
||
|
bot.runpod_api_key = config['DEFAULT']['runpod_api_key']
|
||
|
bots.append(bot)
|
||
|
await bot.login()
|
||
|
print("logged in")
|
||
|
|
||
|
asyncio.get_event_loop().run_until_complete(main())
|