import asyncio import json from runpod import RunpodWrapper import logging logger = logging.getLogger(__name__) class RunpodTextOobaboogaWrapper(RunpodWrapper): def generate(self, prompt, endpoint_name, api_key, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180): # Define your inputs input_data = { "input": { "data": [json.dumps([ prompt, { 'max_new_tokens': min(max_new_tokens, 2048), 'do_sample': True, 'temperature': bot.temperature, 'top_p': 0.73, 'typical_p': 1, 'repetition_penalty': 1.1, 'encoder_repetition_penalty': 1.0, 'top_k': 0, 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, 'penalty_alpha': 0, 'length_penalty': 1, 'early_stopping': False, 'seed': -1, 'add_bos_token': True, 'stopping_strings': [f"\n{bot.user_name}:"], 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, } ])] } } output = await super().generate(input_data, endpoint_name, api_key, typing_fn, timeout) if isinstance(output, list): output.sort(key=len, reverse=True) output = output[0] output = output["data"][0].removeprefix(prompt) return(output)