Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

138 lines
4.3 KiB

#!/usr/bin/env python3
import asyncio
import os, sys
import json
from .utilities.config_parser import read_config
from .bot.core import ChatBot
from .connections.matrix import ChatClient
from .connections.webui import WebUI
import traceback
import signal
import functools
import logging
logger = logging.getLogger(__name__)
DATA_DIR = './.data'
bots = []
async def main() -> None:
config = read_config('bot.conf')
if config.has_option('DEFAULT', 'log_level'):
log_level = config['DEFAULT']['log_level']
if log_level == 'DEBUG':
logging.basicConfig(level=logging.DEBUG)
elif log_level == 'INFO':
logging.basicConfig(level=logging.INFO)
elif log_level == 'WARNING':
logging.basicConfig(level=logging.WARNING)
elif log_level == 'ERROR':
logging.basicConfig(level=logging.ERROR)
elif log_level == 'CRITICAL':
logging.basicConfig(level=logging.CRITICAL)
# loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
loop.set_debug(True)
os.makedirs(DATA_DIR, exist_ok=True)
for section in config.sections():
bot_config = config[section]
connection = ChatClient(bot_config['matrix_homeserver'], bot_config['matrix_username'], bot_config['matrix_password'], bot_config.get('matrix_device_name', 'matrix-nio'))
await connection.persist(f"{DATA_DIR}/{section}/matrix")
bot = ChatBot(section, connection)
await bot.persist(f"{DATA_DIR}/{section}")
bot.init_character(
bot_config['persona'],
bot_config['scenario'],
bot_config['greeting'],
json.loads(bot_config.get('example_dialogue', "[]")),
bot_config.get('nsfw', False),
bot_config.get('temperature', 0.72),
)
if config.has_option(section, 'owner'):
bot.owner = config[section]['owner']
# if config.has_option(section, 'translate'):
# bot.translate = config[section]['translate']
# translate.init(bot.translate, "en")
# translate.init("en", bot.translate)
await bot.load_ai(
json.loads(bot_config['available_text_endpoints']),
json.loads(bot_config['available_image_endpoints']),
)
await bot.connect()
bots.append(bot)
webui = WebUI()
await webui.connect(bots)
async def shutdown(signal, loop):
"""Cleanup tasks and shut down"""
logger.info(f"Received exit signal {signal.name} ...")
await webui.stop()
for bot in bots:
await bot.disconnect()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in tasks]
logging.info(f"Cancelling {len(tasks)} outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True)
logging.info(f"Flushing metrics")
loop.stop()
# for signame in {'SIGINT', 'SIGTERM'}:
# loop.add_signal_handler(
# getattr(signal, signame),
# functools.partial(shutdown, signame, loop))
for s in {signal.SIGHUP, signal.SIGTERM, signal.SIGINT}:
loop.add_signal_handler(
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
try:
if sys.version_info[0] == 3 and sys.version_info[1] < 11:
tasks = []
for bot in bots:
task = asyncio.create_task(bot.connection.sync_forever(timeout=180000, full_state=True)) # 30000
tasks.append(task)
webui.task = asyncio.create_task(webui.run_task())
tasks.append(webui.task)
await asyncio.gather(*tasks)
else:
async with asyncio.TaskGroup() as tg:
for bot in bots:
task = tg.create_task(bot.connection.sync_forever(timeout=180000, full_state=True)) # 30000
webui.task = tg.create_task(webui.run_task())
# except Exception:
# print(traceback.format_exc())
# sys.exit(1)
except (asyncio.CancelledError, KeyboardInterrupt):
print("Received keyboard interrupt.")
# webui.task.cancel()
# for bot in bots:
# await bot.disconnect()
# sys.exit(0)
finally:
pass
#loop.close()
if __name__ == "__main__":
asyncio.run(main())