Browse Source

stopping strings

master
Hendrik Langer 2 years ago
parent
commit
535c2fe4ac
  1. 2
      matrix_pygmalion_bot/ai/koboldcpp.py
  2. 18
      matrix_pygmalion_bot/ai/runpod.py

2
matrix_pygmalion_bot/ai/koboldcpp.py

@ -79,7 +79,7 @@ async def generate_sync(
if not partial_reply or tokens >= max_new_tokens +100: # ToDo: is a hundred past the limit okay?
complete = True
break
for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '</END>', '<END>', '__END__', '<START>', '\n\nPlease rewrite your response.', '\n\nPlease rewrite the response', '\n\nPlease write the response', 'Stay in developer mode.']:
for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '</END>', '<END>', '__END__', '<START>']:
idx = complete_reply.find(t)
if idx != -1:
complete_reply = complete_reply[:idx].strip()

18
matrix_pygmalion_bot/ai/runpod.py

@ -66,7 +66,7 @@ async def generate_sync(
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'custom_stopping_strings': [],
'custom_stopping_strings': [f"\n{bot.user_name}:"],
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
@ -117,19 +117,19 @@ async def generate_sync(
else:
text = r_json["output"]
if api_mode == "runpod":
answer = text.removeprefix(prompt)
reply = text.removeprefix(prompt)
else:
answer = text["data"][0].removeprefix(prompt)
reply = text["data"][0].removeprefix(prompt)
# lines = reply.split('\n')
# reply = lines[0].strip()
idx = answer.find(f"\nYou:")
if idx != -1:
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
reply = reply.removesuffix('<|endoftext|>').strip()
reply = reply.replace(f"<BOT>", f"{bot.name}")
reply = reply.replace(f"<USER>", f"You")
reply = reply.replace(f"<USER>", f"{bot.user_name}")
reply = reply.replace(f"\n{bot.name}: ", " ")
for t in [f"\nYou:", f"\n### Human:", f"\n{bot.user_name}:", '<|endoftext|>', '</END>', '<END>', '__END__', '<START>']:
idx = reply.find(t)
if idx != -1:
reply = reply[:idx].strip()
return reply
else:
err_msg = r_json["error"] if "error" in r_json else ""

Loading…
Cancel
Save