Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
16 KiB

2 years ago
import asyncio
2 years ago
import concurrent.futures
2 years ago
import os, sys
import time
import importlib
import re
import json
2 years ago
import logging
2 years ago
from datetime import datetime, timedelta
import psutil
2 years ago
from functools import partial
from .memory.chatlog import ChatLog
from .utilities.messages import Message
from .ai.langchain import AI
logger = logging.getLogger(__name__)
class ChatBot(object):
"""Main chatbot"""
def __init__(self, name, connection):
self.name = name
self.connection = connection
#self.say_cb = None
self.chatlog = ChatLog(self.name)
self.rooms = {}
self.queue = asyncio.Queue(maxsize=0)
self.background_tasks = set()
task = asyncio.create_task(self.worker(f'worker-{self.name}', self.queue))
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)
2 years ago
#loop = asyncio.get_running_loop()
#with concurrent.futures.ThreadPoolExecutor() as pool:
# task = loop.run_in_executor(pool, self.worker, f'worker-{self.name}', self.queue)
event_loop_task = asyncio.create_task(self.event_loop())
self.background_tasks.add(event_loop_task)
event_loop_task.add_done_callback(self.background_tasks.discard)
2 years ago
#print(f"Hello, I'm {name}")
def init_character(self, persona, scenario, greeting, example_dialogue=[], nsfw=False, temperature=0.72):
self.persona = persona.replace('\\n', '\n').replace('{{char}}', self.name)
self.scenario = scenario.replace('\\n', '\n').replace('{{char}}', self.name)
self.greeting = greeting.replace('\\n', '\n').replace('{{char}}', self.name)
self.example_dialogue = [i.replace('\\n', '\n').replace('{{char}}', self.name) for i in example_dialogue]
2 years ago
# .replace("\\n", "\n") ??????
self.nsfw = nsfw
self.temperature = temperature
async def persist(self, data_dir):
self.chatlog_path = os.path.join(data_dir, "chatlogs/")
self.images_path = os.path.join(data_dir, "images/")
self.memory_path = os.path.join(data_dir, "memory/")
self.rooms_conf_file = os.path.join(data_dir, "rooms.conf")
2 years ago
os.makedirs(self.chatlog_path, exist_ok=True)
os.makedirs(self.images_path, exist_ok=True)
os.makedirs(self.memory_path, exist_ok=True)
self.chatlog.enable_logging(self.chatlog_path)
self.rooms = await self.read_conf2()
async def read_conf2(self):
if not os.path.isfile(self.rooms_conf_file):
return {}
with open(self.rooms_conf_file, "r") as f:
return json.load(f)
async def write_conf2(self, data):
with open(self.rooms_conf_file, "w") as f:
json.dump(data, f)
2 years ago
async def connect(self):
self.connection.callbacks.add_message_callback(self.message_cb, self.redaction_cb)
2 years ago
await self.connection.login()
2 years ago
await self.schedule(self.queue, print, f"Hello, I'm {self.name}")
async def disconnect(self):
# Wait until the queue is fully processed.
await self.queue.join()
# Cancel our worker tasks.
for task in self.background_tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*self.background_tasks, return_exceptions=True)
2 years ago
await self.write_conf2(self.rooms)
2 years ago
await self.connection.logout()
async def load_ai(self, available_text_endpoints, available_image_endpoints):
# module_text_ai = importlib.import_module("bot.ai.langchain", package=None)
# self.text_ai = module_text_ai.AI(self)
2 years ago
from .wrappers.runpod import RunpodTextWrapper
from .wrappers.stablehorde import StableHordeTextWrapper
from .wrappers.koboldcpp import KoboldCppTextWrapper
self.text_generators = {}
2 years ago
for text_endpoint in sorted(available_text_endpoints, key=lambda d: d['id']):
if text_endpoint['service'] == "koboldcpp":
2 years ago
text_generator = KoboldCppTextWrapper(text_endpoint['endpoint'], text_endpoint['model'])
2 years ago
elif text_endpoint['service'] == "stablehorde":
2 years ago
text_generator = StableHordeTextWrapper(text_endpoint['api_key'], text_endpoint['endpoint'], text_endpoint['model'])
2 years ago
elif text_endpoint['service'] == "runpod":
2 years ago
text_generator = RunpodTextWrapper(text_endpoint['api_key'], text_endpoint['endpoint'], text_endpoint['model'])
2 years ago
else:
raise ValueError(f"no text service with the name \"{service_text}\"")
i = text_endpoint['id']
2 years ago
self.text_generators[i] = text_generator
2 years ago
from .wrappers.runpod import RunpodImageWrapper, RunpodImageWrapper2
2 years ago
from .wrappers.runpod import RunpodImageAutomaticWrapper
from .wrappers.stablehorde import StableHordeImageWrapper
self.image_generators = {}
2 years ago
for image_endpoint in sorted(available_image_endpoints, key=lambda d: d['id']):
if image_endpoint['service'] == "runpod-automatic1111":
2 years ago
image_generator = RunpodImageAutomaticWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model'])
elif image_endpoint['service'] == "runpod" and image_endpoint['model'].startswith('kandinsky'):
image_generator = RunpodImageWrapper2(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model'])
elif image_endpoint['service'] == "runpod":
image_generator = RunpodImageWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model'])
2 years ago
elif image_endpoint['service'] == "stablehorde":
2 years ago
image_generator = StableHordeImageWrapper(image_endpoint['api_key'], image_endpoint['endpoint'], image_endpoint['model'])
2 years ago
else:
raise ValueError(f"no image service with the name \"{service_image}\"")
i = image_endpoint['id']
2 years ago
# def make_fn_generate_image_for_endpoint(wrapper, endpoint):
# async def generate_image(input_prompt, negative_prompt, typing_fn, timeout=180):
# return await wrapper.generate(input_prompt, negative_prompt, endpoint, typing_fn, timeout)
# return generate_image
# #self.image_generators.append(generate_image)
# image_generators[i] = make_fn_generate_image_for_endpoint(image_generator, image_endpoint['endpoint'])
self.image_generators[i] = image_generator
2 years ago
2 years ago
self.ai = AI(self, self.text_generators, self.image_generators, self.memory_path)
2 years ago
async def message_cb(self, room, event) -> None:
message = Message.from_matrix(room, event)
reply_fn = partial(self.connection.send_message, room.room_id)
typing_fn = lambda : self.connection.room_typing(room.room_id, True, 15000)
2 years ago
message.user_name = message.user_name.title()
2 years ago
if self.name.casefold() == message.user_name.casefold() and not message.is_from(self.connection.user_id):
2 years ago
"""Bot and user have the same name"""
message.user_name += " 2" # or simply "You"
2 years ago
if message.is_from(self.connection.user_id):
message.role = "ai"
else:
message.role = "human"
2 years ago
2 years ago
if not room.room_id in self.rooms:
self.rooms[room.room_id] = {}
2 years ago
self.rooms[room.room_id]['tick'] = 0
self.rooms[room.room_id]['num_messages'] = 0
self.rooms[room.room_id]['diary'] = {}
await self.write_conf2(self.rooms)
2 years ago
# ToDo: set ticks 0 / start
if not self.connection.synced:
2 years ago
if not message.is_command() and not message.is_error():
2 years ago
await self.ai.add_chat_message(message.to_langchain())
2 years ago
self.chatlog.save(message, False)
return
if message.is_from(self.connection.user_id):
"""Skip messages from ouselves"""
self.chatlog.save(message)
2 years ago
await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id)
2 years ago
return
2 years ago
# if event.decrypted:
# encrypted_symbol = "🛡 "
# else:
# encrypted_symbol = "⚠️ "
# print(
# f"{room.display_name} |{encrypted_symbol}| {room.user_name(event.sender)}: {event.body}"
# )
# if room.is_named:
# print(f"room.display_name: {room.display_name}")
# if room.is_group:
# print(f"room.group_name(): {room.group_name()}")
# print(f"room.joined_count: {room.joined_count}")
# print(f"room.member_count: {room.member_count}")
# print(f"room.encrypted: {room.encrypted}")
# print(f"room.users: {room.users}")
# print(f"room.room_id: {room.room_id}")
2 years ago
if hasattr(self, "owner"):
if not message.is_from(self.owner):
self.chatlog.save(message)
return
if "disabled" in self.rooms[message.room_id] and self.rooms[message.room_id]["disabled"] == True and not (message.message.startswith('!start') or message.message.startswith('!begin')):
2 years ago
return
await self.connection.room_read_markers(room.room_id, event.event_id, event.event_id)
if message.is_command():
await self.schedule(self.queue, self.process_command, message, reply_fn, typing_fn)
# elif re.search("^(?=.*\bsend\b)(?=.*\bpicture\b).*$", event.body, flags=re.IGNORECASE):
# # send, mail, drop, snap picture, photo, image, portrait
elif message.is_error():
return
2 years ago
else:
await self.schedule(self.queue, self.process_message, message, reply_fn, typing_fn)
2 years ago
self.rooms[room.room_id]['num_messages'] += 1
self.last_conversation = datetime.now()
2 years ago
self.chatlog.save(message)
2 years ago
2 years ago
async def redaction_cb(self, room, event) -> None:
self.chatlog.remove_message_by_id(event.event_id)
async def process_command(self, message, reply_fn, typing_fn):
if message.message.startswith("!replybot"):
await reply_fn("Hello World")
elif re.search("(?s)^!image(?P<num>[0-9]+)?(\s(?P<cmd>.*))?$", message.message, flags=re.DOTALL):
m = re.search("(?s)^!image(?P<num>[0-9]+)?(\s(?P<cmd>.*))?$", message.message, flags=re.DOTALL)
2 years ago
if m['num']:
num = int(m['num'])
else:
num = 1
prompt = "a beautiful woman"
2 years ago
negative_prompt = "out of frame, (ugly:1.3), (fused fingers), (too many fingers), (bad anatomy:1.5), (watermark:1.5), (words), letters, untracked eyes, asymmetric eyes, floating head, (logo:1.5), (bad hands:1.3), (mangled hands:1.2), (missing hands), (missing arms), backward hands, floating jewelry, unattached jewelry, floating head, doubled head, unattached head, doubled head, head in body, (misshapen body:1.1), (badly fitted headwear:1.2), floating arms, (too many arms:1.5), limbs fused with body, (facial blemish:1.5), badly fitted clothes, imperfect eyes, untracked eyes, crossed eyes, hair growing from clothes, partial faces, hair not attached to head"
#"anime, cartoon, penis, fake, drawing, illustration, boring, 3d render, long neck, out of frame, extra fingers, mutated hands, monochrome, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, glitchy, bokeh, (((long neck))), 3D, 3DCG, cgstation, red eyes, multiple subjects, extra heads, close up, watermarks, logo"
#"ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face"
#"ugly, deformed, out of frame"
if m['cmd']:
prompt = m['cmd'].strip()
prompt_split = prompt.rsplit('|', 1)
if len(prompt_split) == 2:
prompt = prompt_split[0].strip()
negative_prompt = prompt_split[1].strip()
2 years ago
try:
2 years ago
output = await self.image_generators[num].generate(prompt, negative_prompt, typing_fn)
2 years ago
await self.connection.room_typing(message.room_id, False)
for imagefile in output:
await self.connection.send_image(message.room_id, imagefile)
except (KeyError, IndexError, ValueError) as err:
#await self.connection.room_typing(message.room_id, False)
errormessage = f"<ERROR> {err=}, {type(err)=}"
logger.error(errormessage)
await reply_fn(errormessage)
elif message.message.startswith("!image_negative_prompt"):
self.negative_prompt = message.message.removeprefix('!image_negative_prompt').strip()
elif message.message.startswith('!temperature'):
self.temperature = float( message.message.removeprefix('!temperature').strip() )
elif message.message.startswith('!begin'):
self.rooms[message.room_id]["disabled"] = False
2 years ago
await self.write_conf2(self.rooms)
2 years ago
self.chatlog.clear(message.room_id)
2 years ago
await self.ai.clear(message.room_id)
2 years ago
# ToDo reset time / ticks
await reply_fn(self.greeting)
elif message.message.startswith('!start'):
self.rooms[message.room_id]["disabled"] = False
2 years ago
await self.write_conf2(self.rooms)
2 years ago
elif message.message.startswith('!stop'):
self.rooms[message.room_id]["disabled"] = True
2 years ago
await self.write_conf2(self.rooms)
elif message.message.startswith('!sleep'):
await self.schedule(self.queue, self.ai.sleep)
2 years ago
elif message.message.startswith('!!'):
if self.chatlog.chat_history_len(message.room_id) > 2:
for _ in range(2):
old_message = self.chatlog.remove_message_in_room(message.room_id, 1)
await self.connection.room_redact(message.room_id, old_message.event_id, reason="user-request")
message = self.chatlog.get_last_message(message.room_id)
async def process_message(self, message, reply_fn, typing_fn):
output = await self.ai.generate_roleplay(message.to_langchain(), reply_fn, typing_fn)
2 years ago
#output = await self.ai.generate(message, reply_fn, typing_fn)
2 years ago
# typing false
#await reply_fn(output)
2 years ago
2 years ago
async def event_loop(self):
try:
while True:
await asyncio.sleep(60)
for room_id in self.rooms.keys():
self.rooms[room_id]["tick"] += 1
if datetime.now().hour >= 1 and datetime.now().hour < 5:
if not hasattr(self, "last_conversation") or self.last_conversation + timedelta(minutes=30) < datetime.now():
load1, load5, load15 = [x / psutil.cpu_count() * 100 for x in psutil.getloadavg()]
if load5 < 40 and load1 < 40:
if not hasattr(self, "last_sleep") or self.last_sleep + timedelta(hours=6) < datetime.now():
await self.ai.sleep()
self.last_sleep = datetime.now()
2 years ago
finally:
pass
# await self.write_conf2(self.name)
2 years ago
async def worker(self, name: str, q: asyncio.Queue) -> None:
while True:
cb, args, kwargs = await q.get()
start = time.perf_counter()
if asyncio.iscoroutinefunction(cb):
2 years ago
logger.info("queued task started (coroutine)")
2 years ago
await cb(*args, **kwargs)
else:
2 years ago
logger.info("queued task started (function)")
2 years ago
cb(*args, **kwargs)
q.task_done()
elapsed = time.perf_counter() - start
logger.info(f"Queued task done in {elapsed:0.5f} seconds.")
logger.debug("queue item processed")
async def schedule(self, q: asyncio.Queue, cb, *args, **kwargs) -> None:
logger.info(f"queuing task")
await q.put((cb, args, kwargs))
#q.put_nowait((cb, args, kwargs))
async def schedule_task(self, done_callback, cb, *args, **kwargs):
logger.info(f"creating background task")
task = asyncio.create_task(cb(*args, **kwargs))
task.add_done_callback(done_callback)
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)
# closure
def outer_function(self, x):
def inner_funtion(y):
return x+y
return inner_function