You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
1.7 KiB
48 lines
1.7 KiB
import asyncio
|
|
import json
|
|
from runpod import RunpodWrapper
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RunpodTextOobaboogaWrapper(RunpodWrapper):
|
|
|
|
def generate(self, prompt, api_key, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
|
|
|
|
# Define your inputs
|
|
input_data = {
|
|
"input": {
|
|
"data": [json.dumps([
|
|
prompt,
|
|
{
|
|
'max_new_tokens': min(max_new_tokens, 2048),
|
|
'do_sample': True,
|
|
'temperature': bot.temperature,
|
|
'top_p': 0.73,
|
|
'typical_p': 1,
|
|
'repetition_penalty': 1.1,
|
|
'encoder_repetition_penalty': 1.0,
|
|
'top_k': 0,
|
|
'min_length': 0,
|
|
'no_repeat_ngram_size': 0,
|
|
'num_beams': 1,
|
|
'penalty_alpha': 0,
|
|
'length_penalty': 1,
|
|
'early_stopping': False,
|
|
'seed': -1,
|
|
'add_bos_token': True,
|
|
'stopping_strings': [f"\n{bot.user_name}:"],
|
|
'truncation_length': 2048,
|
|
'ban_eos_token': False,
|
|
'skip_special_tokens': True,
|
|
}
|
|
])]
|
|
}
|
|
}
|
|
output = await super().generate(input_data, api_key, typing_fn, timeout)
|
|
if isinstance(output, list):
|
|
output.sort(key=len, reverse=True)
|
|
output = output[0]
|
|
output = output["data"][0].removeprefix(prompt)
|
|
return(output)
|
|
|