diff --git a/matrix_pygmalion_bot/ai/koboldcpp.py b/matrix_pygmalion_bot/ai/koboldcpp.py index 2ecc5de..4b6347e 100644 --- a/matrix_pygmalion_bot/ai/koboldcpp.py +++ b/matrix_pygmalion_bot/ai/koboldcpp.py @@ -37,7 +37,7 @@ async def generate_sync( "Content-Type": "application/json", } - max_new_tokens = 120 + max_new_tokens = 200 prompt_num_tokens = await num_tokens(prompt) # Define your inputs @@ -55,28 +55,45 @@ async def generate_sync( logger.info(f"sending request to koboldcpp") - # Make the request - try: - r = requests.post(endpoint, json=input_data, headers=headers, timeout=360) - except requests.exceptions.RequestException as e: - raise ValueError(f" {e}") - r_json = r.json() - logger.info(r_json) - - if r.status_code == 200: - reply = r_json["results"][0]["text"] - idx = reply.find(f"\nYou:") - if idx != -1: - reply = reply[:idx].strip() + + + TIMEOUT = 360 + DELAY = 5 + tokens = 0 + complete_reply = "" + for i in range(TIMEOUT//DELAY): + input_data["max_length"] = 16 # pseudo streaming + # Make the request + try: + r = requests.post(endpoint, json=input_data, headers=headers, timeout=360) + except requests.exceptions.RequestException as e: + raise ValueError(f" {e}") + r_json = r.json() + logger.info(r_json) + if r.status_code == 200: + partial_reply = r_json["results"][0]["text"] + input_data["prompt"] += partial_reply + complete_reply += partial_reply + tokens += input_data["max_length"] + await typing_fn() + if not partial_reply or partial_reply.find('<|endoftext|>') != -1 or partial_reply.find("\nYou:") != -1 or tokens >= max_new_tokens: + idx = complete_reply.find(f"\nYou:") + if idx != -1: + complete_reply = complete_reply[:idx].strip() + else: + complete_reply = complete_reply.removesuffix('<|endoftext|>').strip() + complete_reply = complete_reply.replace(f"\n{bot.name}: ", " ") + complete_reply = complete_reply.replace(f"\n: ", " ") + complete_reply = complete_reply.replace(f"", "{bot.name}") + complete_reply = complete_reply.replace(f"", "You") + return complete_reply.strip() + else: + continue + elif r.status_code == 503: + #model busy + await asyncio.sleep(DELAY) else: - reply = reply.removesuffix('<|endoftext|>').strip() - reply = reply.replace(f"\n{bot.name}: ", " ") - reply = reply.replace(f"\n: ", " ") - reply = reply.replace(f"", "{bot.name}") - reply = reply.replace(f"", "You") - return reply.strip() - else: - raise ValueError(f"") + raise ValueError(f"") async def generate_image(input_prompt: str, negative_prompt: str, api_url: str, api_key: str, typing_fn): diff --git a/matrix_pygmalion_bot/core.py b/matrix_pygmalion_bot/core.py index 94df4ef..05a0ad4 100644 --- a/matrix_pygmalion_bot/core.py +++ b/matrix_pygmalion_bot/core.py @@ -16,8 +16,8 @@ import json from .helpers import Event from .chatlog import BotChatHistory -ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") -ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp") +image_ai = importlib.import_module("matrix_pygmalion_bot.ai.runpod_pygmalion") +text_ai = importlib.import_module("matrix_pygmalion_bot.ai.koboldcpp") #ai = importlib.import_module("matrix_pygmalion_bot.ai.stablehorde") #from .llama_cpp import generate, get_full_prompt, get_full_prompt_chat_style #from .runpod_pygmalion import generate_sync, get_full_prompt @@ -114,30 +114,30 @@ class Callbacks(object): typing = lambda : self.client.room_typing(room.room_id, True, 15000) if self.bot.service == "runpod": if num == 1: - output = await ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image1(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 2: - output = await ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image2(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 3: - output = await ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image3(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 4: - output = await ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image4(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 5: - output = await ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image5(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 6: - output = await ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image6(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 7: - output = await ai.generate_image7(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image7(prompt, negative_prompt, self.bot.runpod_api_key, typing) elif num == 8: - output = await ai.generate_image8(prompt, negative_prompt, self.bot.runpod_api_key, typing) + output = await image_ai.generate_image8(prompt, negative_prompt, self.bot.runpod_api_key, typing) else: raise ValueError('no image generator with that number') elif self.bot.service == "stablehorde": if num == 1: - output = await ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) + output = await image_ai.generate_image1(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) elif num == 2: - output = await ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) + output = await image_ai.generate_image2(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) elif num == 3: - output = await ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) + output = await image_ai.generate_image3(prompt, negative_prompt, self.bot.stablehorde_api_key, typing) else: raise ValueError('no image generator with that number') else: @@ -216,8 +216,8 @@ class Callbacks(object): # send, mail, drop, snap picture, photo, image, portrait pass - full_prompt = await ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.display_name)) - num_tokens = await ai.num_tokens(full_prompt) + full_prompt = await text_ai.get_full_prompt(chat_message.getTranslation("en"), self.bot, self.bot.chat_history.room(room.display_name)) + num_tokens = await text_ai.num_tokens(full_prompt) logger.debug(full_prompt) logger.debug(f"Prompt has " + str(num_tokens) + " tokens") # answer = "" @@ -237,7 +237,7 @@ class Callbacks(object): # print("") try: typing = lambda : self.client.room_typing(room.room_id, True, 15000) - answer = await ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing, api_endpoint) + answer = await text_ai.generate_sync(full_prompt, self.bot.runpod_api_key, self.bot, typing, api_endpoint) answer = answer.strip() await self.client.room_typing(room.room_id, False) if not (self.bot.translate is None):