Browse Source

don't recalculate every time. use a window and remove a chunk of chat history when we get near the token limit

master
Hendrik Langer 2 years ago
parent
commit
8213d50f15
  1. 4
      matrix_pygmalion_bot/ai/koboldcpp.py
  2. 47
      matrix_pygmalion_bot/ai/llama_helpers.py
  3. 53
      matrix_pygmalion_bot/ai/pygmalion_helpers.py
  4. 14
      matrix_pygmalion_bot/chatlog.py
  5. 39
      matrix_pygmalion_bot/core.py

4
matrix_pygmalion_bot/ai/koboldcpp.py

@ -20,8 +20,10 @@ logger = logging.getLogger(__name__)
def setup(): def setup():
os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)") os.system("mkdir -p repositories && (cd repositories && git clone https://github.com/LostRuins/koboldcpp.git)")
os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)") os.system("apt update && apt-get install libopenblas-dev libclblast-dev libmkl-dev")
os.system("(cd repositories/koboldcpp && make LLAMA_OPENBLAS=1 LLAMA_CLBLAST=1 && cd models && wget https://huggingface.co/concedo/pygmalion-6bv3-ggml-ggjt/resolve/main/pygmalion-6b-v3-ggml-ggjt-q4_0.bin)")
#python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin #python3 koboldcpp.py models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin
#python3 koboldcpp.py --smartcontext models/pygmalion-6b-v3-ggml-ggjt-q4_0.bin
async def generate_sync( async def generate_sync(
prompt: str, prompt: str,

47
matrix_pygmalion_bot/ai/llama_helpers.py

@ -51,15 +51,20 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history):
#prompt += f"{ai_name}:" #prompt += f"{ai_name}:"
MAX_TOKENS = 2048 MAX_TOKENS = 2048
WINDOW = 600
max_new_tokens = 200 max_new_tokens = 200
total_num_tokens = await num_tokens(prompt) total_num_tokens = await num_tokens(prompt)
total_num_tokens += await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:") input_num_tokens = await num_tokens(f"{user_name}: {simple_prompt}\n{ai_name}:")
total_num_tokens += input_num_tokens
visible_history = [] visible_history = []
current_message = True num_message = 0
for key, chat_item in reversed(chat_history.chat_history.items()): for key, chat_item in reversed(chat_history.chat_history.items()):
if current_message: num_message += 1
current_message = False if num_message == 1:
# skip current_message
continue continue
if chat_item.stop_here:
break
if chat_item.message["en"].startswith('!begin'): if chat_item.message["en"].startswith('!begin'):
break break
if chat_item.message["en"].startswith('!'): if chat_item.message["en"].startswith('!'):
@ -69,10 +74,11 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history):
#if chat_item.message["en"] == bot.greeting: #if chat_item.message["en"] == bot.greeting:
# continue # continue
if chat_item.num_tokens == None: if chat_item.num_tokens == None:
chat_item.num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"])) chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(user_name, chat_item.message["en"]))
chat_item = chat_history.chat_history[key]
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
visible_history.append(chat_item) visible_history.append(chat_item)
total_num_tokens += chat_item.num_tokens total_num_tokens += chat_item.num_tokens
else: else:
@ -81,14 +87,37 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history):
if not hasattr(bot, "greeting_num_tokens"): if not hasattr(bot, "greeting_num_tokens"):
bot.greeting_num_tokens = await num_tokens(bot.greeting) bot.greeting_num_tokens = await num_tokens(bot.greeting)
if total_num_tokens + bot.greeting_num_tokens < MAX_TOKENS - max_new_tokens: if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
prompt += f"{ai_name}: {bot.greeting}\n" prompt += f"{ai_name}: {bot.greeting}\n"
total_num_tokens += bot.greeting_num_tokens
for chat_item in visible_history: for chat_item in visible_history:
if chat_item.is_own_message: if chat_item.is_own_message:
prompt += f"{ai_name}: {chat_item.message['en']}\n" line = f"{ai_name}: {chat_item.message['en']}\n"
else: else:
prompt += f"{user_name}: {chat_item.message['en']}\n" line = f"{user_name}: {chat_item.message['en']}\n"
prompt += line
if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt:
logger.info(f"adding to saved prompt: \"line\"")
chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens )
chat_item.is_in_saved_prompt = True
if chat_history.saved_context_num_tokens:
logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}")
if chat_history.getSavedPrompt():
if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens:
chat_history.setFastForward(False)
if chat_history.getFastForward():
logger.info("using saved prompt")
prompt = chat_history.getSavedPrompt()
if not chat_history.getSavedPrompt() or not chat_history.getFastForward():
logger.info("regenerating prompt")
chat_history.setSavedPrompt(prompt, total_num_tokens)
for key, chat_item in chat_history.chat_history.items():
if key != list(chat_history.chat_history)[-1]: # exclude current item
chat_history.chat_history[key].is_in_saved_prompt = True
chat_history.setFastForward(True)
prompt += f"{user_name}: {simple_prompt}\n" prompt += f"{user_name}: {simple_prompt}\n"
prompt += f"{ai_name}:" prompt += f"{ai_name}:"

53
matrix_pygmalion_bot/ai/pygmalion_helpers.py

@ -15,6 +15,9 @@ from PIL import Image, PngImagePlugin
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
tokenizer = None
async def get_full_prompt(simple_prompt: str, bot, chat_history): async def get_full_prompt(simple_prompt: str, bot, chat_history):
# Prompt without history # Prompt without history
@ -32,15 +35,20 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history):
#prompt += bot.name + ":" #prompt += bot.name + ":"
MAX_TOKENS = 2048 MAX_TOKENS = 2048
WINDOW = 800
max_new_tokens = 200 max_new_tokens = 200
total_num_tokens = await num_tokens(prompt) total_num_tokens = await num_tokens(prompt)
total_num_tokens += await num_tokens(f"You: " + simple_prompt + "\n{bot.name}:") input_num_tokens = await num_tokens(f"You: " + simple_prompt + "\n{bot.name}:")
total_num_tokens += input_num_tokens
visible_history = [] visible_history = []
current_message = True num_message = 0
for key, chat_item in reversed(chat_history.chat_history.items()): for key, chat_item in reversed(chat_history.chat_history.items()):
if current_message: num_message += 1
current_message = False if num_message == 1:
# skip current_message
continue continue
if chat_item.stop_here:
break
if chat_item.message["en"].startswith('!begin'): if chat_item.message["en"].startswith('!begin'):
break break
if chat_item.message["en"].startswith('!'): if chat_item.message["en"].startswith('!'):
@ -50,10 +58,11 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history):
#if chat_item.message["en"] == bot.greeting: #if chat_item.message["en"] == bot.greeting:
# continue # continue
if chat_item.num_tokens == None: if chat_item.num_tokens == None:
chat_item.num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"])) chat_history.chat_history[key].num_tokens = await num_tokens("{}: {}".format(chat_item.user_name, chat_item.message["en"]))
chat_item = chat_history.chat_history[key]
# TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens?? # TODO: is it MAX_TOKENS or MAX_TOKENS - max_new_tokens??
logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]") logger.debug(f"History: " + str(chat_item) + " [" + str(chat_item.num_tokens) + "]")
if total_num_tokens + chat_item.num_tokens < MAX_TOKENS - max_new_tokens: if total_num_tokens + chat_item.num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
visible_history.append(chat_item) visible_history.append(chat_item)
total_num_tokens += chat_item.num_tokens total_num_tokens += chat_item.num_tokens
else: else:
@ -62,23 +71,43 @@ async def get_full_prompt(simple_prompt: str, bot, chat_history):
if not hasattr(bot, "greeting_num_tokens"): if not hasattr(bot, "greeting_num_tokens"):
bot.greeting_num_tokens = await num_tokens(bot.greeting) bot.greeting_num_tokens = await num_tokens(bot.greeting)
if total_num_tokens + bot.greeting_num_tokens < MAX_TOKENS - max_new_tokens: if total_num_tokens + bot.greeting_num_tokens <= MAX_TOKENS - WINDOW - max_new_tokens:
prompt += bot.name + ": " + bot.greeting + "\n" prompt += bot.name + ": " + bot.greeting + "\n"
total_num_tokens += bot.greeting_num_tokens
for chat_item in visible_history: for chat_item in visible_history:
if chat_item.is_own_message: if chat_item.is_own_message:
prompt += bot.name + ": " + chat_item.message["en"] + "\n" line = bot.name + ": " + chat_item.message["en"] + "\n"
else: else:
prompt += "You" + ": " + chat_item.message["en"] + "\n" line = "You" + ": " + chat_item.message["en"] + "\n"
prompt += line
if chat_history.getSavedPrompt() and not chat_item.is_in_saved_prompt:
logger.info(f"adding to saved prompt: \"{line}\"")
chat_history.setSavedPrompt( chat_history.getSavedPrompt() + line, chat_history.saved_context_num_tokens + chat_item.num_tokens )
chat_item.is_in_saved_prompt = True
if chat_history.saved_context_num_tokens:
logger.info(f"saved_context has {chat_history.saved_context_num_tokens+input_num_tokens} tokens. new context would be {total_num_tokens}. Limit is {MAX_TOKENS}")
if chat_history.getSavedPrompt():
if chat_history.saved_context_num_tokens+input_num_tokens > MAX_TOKENS - max_new_tokens:
chat_history.setFastForward(False)
if chat_history.getFastForward():
logger.info("using saved prompt")
prompt = chat_history.getSavedPrompt()
if not chat_history.getSavedPrompt() or not chat_history.getFastForward():
logger.info("regenerating prompt")
chat_history.setSavedPrompt(prompt, total_num_tokens)
for key, chat_item in chat_history.chat_history.items():
if key != list(chat_history.chat_history)[-1]: # exclude current item
chat_history.chat_history[key].is_in_saved_prompt = True
chat_history.setFastForward(True)
prompt += "You: " + simple_prompt + "\n" prompt += "You: " + simple_prompt + "\n"
prompt += bot.name + ":" prompt += bot.name + ":"
return prompt return prompt
tokenizer = None
async def num_tokens(input_text: str): async def num_tokens(input_text: str):
# os.makedirs("./models/pygmalion-6b", exist_ok=True) # os.makedirs("./models/pygmalion-6b", exist_ok=True)
# hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b") # hf_hub_download(repo_id="PygmalionAI/pygmalion-6b", filename="config.json", cache_dir="./models/pygmalion-6b")

14
matrix_pygmalion_bot/chatlog.py

@ -12,8 +12,10 @@ class ChatMessage:
self.user_name = user_name self.user_name = user_name
self.is_own_message = is_own_message self.is_own_message = is_own_message
self.is_command = is_command self.is_command = is_command
self.stop_here = False
self.relates_to_event = relates_to_event self.relates_to_event = relates_to_event
self.num_tokens = None self.num_tokens = None
self.is_in_saved_prompt = False
self.message = {} self.message = {}
self.message[language] = message self.message[language] = message
if not (language == "en"): if not (language == "en"):
@ -40,6 +42,9 @@ class ChatHistory:
def __init__(self, bot_name, room_name): def __init__(self, bot_name, room_name):
self.bot_name = bot_name self.bot_name = bot_name
self.room_name = room_name self.room_name = room_name
self.context_fast_forward = False
self.saved_context = None
self.saved_context_num_tokens = None
self.chat_history = {} self.chat_history = {}
def __str__(self): def __str__(self):
return str("Chat History for {} in room {}".format(self.bot_name, self.room_name)) return str("Chat History for {} in room {}".format(self.bot_name, self.room_name))
@ -70,6 +75,15 @@ class ChatHistory:
def getLastItem(self): def getLastItem(self):
key, chat_item = list(self.chat_history.items())[-1] key, chat_item = list(self.chat_history.items())[-1]
return chat_item return chat_item
def setFastForward(self, value):
self.context_fast_forward = value
def getFastForward(self):
return self.context_fast_forward
def getSavedPrompt(self):
return self.saved_context
def setSavedPrompt(self, context, num_tokens):
self.saved_context = context
self.saved_context_num_tokens = num_tokens
class BotChatHistory: class BotChatHistory:
def __init__(self, bot_name): def __init__(self, bot_name):

39
matrix_pygmalion_bot/core.py

@ -59,12 +59,14 @@ class Callbacks(object):
else: else:
english_original_message = None english_original_message = None
chat_message = self.bot.chat_history.room(room.display_name).add(event.event_id, event.server_timestamp, room.user_name(event.sender), event.sender == self.client.user, is_command, relates_to, event.body, language, english_original_message) chat_message = self.bot.chat_history.room(room.room_id).add(event.event_id, event.server_timestamp, room.user_name(event.sender), event.sender == self.client.user, is_command, relates_to, event.body, language, english_original_message)
# parse keywords # parse keywords
self.bot.extra_info = {"persona": [], "scenario": [], "example_dialogue": []} self.bot.extra_info = {"persona": [], "scenario": [], "example_dialogue": []}
for i, keyword in enumerate(self.bot.keywords): for i, keyword in enumerate(self.bot.keywords):
if re.search(keyword["regex"], event.body, flags=re.IGNORECASE): if re.search(keyword["regex"], event.body, flags=re.IGNORECASE):
if not 'active' in self.bot.keywords[i] or self.bot.keywords[i]['active'] < 1:
self.bot.chat_history.room(room.room_id).setFastForward(False)
self.bot.keywords[i]['active'] = int(keyword["duration"]) self.bot.keywords[i]['active'] = int(keyword["duration"])
logger.info(f"keyword \"{keyword['regex']}\" detected") logger.info(f"keyword \"{keyword['regex']}\" detected")
if 'active' in self.bot.keywords[i]: if 'active' in self.bot.keywords[i]:
@ -185,7 +187,7 @@ class Callbacks(object):
elif event.body.startswith('!temperature'): elif event.body.startswith('!temperature'):
self.bot.temperature = float( event.body.removeprefix('!temperature').strip() ) self.bot.temperature = float( event.body.removeprefix('!temperature').strip() )
elif event.body.startswith('!begin'): elif event.body.startswith('!begin'):
self.bot.chat_history.room(room.display_name).clear() self.bot.chat_history.room(room.room_id).clear()
self.bot.room_config[room.room_id]["tick"] = 0 self.bot.room_config[room.room_id]["tick"] = 0
await self.bot.write_conf2(self.bot.name) await self.bot.write_conf2(self.bot.name)
await self.bot.send_message(self.client, room.room_id, self.bot.greeting) await self.bot.send_message(self.client, room.room_id, self.bot.greeting)
@ -197,30 +199,30 @@ class Callbacks(object):
self.bot.room_config[room.room_id]["disabled"] = True self.bot.room_config[room.room_id]["disabled"] = True
return return
elif event.body.startswith('!!!'): elif event.body.startswith('!!!'):
if self.bot.chat_history.room(room.display_name).getLen() < 3: if self.bot.chat_history.room(room.room_id).getLen() < 3:
return return
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) # current
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) chat_history_item = self.bot.chat_history.room(room.room_id).remove(1)
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) chat_history_item = self.bot.chat_history.room(room.room_id).remove(1)
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
return return
elif event.body.startswith('!!'): elif event.body.startswith('!!'):
if self.bot.chat_history.room(room.display_name).getLen() < 3: if self.bot.chat_history.room(room.room_id).getLen() < 3:
return return
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1)# current chat_history_item = self.bot.chat_history.room(room.room_id).remove(1)# current
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) chat_history_item = self.bot.chat_history.room(room.room_id).remove(1)
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_message = self.bot.chat_history.room(room.display_name).getLastItem() # new current chat_message = self.bot.chat_history.room(room.room_id).getLastItem() # new current
# don't return, we generate a new answer # don't return, we generate a new answer
elif event.body.startswith('!replace'): elif event.body.startswith('!replace'):
if self.bot.chat_history.room(room.display_name).getLen() < 3: if self.bot.chat_history.room(room.room_id).getLen() < 3:
return return
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) # current chat_history_item = self.bot.chat_history.room(room.room_id).remove(1) # current
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
chat_history_item = self.bot.chat_history.room(room.display_name).remove(1) chat_history_item = self.bot.chat_history.room(room.room_id).remove(1)
await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request") await self.client.room_redact(room.room_id, chat_history_item.event_id, reason="user-request")
new_answer = event.body.removeprefix('!replace').strip() new_answer = event.body.removeprefix('!replace').strip()
await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event) await self.bot.send_message(self.client, room.room_id, new_answer, reply_to=chat_history_item.relates_to_event)
@ -237,7 +239,7 @@ class Callbacks(object):
# send, mail, drop, snap picture, photo, image, portrait # send, mail, drop, snap picture, photo, image, portrait
pass pass
full_prompt = await text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.display_name)) full_prompt = await text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.room_id))
num_tokens = await text_ai.num_tokens(full_prompt) num_tokens = await text_ai.num_tokens(full_prompt)
logger.debug(full_prompt) logger.debug(full_prompt)
logger.info(f"Prompt has " + str(num_tokens) + " tokens") logger.info(f"Prompt has " + str(num_tokens) + " tokens")
@ -288,11 +290,12 @@ class Callbacks(object):
logger.info(f"event redacted in room {room.room_id}. event_id: {event.redacts}") logger.info(f"event redacted in room {room.room_id}. event_id: {event.redacts}")
for bot in bots: for bot in bots:
# for room in bot.chat_history.chat_rooms.keys(): # for room in bot.chat_history.chat_rooms.keys():
if room.display_name in bot.chat_history.chat_rooms: if room.room_id in bot.chat_history.chat_rooms:
logger.info("room found") logger.info("room found")
if bot.chat_history.chat_rooms[room.display_name].exists_id(event.redacts): if bot.chat_history.chat_rooms[room.room_id].exists_id(event.redacts):
logger.info("found it") logger.info("found it")
bot.chat_history.chat_rooms[room.display_name].remove_id(event.redacts) bot.chat_history.chat_rooms[room.room_id].remove_id(event.redacts)
self.bot.chat_history.room(room.room_id).setFastForward(False)
class ChatBot(object): class ChatBot(object):
"""Main chatbot""" """Main chatbot"""
@ -338,6 +341,7 @@ class ChatBot(object):
self.scenario = scenario self.scenario = scenario
self.greeting = greeting self.greeting = greeting
self.example_dialogue = example_dialogue self.example_dialogue = example_dialogue
self.chat_history = BotChatHistory(self.name)
def get_persona(self): def get_persona(self):
return ' '.join([self.persona, ' '.join(self.extra_info['persona'])]) return ' '.join([self.persona, ' '.join(self.extra_info['persona'])])
@ -374,7 +378,6 @@ class ChatBot(object):
async def login(self): async def login(self):
self.config = AsyncClientConfig(store_sync_tokens=True) self.config = AsyncClientConfig(store_sync_tokens=True)
self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config) self.client = AsyncClient(self.homeserver, self.user_id, store_path=STORE_PATH, config=self.config)
self.chat_history = BotChatHistory(self.name)
self.callbacks = Callbacks(self.client, self) self.callbacks = Callbacks(self.client, self)
self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText) self.client.add_event_callback(self.callbacks.message_cb, RoomMessageText)
self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent) self.client.add_event_callback(self.callbacks.invite_cb, InviteEvent)

Loading…
Cancel
Save