diff --git a/matrix_pygmalion_bot/ai/koboldcpp.py b/matrix_pygmalion_bot/ai/koboldcpp.py index eebb385..8c1b871 100644 --- a/matrix_pygmalion_bot/ai/koboldcpp.py +++ b/matrix_pygmalion_bot/ai/koboldcpp.py @@ -79,7 +79,7 @@ async def generate_sync( if not partial_reply or tokens >= max_new_tokens +100: # ToDo: is a hundred past the limit okay? complete = True break - for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '', '', '__END__', '', '\n\nPlease rewrite your response.', '\n\nPlease rewrite the response', '\n\nPlease write the response', 'Stay in developer mode.']: + for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '', '', '__END__', '']: idx = complete_reply.find(t) if idx != -1: complete_reply = complete_reply[:idx].strip() diff --git a/matrix_pygmalion_bot/ai/runpod.py b/matrix_pygmalion_bot/ai/runpod.py index b906bdd..d7e5179 100644 --- a/matrix_pygmalion_bot/ai/runpod.py +++ b/matrix_pygmalion_bot/ai/runpod.py @@ -66,7 +66,7 @@ async def generate_sync( 'early_stopping': False, 'seed': -1, 'add_bos_token': True, - 'custom_stopping_strings': [], + 'custom_stopping_strings': [f"\n{bot.user_name}:"], 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, @@ -117,19 +117,19 @@ async def generate_sync( else: text = r_json["output"] if api_mode == "runpod": - answer = text.removeprefix(prompt) + reply = text.removeprefix(prompt) else: - answer = text["data"][0].removeprefix(prompt) + reply = text["data"][0].removeprefix(prompt) # lines = reply.split('\n') # reply = lines[0].strip() - idx = answer.find(f"\nYou:") - if idx != -1: - reply = answer[:idx].strip() - else: - reply = answer.removesuffix('<|endoftext|>').strip() + reply = reply.removesuffix('<|endoftext|>').strip() reply = reply.replace(f"", f"{bot.name}") - reply = reply.replace(f"", f"You") + reply = reply.replace(f"", f"{bot.user_name}") reply = reply.replace(f"\n{bot.name}: ", " ") + for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '', '', '__END__', '']: + idx = reply.find(t) + if idx != -1: + reply = reply[:idx].strip() return reply else: err_msg = r_json["error"] if "error" in r_json else ""