You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
1009 B
32 lines
1009 B
2 years ago
|
import asyncio
|
||
|
import json
|
||
|
from .runpod import RunpodWrapper
|
||
|
import logging
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class RunpodTextWrapper(RunpodWrapper):
|
||
|
|
||
|
def __init__(self, api_key, endpoint):
|
||
|
self.api_key = api_key
|
||
|
self.endpoint = endpoint
|
||
|
|
||
|
async def generate(self, prompt, endpoint_name, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
|
||
|
|
||
|
# Define your inputs
|
||
|
input_data = {
|
||
|
"input": {
|
||
|
"prompt": prompt,
|
||
|
"max_length": min(max_new_tokens, 2048),
|
||
|
"temperature": bot.temperature,
|
||
|
"do_sample": True,
|
||
|
}
|
||
|
}
|
||
|
output = await super().generate(input_data, endpoint_name, api_key, typing_fn, timeout)
|
||
|
output = output.removeprefix(prompt)
|
||
|
return(output)
|
||
|
|
||
|
async def generate2(self, prompt, typing_fn, temperature=0.72, max_new_tokens=200, timeout=180):
|
||
|
generate(prompt, self.endpoint, typing_fn, temperature, nax_new_tokens, timeout)
|