Browse Source

parallel tasks

master
Hendrik Langer 2 years ago
parent
commit
a5339f2486
  1. 3
      matrix_pygmalion_bot/ai/runpod_pygmalion.py
  2. 71
      matrix_pygmalion_bot/core.py
  3. 14
      matrix_pygmalion_bot/helpers.py

3
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
async def generate_sync(
prompt: str,
api_key: str,
bot_name: str,
):
# Set the API endpoint URL
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync"
@ -54,6 +55,7 @@ async def generate_sync(
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
reply.replace("\n{bot_name}: ", " ")
return reply
elif status == 'IN_PROGRESS' or status == 'IN_QUEUE':
job_id = r_json["id"]
@ -79,6 +81,7 @@ async def generate_sync(
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
reply.replace("\n{bot_name}: ", " ")
return reply
else:
return "<ERROR>"

71
matrix_pygmalion_bot/core.py

@ -3,6 +3,7 @@ import nio
from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse)
import os, sys
import time
import importlib
import configparser
import logging
@ -10,6 +11,7 @@ import logging
import aiofiles.os
import magic
from PIL import Image
import re
from .helpers import ChatItem
ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion")
@ -23,6 +25,7 @@ STORE_PATH = "./.store/"
logger = logging.getLogger(__name__)
config = configparser.ConfigParser()
bots = []
background_tasks = set()
class Callbacks(object):
"""Class to pass client to callback methods."""
@ -32,6 +35,7 @@ class Callbacks(object):
self.bot = bot
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None:
event_id = event.event_id
message = event.body
is_own_message = False
if event.sender == self.client.user:
@ -39,6 +43,9 @@ class Callbacks(object):
is_command = False
if event.body.startswith('!'):
is_command = True
if not (self.bot.owner is None):
if not (event.sender == self.bot.owner or is_own_message):
return
relates_to = None
if 'm.relates_to' in event.source["content"]:
relates_to = event.source["content"]['m.relates_to']["event_id"]
@ -49,7 +56,7 @@ class Callbacks(object):
else:
translated_message = translate.translate(message, self.bot.translate, "en")
if hasattr(event, 'body'):
self.bot.chat_history[event.event_id] = ChatItem(event.event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message)
self.bot.chat_history[event_id] = ChatItem(event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message)
if self.bot.not_synced:
return
print(
@ -57,7 +64,7 @@ class Callbacks(object):
room.display_name, room.user_name(event.sender), event.body
)
)
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
await self.client.room_read_markers(room.room_id, event_id, event_id)
# Ignore messages from ourselves
if is_own_message:
return
@ -81,23 +88,26 @@ class Callbacks(object):
return
elif event.body.startswith('!!!'):
return
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# return
elif event.body.startswith('!!'):
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
return
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current
# self.bot.chat_history[chat_history_event_id] = chat_history_item
# message = chat_history_item.message
# # don't return, we generate a new answer
elif event.body.startswith('!!'):
if len(self.bot.chat_history) < 3:
return
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current
self.bot.chat_history[chat_history_event_id] = chat_history_item
event_id = chat_history_item.event_id
message = chat_history_item.message
translated_message = message
# don't return, we generate a new answer
full_prompt = await ai.get_full_prompt(translated_message, self.bot)
num_tokens = await ai.num_tokens(full_prompt)
@ -119,15 +129,15 @@ class Callbacks(object):
# answer = answer.strip()
# print("")
await self.client.room_typing(room.room_id, True, 15000)
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key)
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name)
answer = answer.strip()
await self.client.room_typing(room.room_id, False)
translated_answer = answer
if not (self.bot.translate is None):
translated_answer = translate.translate(answer, "en", self.bot.translate)
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event.event_id, original_message=answer)
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event_id, original_message=answer)
else:
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event.event_id)
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event_id)
@ -162,6 +172,7 @@ class ChatBot(object):
self.persona = None
self.scenario = None
self.greeting = None
self.events = []
self.chat_history = {}
if STORE_PATH and not os.path.isdir(STORE_PATH):
@ -173,6 +184,13 @@ class ChatBot(object):
self.scenario = scenario
self.greeting = greeting
async def event_loop(self):
while True:
await asyncio.sleep(60)
print(time.time())
for event in self.events:
event.loop()
async def login(self):
self.config = AsyncClientConfig(store_sync_tokens=True)
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config)
@ -181,14 +199,18 @@ class ChatBot(object):
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent)
sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced))
event_loop = asyncio.create_task(self.event_loop())
background_tasks.add(event_loop)
event_loop.add_done_callback(background_tasks.discard)
try:
response = await self.client.login(self.password)
print(response)
await self.client.sync_forever(timeout=30000, full_state=True)
#sync_forever_task = asyncio.create_task(self.client.sync_forever(timeout=30000, full_state=True))
except (asyncio.CancelledError, KeyboardInterrupt):
print("Received interrupt.")
await self.client.close()
#return sync_forever_task
async def watch_for_sync(self, sync_event):
print("Awaiting sync")
@ -279,6 +301,9 @@ async def main() -> None:
bot.runpod_api_key = config['DEFAULT']['runpod_api_key']
bots.append(bot)
await bot.login()
print("logged in")
print("gather")
async with asyncio.TaskGroup() as tg:
for bot in bots:
task = tg.create_task(bot.client.sync_forever(timeout=30000, full_state=True))
asyncio.get_event_loop().run_until_complete(main())

14
matrix_pygmalion_bot/helpers.py

@ -1,4 +1,4 @@
import time
class ChatItem:
def __init__(self, event_id, timestamp, user_name, is_own_message, relates_to_event, message):
@ -13,3 +13,15 @@ class ChatItem:
return str("{}: {}".format(self.user_name, self.message))
def getLine(self):
return str("{}: {}".format(self.user_name, self.message))
class Event:
def __init__(self, timestamp_start, timestamp_stop=None, chance=1):
self.timestamp_start = timestamp_start
self.timestamp_stop = timestamp_stop
self.chance = chance
self.executed = 0
def __str__(self):
return str("Event starting at timestamp {}".format(self.timestamp_start))
def loop(self):
if timestamp_start > time.time():
pass

Loading…
Cancel
Save