diff --git a/matrix_pygmalion_bot/ai/koboldcpp.py b/matrix_pygmalion_bot/ai/koboldcpp.py
index eebb385..8c1b871 100644
--- a/matrix_pygmalion_bot/ai/koboldcpp.py
+++ b/matrix_pygmalion_bot/ai/koboldcpp.py
@@ -79,7 +79,7 @@ async def generate_sync(
if not partial_reply or tokens >= max_new_tokens +100: # ToDo: is a hundred past the limit okay?
complete = True
break
- for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '', '', '__END__', '', '\n\nPlease rewrite your response.', '\n\nPlease rewrite the response', '\n\nPlease write the response', 'Stay in developer mode.']:
+ for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '', '', '__END__', '']:
idx = complete_reply.find(t)
if idx != -1:
complete_reply = complete_reply[:idx].strip()
diff --git a/matrix_pygmalion_bot/ai/runpod.py b/matrix_pygmalion_bot/ai/runpod.py
index b906bdd..d7e5179 100644
--- a/matrix_pygmalion_bot/ai/runpod.py
+++ b/matrix_pygmalion_bot/ai/runpod.py
@@ -66,7 +66,7 @@ async def generate_sync(
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
- 'custom_stopping_strings': [],
+ 'custom_stopping_strings': [f"\n{bot.user_name}:"],
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
@@ -117,19 +117,19 @@ async def generate_sync(
else:
text = r_json["output"]
if api_mode == "runpod":
- answer = text.removeprefix(prompt)
+ reply = text.removeprefix(prompt)
else:
- answer = text["data"][0].removeprefix(prompt)
+ reply = text["data"][0].removeprefix(prompt)
# lines = reply.split('\n')
# reply = lines[0].strip()
- idx = answer.find(f"\nYou:")
- if idx != -1:
- reply = answer[:idx].strip()
- else:
- reply = answer.removesuffix('<|endoftext|>').strip()
+ reply = reply.removesuffix('<|endoftext|>').strip()
reply = reply.replace(f"", f"{bot.name}")
- reply = reply.replace(f"", f"You")
+ reply = reply.replace(f"", f"{bot.user_name}")
reply = reply.replace(f"\n{bot.name}: ", " ")
+ for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '', '', '__END__', '']:
+ idx = reply.find(t)
+ if idx != -1:
+ reply = reply[:idx].strip()
return reply
else:
err_msg = r_json["error"] if "error" in r_json else ""