#!/usr/bin/env python3 import asyncio import os, sys import json from .utilities.config_parser import read_config from .bot.core import ChatBot from .connections.matrix import ChatClient from .connections.webui import WebUI import traceback import signal import functools import logging logger = logging.getLogger(__name__) DATA_DIR = './.data' bots = [] async def main() -> None: config = read_config('bot.conf') if config.has_option('DEFAULT', 'log_level'): log_level = config['DEFAULT']['log_level'] if log_level == 'DEBUG': logging.basicConfig(level=logging.DEBUG) elif log_level == 'INFO': logging.basicConfig(level=logging.INFO) elif log_level == 'WARNING': logging.basicConfig(level=logging.WARNING) elif log_level == 'ERROR': logging.basicConfig(level=logging.ERROR) elif log_level == 'CRITICAL': logging.basicConfig(level=logging.CRITICAL) # loop = asyncio.get_event_loop() loop = asyncio.get_running_loop() loop.set_debug(True) os.makedirs(DATA_DIR, exist_ok=True) for section in config.sections(): bot_config = config[section] connection = ChatClient(bot_config['matrix_homeserver'], bot_config['matrix_username'], bot_config['matrix_password'], bot_config.get('matrix_device_name', 'matrix-nio')) await connection.persist(f"{DATA_DIR}/{section}/matrix") bot = ChatBot(section, connection) await bot.persist(f"{DATA_DIR}/{section}") bot.init_character( bot_config['persona'], bot_config['scenario'], bot_config['greeting'], json.loads(bot_config.get('example_dialogue', "[]")), bot_config.get('nsfw', False), bot_config.get('temperature', 0.72), ) if config.has_option(section, 'owner'): bot.owner = config[section]['owner'] # if config.has_option(section, 'translate'): # bot.translate = config[section]['translate'] # translate.init(bot.translate, "en") # translate.init("en", bot.translate) await bot.load_ai( json.loads(bot_config['available_text_endpoints']), json.loads(bot_config['available_image_endpoints']), ) await bot.connect() bots.append(bot) webui = WebUI() async def shutdown(signal, loop): """Cleanup tasks and shut down""" logger.info(f"Received exit signal {signal.name} ...") await webui.stop() for bot in bots: await bot.disconnect() tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] [task.cancel() for task in tasks] logging.info(f"Cancelling {len(tasks)} outstanding tasks") await asyncio.gather(*tasks, return_exceptions=True) logging.info(f"Flushing metrics") loop.stop() # for signame in {'SIGINT', 'SIGTERM'}: # loop.add_signal_handler( # getattr(signal, signame), # functools.partial(shutdown, signame, loop)) for s in {signal.SIGHUP, signal.SIGTERM, signal.SIGINT}: loop.add_signal_handler( s, lambda s=s: asyncio.create_task(shutdown(s, loop))) try: if sys.version_info[0] == 3 and sys.version_info[1] < 11: tasks = [] for bot in bots: task = asyncio.create_task(bot.connection.sync_forever(timeout=180000, full_state=True)) # 30000 tasks.append(task) webui.task = asyncio.create_task(webui.run_task()) tasks.append(webui.task) await asyncio.gather(*tasks) else: async with asyncio.TaskGroup() as tg: for bot in bots: task = tg.create_task(bot.connection.sync_forever(timeout=180000, full_state=True)) # 30000 webui.task = tg.create_task(webui.run_task()) # except Exception: # print(traceback.format_exc()) # sys.exit(1) except (asyncio.CancelledError, KeyboardInterrupt): print("Received keyboard interrupt.") # webui.task.cancel() # for bot in bots: # await bot.disconnect() # sys.exit(0) finally: pass #loop.close() if __name__ == "__main__": asyncio.run(main())