You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
498 lines
24 KiB
498 lines
24 KiB
import asyncio
|
|
import nio
|
|
from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse, RedactionEvent)
|
|
|
|
import os, sys
|
|
import time
|
|
import importlib
|
|
import configparser
|
|
import logging
|
|
|
|
import aiofiles.os
|
|
import magic
|
|
from PIL import Image
|
|
import re
|
|
import json
|
|
|
|
from .helpers import Event
|
|
from .chatlog import BotChatHistory
|
|
image_ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion")
|
|
text_ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp")
|
|
#ai = importlib.import_module("matrix_pygmalion_bot.ai.stablehorde")
|
|
#from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style
|
|
#from .runpod_pygmalion import generate_sync, get_full_prompt
|
|
import matrix_pygmalion_bot.translate as translate
|
|
|
|
STORE_PATH = "./.store/"
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
config = configparser.ConfigParser()
|
|
bots = []
|
|
background_tasks = set()
|
|
|
|
class Callbacks(object):
|
|
"""Class to pass client to callback methods."""
|
|
|
|
def __init__(self, client: AsyncClient, bot):
|
|
self.client = client
|
|
self.bot = bot
|
|
|
|
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
|
if not hasattr(event, 'body'):
|
|
return
|
|
if not room.room_id in self.bot.room_config:
|
|
self.bot.room_config[room.room_id] = {}
|
|
self.bot.room_config[room.room_id]["tick"] = 0
|
|
relates_to = None
|
|
if 'm.relates_to' in event.source["content"]:
|
|
relates_to = event.source["content"]['m.relates_to']["event_id"]
|
|
is_command = False
|
|
if event.body.startswith('!'):
|
|
is_command = True
|
|
language = "en"
|
|
if not (self.bot.translate is None) and not is_command:
|
|
language = self.bot.translate
|
|
|
|
if 'original_message' in event.source["content"]:
|
|
english_original_message = event.source["content"]['original_message']
|
|
else:
|
|
english_original_message = None
|
|
|
|
chat_message = self.bot.chat_history.room(room.display_name).add(event.event_id, event.server_timestamp, room.user_name(event.sender), event.sender == self.client.user, is_command, relates_to, event.body, language, english_original_message)
|
|
if self.bot.not_synced:
|
|
return
|
|
logger.info(
|
|
"Message received for room {} | {}: {}".format(
|
|
room.display_name, room.user_name(event.sender), event.body
|
|
)
|
|
)
|
|
|
|
api_endpoint = "pygmalion-6b"
|
|
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
|
|
# Ignore messages when disabled
|
|
if "disabled" in self.bot.room_config[room.room_id] and self.bot.room_config[room.room_id]["disabled"] == True and not event.body.startswith('!start'):
|
|
return
|
|
# Ignore messages from ourselves
|
|
if chat_message.is_own_message:
|
|
return
|
|
# Ignore message from strangers
|
|
if not (self.bot.owner is None):
|
|
if not (event.sender == self.bot.owner or chat_message.is_own_message):
|
|
return
|
|
|
|
self.bot.user_name = room.user_name(event.sender)
|
|
|
|
if event.body.startswith('!replybot'):
|
|
print(event)
|
|
await self.bot.send_message(self.client, room.room_id, "Hello World!")
|
|
return
|
|
elif re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", event.body):
|
|
m = re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", event.body)
|
|
if m['num']:
|
|
num = int(m['num'])
|
|
else:
|
|
num = 1
|
|
if m['cmd']:
|
|
prompt = m['cmd'].strip()
|
|
if self.bot.image_prompt:
|
|
prompt = prompt.replace(self.bot.name, self.bot.image_prompt)
|
|
else:
|
|
if self.bot.image_prompt:
|
|
prompt = self.bot.image_prompt
|
|
else:
|
|
prompt = "a beautiful woman"
|
|
if self.bot.negative_prompt:
|
|
negative_prompt = self.bot.negative_prompt
|
|
elif num == 1:
|
|
negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head"
|
|
elif num == 5:
|
|
negative_prompt = "anime, cartoon, penis, fake, drawing, illustration, boring, 3d render, long neck, out of frame, extra fingers, mutated hands, monochrome, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, glitchy, bokeh, (((long neck))), 3D, 3DCG, cgstation, red eyes, multiple subjects, extra heads, close up, watermarks, logo"
|
|
else:
|
|
negative_prompt = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face"
|
|
# else:
|
|
# negative_prompt = "ugly, deformed, out of frame"
|
|
try:
|
|
typing = lambda : self.client.room_typing(room.room_id, True, 15000)
|
|
if self.bot.service == "runpod":
|
|
if num == 1:
|
|
output = await image_ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
elif num == 2:
|
|
output = await image_ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
elif num == 3:
|
|
output = await image_ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
elif num == 4:
|
|
output = await image_ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
elif num == 5:
|
|
output = await image_ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
elif num == 6:
|
|
output = await image_ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
elif num == 7:
|
|
output = await image_ai.generate_image7(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
elif num == 8:
|
|
output = await image_ai.generate_image8(prompt, negative_prompt, self.bot.runpod_api_key, typing)
|
|
else:
|
|
raise ValueError('no image generator with that number')
|
|
elif self.bot.service == "stablehorde":
|
|
if num == 1:
|
|
output = await image_ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing)
|
|
elif num == 2:
|
|
output = await image_ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing)
|
|
elif num == 3:
|
|
output = await image_ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing)
|
|
else:
|
|
raise ValueError('no image generator with that number')
|
|
else:
|
|
raise ValueError('remote image generation not configured properly')
|
|
except ValueError as err:
|
|
await self.client.room_typing(room.room_id, False)
|
|
errormessage = f"<ERROR> {err=}, {type(err)=}"
|
|
logger.error(errormessage)
|
|
await self.bot.send_message(self.client, room.room_id, errormessage)
|
|
return
|
|
|
|
await self.client.room_typing(room.room_id, False)
|
|
for imagefile in output:
|
|
await self.bot.send_image(self.client, room.room_id, imagefile)
|
|
return
|
|
|
|
elif event.body.startswith('!image_negative_prompt'):
|
|
negative_prompt = event.body.removeprefix('!image_negative_prompt').strip()
|
|
if len(negative_prompt) > 0:
|
|
self.bot.negative_prompt = negative_prompt
|
|
else:
|
|
self.bot.negative_prompt = None
|
|
return
|
|
elif event.body.startswith('!temperature'):
|
|
self.bot.temperature = float( event.body.removeprefix('!temperature').strip() )
|
|
elif event.body.startswith('!begin'):
|
|
self.bot.chat_history.room(room.display_name).clear()
|
|
self.bot.room_config[room.room_id]["tick"] = 0
|
|
await self.bot.write_conf2(self.bot.name)
|
|
await self.bot.send_message(self.client, room.room_id, self.bot.greeting)
|
|
return
|
|
elif event.body.startswith('!start'):
|
|
self.bot.room_config[room.room_id]["disabled"] = False
|
|
return
|
|
elif event.body.startswith('!stop'):
|
|
self.bot.room_config[room.room_id]["disabled"] = True
|
|
return
|
|
elif event.body.startswith('!!!'):
|
|
if self.bot.chat_history.room(room.display_name).getLen() < 3:
|
|
return
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
return
|
|
elif event.body.startswith('!!'):
|
|
if self.bot.chat_history.room(room.display_name).getLen() < 3:
|
|
return
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)# current
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_message = self.bot.chat_history.room(room.display_name).getLastItem() # new current
|
|
# don't return, we generate a new answer
|
|
elif event.body.startswith('!replace'):
|
|
if self.bot.chat_history.room(room.display_name).getLen() < 3:
|
|
return
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)
|
|
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
|
|
new_answer = event.body.removeprefix('!replace').strip()
|
|
await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event)
|
|
return
|
|
elif event.body.startswith('!2'):
|
|
chat_message.updateText( event.body.removeprefix('!2').strip() )
|
|
api_endpoint = "ynznznpn6qz6yh"
|
|
elif event.body.startswith('!'):
|
|
await self.bot.send_message(self.client, room.room_id, "<ERROR> UNKNOWN COMMAND")
|
|
return
|
|
|
|
# Other commands
|
|
if re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body):
|
|
# send, mail, drop, snap picture, photo, image, portrait
|
|
pass
|
|
|
|
full_prompt = await text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.display_name))
|
|
num_tokens = await text_ai.num_tokens(full_prompt)
|
|
logger.debug(full_prompt)
|
|
logger.info(f"Prompt has " + str(num_tokens) + " tokens")
|
|
# answer = ""
|
|
# time = 0
|
|
# error = None
|
|
# try:
|
|
# async for output in generate(full_prompt):
|
|
# await asyncio.sleep(0.1)
|
|
# answer += output
|
|
# if time % 5 == 0:
|
|
# await self.client.room_typing(room.room_id, True, 15000)
|
|
# time +=1
|
|
# print(output, end='', flush=True)
|
|
# except Exception as e:
|
|
# error = e.__str__()
|
|
# answer = answer.strip()
|
|
# print("")
|
|
try:
|
|
typing = lambda : self.client.room_typing(room.room_id, True, 15000)
|
|
answer = await text_ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing, api_endpoint)
|
|
answer = answer.strip()
|
|
await self.client.room_typing(room.room_id, False)
|
|
if not (self.bot.translate is None):
|
|
translated_answer = translate.translate(answer, "en", self.bot.translate)
|
|
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=chat_message.event_id, original_message=answer)
|
|
else:
|
|
await self.bot.send_message(self.client, room.room_id, answer, reply_to=chat_message.event_id)
|
|
if not "message_count" in self.bot.room_config[room.room_id]:
|
|
self.bot.room_config[room.room_id]["message_count"] = 0
|
|
self.bot.room_config[room.room_id]["message_count"] += 1
|
|
except ValueError as err:
|
|
await self.client.room_typing(room.room_id, False)
|
|
errormessage = f"<ERROR> {err=}, {type(err)=}"
|
|
logger.error(errormessage)
|
|
await self.bot.send_message(self.client, room.room_id, errormessage)
|
|
return
|
|
|
|
async def invite_cb(self, room: MatrixRoom, event: InviteEvent) -> None:
|
|
"""Automatically join all rooms we get invited to"""
|
|
result = await self.client.join(room.room_id)
|
|
if isinstance(result, nio.responses.JoinResponse):
|
|
logger.info('Invited and joined room: {} {}'.format(room.name, room.room_id))
|
|
else:
|
|
logger.error("Error joining room: {}".format(str(result)))
|
|
|
|
async def redaction_cb(self, room: MatrixRoom, event: RedactionEvent) -> None:
|
|
logger.info(f"event redacted in room {room.room_id}. event_id: {event.redacts}")
|
|
for bot in bots:
|
|
# for room in bot.chat_history.chat_rooms.keys():
|
|
if room.display_name in bot.chat_history.chat_rooms:
|
|
logger.info("room found")
|
|
if bot.chat_history.chat_rooms[room.display_name].exists_id(event.redacts):
|
|
logger.info("found it")
|
|
bot.chat_history.chat_rooms[room.display_name].remove_id(event.redacts)
|
|
|
|
class ChatBot(object):
|
|
"""Main chatbot"""
|
|
|
|
def __init__(self, homeserver, user_id, password, device_name="matrix-nio"):
|
|
self.homeserver = homeserver
|
|
self.user_id = user_id
|
|
self.password = password
|
|
self.device_name = device_name
|
|
|
|
self.runpod_api_key = None
|
|
|
|
self.client = None
|
|
self.callbacks = None
|
|
self.config = None
|
|
self.not_synced = True
|
|
|
|
self.owner = None
|
|
self.translate = None
|
|
self.user_name = "You"
|
|
|
|
self.name = None
|
|
self.persona = None
|
|
self.scenario = None
|
|
self.greeting = None
|
|
self.example_dialogue = []
|
|
self.temperature = 0.90
|
|
self.events = []
|
|
self.global_tick = 0
|
|
self.chat_history = None
|
|
self.room_config = {}
|
|
|
|
self.negative_prompt = None
|
|
|
|
if STORE_PATH and not os.path.isdir(STORE_PATH):
|
|
os.mkdir(STORE_PATH)
|
|
|
|
def character_init(self, name, persona, scenario, greeting, example_dialogue=[]):
|
|
self.name = name
|
|
self.persona = persona
|
|
self.scenario = scenario
|
|
self.greeting = greeting
|
|
self.example_dialogue = example_dialogue
|
|
|
|
async def event_loop(self):
|
|
try:
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
for room_id in self.room_config.keys():
|
|
for event in self.events:
|
|
event.loop(self, self.room_config[room_id]["tick"])
|
|
self.room_config[room_id]["tick"] += 1
|
|
self.global_tick += 1
|
|
if self.global_tick % 10 == 0:
|
|
await self.write_conf2(self.name)
|
|
finally:
|
|
await self.write_conf2(self.name)
|
|
|
|
async def add_event(self, event_string):
|
|
items = event_string.split(',', 4)
|
|
for item in items:
|
|
item = item.strip()
|
|
event = Event(int(items[0]), int(items[1]), float(items[2]), int(items[3]), items[4].lstrip())
|
|
self.events.append(event)
|
|
logger.debug("event added to event_loop")
|
|
pass
|
|
|
|
async def login(self):
|
|
self.config = AsyncClientConfig(store_sync_tokens=True)
|
|
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config)
|
|
self.chat_history = BotChatHistory(self.name)
|
|
self.callbacks = Callbacks(self.client, self)
|
|
self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText)
|
|
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent)
|
|
self.client.add_event_callback(self.callbacks.redaction_cb, RedactionEvent)
|
|
|
|
sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced))
|
|
event_loop = asyncio.create_task(self.event_loop())
|
|
background_tasks.add(event_loop)
|
|
event_loop.add_done_callback(background_tasks.discard)
|
|
|
|
try:
|
|
response = await self.client.login(self.password)
|
|
logger.info(response)
|
|
#sync_forever_task = asyncio.create_task(self.client.sync_forever(timeout=30000, full_state=True))
|
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
logger.error("Received interrupt while login.")
|
|
await self.client.close()
|
|
#return sync_forever_task
|
|
|
|
async def watch_for_sync(self, sync_event):
|
|
logger.debug("Awaiting sync")
|
|
await sync_event.wait()
|
|
logger.debug("Client is synced")
|
|
self.not_synced = False
|
|
|
|
async def read_conf2(self, section):
|
|
if not os.path.isfile("bot.conf2"):
|
|
return
|
|
with open("bot.conf2", "r") as f:
|
|
self.room_config = json.load(f)
|
|
|
|
async def write_conf2(self, section):
|
|
with open("bot.conf2", "w") as f:
|
|
json.dump(self.room_config, f)
|
|
|
|
async def send_message(self, client, room_id, message, reply_to=None, original_message=None):
|
|
content={"msgtype": "m.text", "body": message}
|
|
if reply_to:
|
|
content["m.relates_to"] = {"event_id": reply_to, "rel_type": "de.xd0.mpygbot.in_reply_to"}
|
|
if original_message:
|
|
content["original_message"] = original_message
|
|
|
|
await client.room_send(
|
|
room_id=room_id,
|
|
message_type="m.room.message",
|
|
content=content,
|
|
)
|
|
|
|
async def send_image(self, client, room_id, image):
|
|
"""Send image to room
|
|
https://matrix-nio.readthedocs.io/en/latest/examples.html#sending-an-image
|
|
"""
|
|
mime_type = magic.from_file(image, mime=True) # e.g. "image/jpeg"
|
|
if not mime_type.startswith("image/"):
|
|
logger.error("Drop message because file does not have an image mime type.")
|
|
return
|
|
|
|
im = Image.open(image)
|
|
(width, height) = im.size # im.size returns (width,height) tuple
|
|
|
|
# first do an upload of image, then send URI of upload to room
|
|
file_stat = await aiofiles.os.stat(image)
|
|
async with aiofiles.open(image, "r+b") as f:
|
|
resp, maybe_keys = await client.upload(
|
|
f,
|
|
content_type=mime_type, # image/jpeg
|
|
filename=os.path.basename(image),
|
|
filesize=file_stat.st_size,
|
|
)
|
|
if isinstance(resp, UploadResponse):
|
|
logger.info("Image was uploaded successfully to server. ")
|
|
else:
|
|
logger.error(f"Failed to upload image. Failure response: {resp}")
|
|
|
|
content = {
|
|
"body": os.path.basename(image), # descriptive title
|
|
"info": {
|
|
"size": file_stat.st_size,
|
|
"mimetype": mime_type,
|
|
"thumbnail_info": None, # TODO
|
|
"w": width, # width in pixel
|
|
"h": height, # height in pixel
|
|
"thumbnail_url": None, # TODO
|
|
},
|
|
"msgtype": "m.image",
|
|
"url": resp.content_uri,
|
|
}
|
|
|
|
try:
|
|
await client.room_send(room_id, message_type="m.room.message", content=content)
|
|
logger.info("Image was sent successfully")
|
|
except Exception:
|
|
logger.error(f"Image send of file {image} failed.")
|
|
|
|
async def main() -> None:
|
|
config.read('bot.conf')
|
|
logging.basicConfig(level=logging.INFO)
|
|
for section in config.sections():
|
|
if section == 'DEFAULT' or section == 'Common':
|
|
pass
|
|
botname = section
|
|
homeserver = config[section]['url']
|
|
user_id = config[section]['username']
|
|
password = config[section]['password']
|
|
if config.has_option(section, 'device_name'):
|
|
device_name = config[section]['device_name']
|
|
else:
|
|
device_name = "matrix-nio"
|
|
bot = ChatBot(homeserver, user_id, password, device_name)
|
|
if config.has_option(section, 'example_dialogue'):
|
|
example_dialogue = json.loads(config[section]['example_dialogue'])
|
|
else:
|
|
example_dialogue = []
|
|
bot.character_init(botname, config[section]['persona'].replace("\\n", "\n"), config[section]['scenario'].replace("\\n", "\n"), config[section]['greeting'].replace("\\n", "\n"), example_dialogue)
|
|
if config.has_option(section, 'temperature'):
|
|
bot.temperature = config[section]['temperature']
|
|
if config.has_option(section, 'owner'):
|
|
bot.owner = config[section]['owner']
|
|
if config.has_option(section, 'translate'):
|
|
bot.translate = config[section]['translate']
|
|
translate.init(bot.translate, "en")
|
|
translate.init("en", bot.translate)
|
|
if config.has_option(section, 'image_prompt'):
|
|
bot.image_prompt = config[section]['image_prompt']
|
|
if config.has_option(section, 'events'):
|
|
events = config[section]['events'].strip().split('\n')
|
|
for event in events:
|
|
await bot.add_event(event)
|
|
if config.has_option('DEFAULT', 'service'):
|
|
bot.service = config['DEFAULT']['service']
|
|
if config.has_option('DEFAULT', 'runpod_api_key'):
|
|
bot.runpod_api_key = config['DEFAULT']['runpod_api_key']
|
|
if config.has_option('DEFAULT', 'stablehorde_api_key'):
|
|
bot.stablehorde_api_key = config['DEFAULT']['stablehorde_api_key']
|
|
await bot.read_conf2(section)
|
|
bots.append(bot)
|
|
await bot.login()
|
|
#logger.info("gather")
|
|
if sys.version_info[0] == 3 and sys.version_info[1] < 11:
|
|
tasks = []
|
|
for bot in bots:
|
|
task = asyncio.create_task(bot.client.sync_forever(timeout=30000, full_state=True))
|
|
tasks.append(task)
|
|
await asyncio.gather(*tasks)
|
|
else:
|
|
async with asyncio.TaskGroup() as tg:
|
|
for bot in bots:
|
|
task = tg.create_task(bot.client.sync_forever(timeout=30000, full_state=True))
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.get_event_loop().run_until_complete(main())
|
|
|