You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
320 lines
16 KiB
320 lines
16 KiB
import asyncio
|
|
import concurrent.futures
|
|
import os, sys
|
|
import time
|
|
import importlib
|
|
import re
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
import psutil
|
|
from functools import partial
|
|
from .memory.chatlog import ChatLog
|
|
from .utilities.messages import Message
|
|
from .ai.langchain import AI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ChatBot(object):
|
|
"""Main chatbot"""
|
|
|
|
def __init__(self, name, connection):
|
|
self.name = name
|
|
self.connection = connection
|
|
#self.say_cb = None
|
|
self.chatlog = ChatLog(self.name)
|
|
self.rooms = {}
|
|
self.queue = asyncio.Queue(maxsize=0)
|
|
self.background_tasks = set()
|
|
task = asyncio.create_task(self.worker(f'worker-{self.name}', self.queue))
|
|
self.background_tasks.add(task)
|
|
task.add_done_callback(self.background_tasks.discard)
|
|
#loop = asyncio.get_running_loop()
|
|
#with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
# task = loop.run_in_executor(pool, self.worker, f'worker-{self.name}', self.queue)
|
|
|
|
event_loop_task = asyncio.create_task(self.event_loop())
|
|
self.background_tasks.add(event_loop_task)
|
|
event_loop_task.add_done_callback(self.background_tasks.discard)
|
|
#print(f"Hello, I'm {name}")
|
|
|
|
def init_character(self, persona, scenario, greeting, example_dialogue=[], nsfw=False, temperature=0.72):
|
|
self.persona = persona.replace('\\n', '\n').replace('{{char}}', self.name)
|
|
self.scenario = scenario.replace('\\n', '\n').replace('{{char}}', self.name)
|
|
self.greeting = greeting.replace('\\n', '\n').replace('{{char}}', self.name)
|
|
self.example_dialogue = [i.replace('\\n', '\n').replace('{{char}}', self.name) for i in example_dialogue]
|
|
# .replace("\\n", "\n") ??????
|
|
self.nsfw = nsfw
|
|
self.temperature = temperature
|
|
|
|
async def persist(self, data_dir):
|
|
self.chatlog_path = os.path.join(data_dir, "chatlogs/")
|
|
self.images_path = os.path.join(data_dir, "images/")
|
|
self.memory_path = os.path.join(data_dir, "memory/")
|
|
self.rooms_conf_file = os.path.join(data_dir, "rooms.conf")
|
|
os.makedirs(self.chatlog_path, exist_ok=True)
|
|
os.makedirs(self.images_path, exist_ok=True)
|
|
os.makedirs(self.memory_path, exist_ok=True)
|
|
self.chatlog.enable_logging(self.chatlog_path)
|
|
self.rooms = await self.read_conf2()
|
|
|
|
async def read_conf2(self):
|
|
if not os.path.isfile(self.rooms_conf_file):
|
|
return {}
|
|
with open(self.rooms_conf_file, "r") as f:
|
|
return json.load(f)
|
|
|
|
async def write_conf2(self, data):
|
|
with open(self.rooms_conf_file, "w") as f:
|
|
json.dump(data, f)
|
|
|
|
async def connect(self):
|
|
self.connection.callbacks.add_message_callback(self.message_cb, self.redaction_cb)
|
|
await self.connection.login()
|
|
await self.schedule(self.queue, print, f"Hello, I'm {self.name}")
|
|
|
|
async def disconnect(self):
|
|
# Wait until the queue is fully processed.
|
|
await self.queue.join()
|
|
# Cancel our worker tasks.
|
|
for task in self.background_tasks:
|
|
task.cancel()
|
|
# Wait until all worker tasks are cancelled.
|
|
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
|
await self.connection.logout()
|
|
|
|
async def load_ai(self, available_text_endpoints, available_image_endpoints):
|
|
# module_text_ai = importlib.import_module("bot.ai.langchain", package=None)
|
|
# self.text_ai = module_text_ai.AI(self)
|
|
|
|
from .wrappers.runpod import RunpodTextWrapper
|
|
from .wrappers.stablehorde import StableHordeTextWrapper
|
|
from .wrappers.koboldcpp import KoboldCppTextWrapper
|
|
self.text_generators = {}
|
|
for text_endpoint in sorted(available_text_endpoints, key=lambda d: d['id']):
|
|
if text_endpoint['service'] == "koboldcpp":
|
|
text_generator = KoboldCppTextWrapper(text_endpoint['endpoint'], text_endpoint['model'])
|
|
elif text_endpoint['service'] == "stablehorde":
|
|
text_generator = StableHordeTextWrapper(text_endpoint['api_key'], text_endpoint['endpoint'], text_endpoint['model'])
|
|
elif text_endpoint['service'] == "runpod":
|
|
text_generator = RunpodTextWrapper(text_endpoint['api_key'], text_endpoint['endpoint'], text_endpoint['model'])
|
|
else:
|
|
raise ValueError(f"no text service with the name \"{service_text}\"")
|
|
i = text_endpoint['id']
|
|
self.text_generators[i] = text_generator
|
|
|
|
from .wrappers.runpod import RunpodImageWrapper
|
|
from .wrappers.runpod import RunpodImageAutomaticWrapper
|
|
from .wrappers.stablehorde import StableHordeImageWrapper
|
|
self.image_generators = {}
|
|
for image_endpoint in sorted(available_image_endpoints, key=lambda d: d['id']):
|
|
if image_endpoint['service'] == "runpod":
|
|
image_generator = RunpodImageWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model'])
|
|
elif image_endpoint['service'] == "runpod-automatic1111":
|
|
image_generator = RunpodImageAutomaticWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model'])
|
|
elif image_endpoint['service'] == "stablehorde":
|
|
image_generator = StableHordeImageWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model'])
|
|
else:
|
|
raise ValueError(f"no image service with the name \"{service_image}\"")
|
|
i = image_endpoint['id']
|
|
# def make_fn_generate_image_for_endpoint(wrapper, endpoint):
|
|
# async def generate_image(input_prompt, negative_prompt, typing_fn, timeout=180):
|
|
# return await wrapper.generate(input_prompt, negative_prompt, endpoint, typing_fn, timeout)
|
|
# return generate_image
|
|
# #self.image_generators.append(generate_image)
|
|
# image_generators[i] = make_fn_generate_image_for_endpoint(image_generator, image_endpoint['endpoint'])
|
|
self.image_generators[i] = image_generator
|
|
|
|
self.ai = AI(self, self.text_generators, self.image_generators, self.memory_path)
|
|
|
|
|
|
async def message_cb(self, room, event) -> None:
|
|
message = Message.from_matrix(room, event)
|
|
reply_fn = partial(self.connection.send_message, room.room_id)
|
|
typing_fn = lambda : self.connection.room_typing(room.room_id, True, 15000)
|
|
|
|
message.user_name = message.user_name.title()
|
|
if self.name.casefold() == message.user_name.casefold():
|
|
"""Bot and user have the same name"""
|
|
message.user_name += " 2" # or simply "You"
|
|
|
|
if not room.room_id in self.rooms:
|
|
self.rooms[room.room_id] = {}
|
|
self.rooms[room.room_id]['tick'] = 0
|
|
self.rooms[room.room_id]['num_messages'] = 0
|
|
self.rooms[room.room_id]['diary'] = {}
|
|
await self.write_conf2(self.rooms)
|
|
# ToDo: set ticks 0 / start
|
|
|
|
if not self.connection.synced:
|
|
if not message.is_command() and not message.is_error():
|
|
await self.ai.add_chat_message(message)
|
|
self.chatlog.save(message, False)
|
|
return
|
|
|
|
if message.is_from(self.connection.user_id):
|
|
"""Skip messages from ouselves"""
|
|
message.role = "ai"
|
|
self.chatlog.save(message)
|
|
await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id)
|
|
return
|
|
else:
|
|
message.role = "human"
|
|
|
|
# if event.decrypted:
|
|
# encrypted_symbol = "🛡 "
|
|
# else:
|
|
# encrypted_symbol = "⚠️ "
|
|
# print(
|
|
# f"{room.display_name} |{encrypted_symbol}| {room.user_name(event.sender)}: {event.body}"
|
|
# )
|
|
# if room.is_named:
|
|
# print(f"room.display_name: {room.display_name}")
|
|
# if room.is_group:
|
|
# print(f"room.group_name(): {room.group_name()}")
|
|
# print(f"room.joined_count: {room.joined_count}")
|
|
# print(f"room.member_count: {room.member_count}")
|
|
# print(f"room.encrypted: {room.encrypted}")
|
|
# print(f"room.users: {room.users}")
|
|
# print(f"room.room_id: {room.room_id}")
|
|
|
|
|
|
|
|
if hasattr(self, "owner"):
|
|
if not message.is_from(self.owner):
|
|
self.chatlog.save(message)
|
|
return
|
|
|
|
if "disabled" in self.rooms[message.room_id] and self.rooms[message.room_id]["disabled"] == True and not (message.message.startswith('!start') or message.message.startswith('!begin')):
|
|
return
|
|
|
|
await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id)
|
|
|
|
if message.is_command():
|
|
await self.schedule(self.queue, self.process_command, message, reply_fn, typing_fn)
|
|
# elif re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body, flags=re.IGNORECASE):
|
|
# # send, mail, drop, snap picture, photo, image, portrait
|
|
elif message.is_error():
|
|
return
|
|
else:
|
|
await self.schedule(self.queue, self.process_message, message, reply_fn, typing_fn)
|
|
self.rooms[room.room_id]['num_messages'] += 1
|
|
self.last_conversation = datetime.now()
|
|
self.chatlog.save(message)
|
|
print("done")
|
|
|
|
async def redaction_cb(self, room, event) -> None:
|
|
self.chatlog.remove_message_by_id(event.event_id)
|
|
|
|
async def process_command(self, message, reply_fn, typing_fn):
|
|
if message.message.startswith("!replybot"):
|
|
await reply_fn("Hello World")
|
|
elif re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", message.message, flags=re.DOTALL):
|
|
m = re.search("(?s)^!image(?P<num>[0-9])?(\s(?P<cmd>.*))?$", message.message, flags=re.DOTALL)
|
|
if m['num']:
|
|
num = int(m['num'])
|
|
else:
|
|
num = 1
|
|
if m['cmd']:
|
|
prompt = m['cmd'].strip()
|
|
else:
|
|
prompt = "a beautiful woman"
|
|
negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head"
|
|
#"anime, cartoon, penis, fake, drawing, illustration, boring, 3d render, long neck, out of frame, extra fingers, mutated hands, monochrome, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, glitchy, bokeh, (((long neck))), 3D, 3DCG, cgstation, red eyes, multiple subjects, extra heads, close up, watermarks, logo"
|
|
#"ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face"
|
|
#"ugly, deformed, out of frame"
|
|
try:
|
|
output = await self.image_generators[num].generate(prompt, negative_prompt, typing_fn)
|
|
await self.connection.room_typing(message.room_id, False)
|
|
for imagefile in output:
|
|
await self.connection.send_image(message.room_id, imagefile)
|
|
except (KeyError, IndexError, ValueError) as err:
|
|
#await self.connection.room_typing(message.room_id, False)
|
|
errormessage = f"<ERROR> {err=}, {type(err)=}"
|
|
logger.error(errormessage)
|
|
await reply_fn(errormessage)
|
|
|
|
elif message.message.startswith("!image_negative_prompt"):
|
|
self.negative_prompt = message.message.removeprefix('!image_negative_prompt').strip()
|
|
elif message.message.startswith('!temperature'):
|
|
self.temperature = float( message.message.removeprefix('!temperature').strip() )
|
|
elif message.message.startswith('!begin'):
|
|
self.rooms[message.room_id]["disabled"] = False
|
|
await self.write_conf2(self.rooms)
|
|
self.chatlog.clear(message.room_id)
|
|
await self.ai.clear(message.room_id)
|
|
# ToDo reset time / ticks
|
|
await reply_fn(self.greeting)
|
|
elif message.message.startswith('!start'):
|
|
self.rooms[message.room_id]["disabled"] = False
|
|
await self.write_conf2(self.rooms)
|
|
elif message.message.startswith('!stop'):
|
|
self.rooms[message.room_id]["disabled"] = True
|
|
await self.write_conf2(self.rooms)
|
|
elif message.message.startswith('!sleep'):
|
|
await self.schedule(self.queue, self.ai.sleep)
|
|
elif message.message.startswith('!!'):
|
|
if self.chatlog.chat_history_len(message.room_id) > 2:
|
|
for _ in range(2):
|
|
old_message = self.chatlog.remove_message_in_room(message.room_id, 1)
|
|
await self.connection.room_redact(message.room_id, old_message.event_id, reason="user-request")
|
|
message = self.chatlog.get_last_message(message.room_id)
|
|
|
|
|
|
async def process_message(self, message, reply_fn, typing_fn):
|
|
output = await self.ai.generate_roleplay(message.to_langchain(), reply_fn, typing_fn)
|
|
#output = await self.ai.generate(message, reply_fn, typing_fn)
|
|
# typing false
|
|
#await reply_fn(output)
|
|
|
|
|
|
async def event_loop(self):
|
|
try:
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
for room_id in self.rooms.keys():
|
|
self.rooms[room_id]["tick"] += 1
|
|
if datetime.now().hour >= 1 and datetime.now().hour < 5:
|
|
load1, load5, load15 = [x / psutil.cpu_count() * 100 for x in psutil.getloadavg()]
|
|
if load5 < 25 and load1 < 25:
|
|
if not hasattr(self, "last_sleep") or self.last_sleep + timedelta(hours=6) < datetime.now():
|
|
await self.ai.sleep()
|
|
self.last_sleep = datetime.now()
|
|
|
|
|
|
finally:
|
|
pass
|
|
# await self.write_conf2(self.name)
|
|
|
|
async def worker(self, name: str, q: asyncio.Queue) -> None:
|
|
while True:
|
|
cb, args, kwargs = await q.get()
|
|
start = time.perf_counter()
|
|
if asyncio.iscoroutinefunction(cb):
|
|
logger.info("queued task started (coroutine)")
|
|
await cb(*args, **kwargs)
|
|
else:
|
|
logger.info("queued task started (function)")
|
|
cb(*args, **kwargs)
|
|
q.task_done()
|
|
elapsed = time.perf_counter() - start
|
|
logger.info(f"Queued task done in {elapsed:0.5f} seconds.")
|
|
logger.debug("queue item processed")
|
|
|
|
async def schedule(self, q: asyncio.Queue, cb, *args, **kwargs) -> None:
|
|
logger.info(f"queuing task")
|
|
await q.put((cb, args, kwargs))
|
|
#q.put_nowait((cb, args, kwargs))
|
|
|
|
async def schedule_task(self, done_callback, cb, *args, **kwargs):
|
|
logger.info(f"creating background task")
|
|
task = asyncio.create_task(cb(*args, **kwargs))
|
|
task.add_done_callback(done_callback)
|
|
self.background_tasks.add(task)
|
|
task.add_done_callback(self.background_tasks.discard)
|
|
|
|
# closure
|
|
def outer_function(self, x):
|
|
def inner_funtion(y):
|
|
return x+y
|
|
return inner_function
|
|
|