Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

32 lines
1009 B

2 years ago
import asyncio
import json
from .runpod import RunpodWrapper
import logging
logger = logging.getLogger(__name__)
class RunpodTextWrapper(RunpodWrapper):
def __init__(self, api_key, endpoint):
self.api_key = api_key
self.endpoint = endpoint
async def generate(self, prompt, endpoint_name, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
# Define your inputs
input_data = {
"input": {
"prompt": prompt,
"max_length": min(max_new_tokens, 2048),
"temperature": bot.temperature,
"do_sample": True,
}
}
output = await super().generate(input_data, endpoint_name, api_key, typing_fn, timeout)
output = output.removeprefix(prompt)
return(output)
async def generate2(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
generate(prompt, self.endpoint, typing_fn, temperature, nax_new_tokens, timeout)