Browse Source

parallel tasks

master
Hendrik Langer 2 years ago
parent
commit
a5339f2486
  1. 3
      matrix_pygmalion_bot/ai/runpod_pygmalion.py
  2. 71
      matrix_pygmalion_bot/core.py
  3. 14
      matrix_pygmalion_bot/helpers.py

3
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
async def generate_sync( async def generate_sync(
prompt: str, prompt: str,
api_key: str, api_key: str,
bot_name: str,
): ):
# Set the API endpoint URL # Set the API endpoint URL
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync" endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync"
@ -54,6 +55,7 @@ async def generate_sync(
reply = answer[:idx].strip() reply = answer[:idx].strip()
else: else:
reply = answer.removesuffix('<|endoftext|>').strip() reply = answer.removesuffix('<|endoftext|>').strip()
reply.replace("\n{bot_name}: ", " ")
return reply return reply
elif status == 'IN_PROGRESS' or status == 'IN_QUEUE': elif status == 'IN_PROGRESS' or status == 'IN_QUEUE':
job_id = r_json["id"] job_id = r_json["id"]
@ -79,6 +81,7 @@ async def generate_sync(
reply = answer[:idx].strip() reply = answer[:idx].strip()
else: else:
reply = answer.removesuffix('<|endoftext|>').strip() reply = answer.removesuffix('<|endoftext|>').strip()
reply.replace("\n{bot_name}: ", " ")
return reply return reply
else: else:
return "<ERROR>" return "<ERROR>"

71
matrix_pygmalion_bot/core.py

@ -3,6 +3,7 @@ import nio
from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse) from nio import (AsyncClient, AsyncClientConfig, MatrixRoom, RoomMessageText, InviteEvent, UploadResponse)
import os, sys import os, sys
import time
import importlib import importlib
import configparser import configparser
import logging import logging
@ -10,6 +11,7 @@ import logging
import aiofiles.os import aiofiles.os
import magic import magic
from PIL import Image from PIL import Image
import re
from .helpers import ChatItem from .helpers import ChatItem
ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion")
@ -23,6 +25,7 @@ STORE_PATH = "./.store/"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
config = configparser.ConfigParser() config = configparser.ConfigParser()
bots = [] bots = []
background_tasks = set()
class Callbacks(object): class Callbacks(object):
"""Class to pass client to callback methods.""" """Class to pass client to callback methods."""
@ -32,6 +35,7 @@ class Callbacks(object):
self.bot = bot self.bot = bot
async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None: async def message_cb(self, room: MatrixRoom, event: RoomMessageText) -> None:
event_id = event.event_id
message = event.body message = event.body
is_own_message = False is_own_message = False
if event.sender == self.client.user: if event.sender == self.client.user:
@ -39,6 +43,9 @@ class Callbacks(object):
is_command = False is_command = False
if event.body.startswith('!'): if event.body.startswith('!'):
is_command = True is_command = True
if not (self.bot.owner is None):
if not (event.sender == self.bot.owner or is_own_message):
return
relates_to = None relates_to = None
if 'm.relates_to' in event.source["content"]: if 'm.relates_to' in event.source["content"]:
relates_to = event.source["content"]['m.relates_to']["event_id"] relates_to = event.source["content"]['m.relates_to']["event_id"]
@ -49,7 +56,7 @@ class Callbacks(object):
else: else:
translated_message = translate.translate(message, self.bot.translate, "en") translated_message = translate.translate(message, self.bot.translate, "en")
if hasattr(event, 'body'): if hasattr(event, 'body'):
self.bot.chat_history[event.event_id] = ChatItem(event.event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message) self.bot.chat_history[event_id] = ChatItem(event_id, event.server_timestamp, room.user_name(event.sender), is_own_message, relates_to, translated_message)
if self.bot.not_synced: if self.bot.not_synced:
return return
print( print(
@ -57,7 +64,7 @@ class Callbacks(object):
room.display_name, room.user_name(event.sender), event.body room.display_name, room.user_name(event.sender), event.body
) )
) )
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id) await self.client.room_read_markers(room.room_id, event_id, event_id)
# Ignore messages from ourselves # Ignore messages from ourselves
if is_own_message: if is_own_message:
return return
@ -81,23 +88,26 @@ class Callbacks(object):
return return
elif event.body.startswith('!!!'): elif event.body.startswith('!!!'):
return return
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# return
elif event.body.startswith('!!'):
return return
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current elif event.body.startswith('!!'):
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") if len(self.bot.chat_history) < 3:
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() return
# await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # current
# chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# self.bot.chat_history[chat_history_event_id] = chat_history_item chat_history_event_id, chat_history_item = self.bot.chat_history.popitem()
# message = chat_history_item.message await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
# # don't return, we generate a new answer chat_history_event_id, chat_history_item = self.bot.chat_history.popitem() # new current
self.bot.chat_history[chat_history_event_id] = chat_history_item
event_id = chat_history_item.event_id
message = chat_history_item.message
translated_message = message
# don't return, we generate a new answer
full_prompt = await ai.get_full_prompt(translated_message, self.bot) full_prompt = await ai.get_full_prompt(translated_message, self.bot)
num_tokens = await ai.num_tokens(full_prompt) num_tokens = await ai.num_tokens(full_prompt)
@ -119,15 +129,15 @@ class Callbacks(object):
# answer = answer.strip() # answer = answer.strip()
# print("") # print("")
await self.client.room_typing(room.room_id, True, 15000) await self.client.room_typing(room.room_id, True, 15000)
answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key) answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot.name)
answer = answer.strip() answer = answer.strip()
await self.client.room_typing(room.room_id, False) await self.client.room_typing(room.room_id, False)
translated_answer = answer translated_answer = answer
if not (self.bot.translate is None): if not (self.bot.translate is None):
translated_answer = translate.translate(answer, "en", self.bot.translate) translated_answer = translate.translate(answer, "en", self.bot.translate)
await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event.event_id, original_message=answer) await self.bot.send_message(self.client, room.room_id, translated_answer, reply_to=event_id, original_message=answer)
else: else:
await self.bot.send_message(self.client, room.room_id, answer, reply_to=event.event_id) await self.bot.send_message(self.client, room.room_id, answer, reply_to=event_id)
@ -162,6 +172,7 @@ class ChatBot(object):
self.persona = None self.persona = None
self.scenario = None self.scenario = None
self.greeting = None self.greeting = None
self.events = []
self.chat_history = {} self.chat_history = {}
if STORE_PATH and not os.path.isdir(STORE_PATH): if STORE_PATH and not os.path.isdir(STORE_PATH):
@ -173,6 +184,13 @@ class ChatBot(object):
self.scenario = scenario self.scenario = scenario
self.greeting = greeting self.greeting = greeting
async def event_loop(self):
while True:
await asyncio.sleep(60)
print(time.time())
for event in self.events:
event.loop()
async def login(self): async def login(self):
self.config = AsyncClientConfig(store_sync_tokens=True) self.config = AsyncClientConfig(store_sync_tokens=True)
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config)
@ -181,14 +199,18 @@ class ChatBot(object):
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent)
sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced)) sync_task = asyncio.create_task(self.watch_for_sync(self.client.synced))
event_loop = asyncio.create_task(self.event_loop())
background_tasks.add(event_loop)
event_loop.add_done_callback(background_tasks.discard)
try: try:
response = await self.client.login(self.password) response = await self.client.login(self.password)
print(response) print(response)
await self.client.sync_forever(timeout=30000, full_state=True) #sync_forever_task = asyncio.create_task(self.client.sync_forever(timeout=30000, full_state=True))
except (asyncio.CancelledError, KeyboardInterrupt): except (asyncio.CancelledError, KeyboardInterrupt):
print("Received interrupt.") print("Received interrupt.")
await self.client.close() await self.client.close()
#return sync_forever_task
async def watch_for_sync(self, sync_event): async def watch_for_sync(self, sync_event):
print("Awaiting sync") print("Awaiting sync")
@ -279,6 +301,9 @@ async def main() -> None:
bot.runpod_api_key = config['DEFAULT']['runpod_api_key'] bot.runpod_api_key = config['DEFAULT']['runpod_api_key']
bots.append(bot) bots.append(bot)
await bot.login() await bot.login()
print("logged in") print("gather")
async with asyncio.TaskGroup() as tg:
for bot in bots:
task = tg.create_task(bot.client.sync_forever(timeout=30000, full_state=True))
asyncio.get_event_loop().run_until_complete(main()) asyncio.get_event_loop().run_until_complete(main())

14
matrix_pygmalion_bot/helpers.py

@ -1,4 +1,4 @@
import time
class ChatItem: class ChatItem:
def __init__(self, event_id, timestamp, user_name, is_own_message, relates_to_event, message): def __init__(self, event_id, timestamp, user_name, is_own_message, relates_to_event, message):
@ -13,3 +13,15 @@ class ChatItem:
return str("{}: {}".format(self.user_name, self.message)) return str("{}: {}".format(self.user_name, self.message))
def getLine(self): def getLine(self):
return str("{}: {}".format(self.user_name, self.message)) return str("{}: {}".format(self.user_name, self.message))
class Event:
def __init__(self, timestamp_start, timestamp_stop=None, chance=1):
self.timestamp_start = timestamp_start
self.timestamp_stop = timestamp_stop
self.chance = chance
self.executed = 0
def __str__(self):
return str("Event starting at timestamp {}".format(self.timestamp_start))
def loop(self):
if timestamp_start > time.time():
pass

Loading…
Cancel
Save