You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
2.0 KiB
60 lines
2.0 KiB
2 years ago
|
import asyncio
|
||
|
import requests
|
||
|
import json
|
||
|
import logging
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
class RunpodWrapper(object):
|
||
|
"""Base Class for runpod"""
|
||
|
|
||
|
def __init__(self, api_key):
|
||
|
self.api_key = api_key
|
||
|
|
||
|
async def generate(self, input_data, endpoint_name, typing_fn, timeout=180):
|
||
|
# Set the API endpoint URL
|
||
|
endpoint = f"https://api.runpod.ai/v2/{endpoint_name}/run"
|
||
|
|
||
|
# Set the headers for the request
|
||
|
headers = {
|
||
|
"Content-Type": "application/json",
|
||
|
"Authorization": f"Bearer {self.api_key}"
|
||
|
}
|
||
|
|
||
|
logger.info(f"sending request to runpod.io. endpoint=\"{endpoint_name}\"")
|
||
|
|
||
|
# Make the request
|
||
|
try:
|
||
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout)
|
||
|
except requests.exceptions.RequestException as e:
|
||
|
raise ValueError(f"<HTTP ERROR>")
|
||
|
r_json = r.json()
|
||
|
logger.debug(r_json)
|
||
|
|
||
|
if r.status_code == 200:
|
||
|
status = r_json["status"]
|
||
|
job_id = r_json["id"]
|
||
|
TIMEOUT = 360
|
||
|
DELAY = 5
|
||
|
for i in range(TIMEOUT//DELAY):
|
||
|
endpoint = f"https://api.runpod.ai/v2/{endpoint_name}/status/{job_id}"
|
||
|
r = requests.get(endpoint, headers=headers)
|
||
|
r_json = r.json()
|
||
|
logger.info(r_json)
|
||
|
status = r_json["status"]
|
||
|
if status == 'IN_PROGRESS':
|
||
|
await typing_fn()
|
||
|
await asyncio.sleep(DELAY)
|
||
|
elif status == 'IN_QUEUE':
|
||
|
await asyncio.sleep(DELAY)
|
||
|
elif status == 'COMPLETED':
|
||
|
output = r_json["output"]
|
||
|
return output
|
||
|
else:
|
||
|
err_msg = r_json["error"] if "error" in r_json else ""
|
||
|
err_msg = err_msg.replace("\\n", "\n")
|
||
|
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
|
||
|
raise ValueError(f"<ERROR> TIMEOUT")
|
||
|
else:
|
||
|
raise ValueError(f"<ERROR>")
|