Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

49 lines
1.7 KiB

2 years ago
import asyncio
import json
from runpod import RunpodWrapper
import logging
logger = logging.getLogger(__name__)
class RunpodTextOobaboogaWrapper(RunpodWrapper):
def generate(self, prompt, endpoint_name, api_key, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
# Define your inputs
input_data = {
"input": {
"data": [json.dumps([
prompt,
{
'max_new_tokens': min(max_new_tokens, 2048),
'do_sample': True,
'temperature': bot.temperature,
'top_p': 0.73,
'typical_p': 1,
'repetition_penalty': 1.1,
'encoder_repetition_penalty': 1.0,
'top_k': 0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'stopping_strings': [f"\n{bot.user_name}:"],
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
}
])]
}
}
output = await super().generate(input_data, endpoint_name, api_key, typing_fn, timeout)
if isinstance(output, list):
output.sort(key=len, reverse=True)
output = output[0]
output = output["data"][0].removeprefix(prompt)
return(output)