import asyncio import json from .runpod import RunpodWrapper import logging logger = logging.getLogger(__name__) class RunpodTextWrapper(RunpodWrapper): def __init__(self, api_key, endpoint): self.api_key = api_key self.endpoint = endpoint async def generate(self, prompt, endpoint_name, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180): # Define your inputs input_data = { "input": { "prompt": prompt, "max_length": min(max_new_tokens, 2048), "temperature": bot.temperature, "do_sample": True, } } output = await super().generate(input_data, endpoint_name, api_key, typing_fn, timeout) output = output.removeprefix(prompt) return(output) async def generate2(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180): generate(prompt, self.endpoint, typing_fn, temperature, nax_new_tokens, timeout)