Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
2.0 KiB

2 years ago
import asyncio
import requests
import json
import logging
logger = logging.getLogger(__name__)
class RunpodWrapper(object):
"""Base Class for runpod"""
def __init__(self, api_key):
self.api_key = api_key
async def generate(self, input_data, endpoint_name, typing_fn, timeout=180):
# Set the API endpoint URL
endpoint = f"https://api.runpod.ai/v2/{endpoint_name}/run"
# Set the headers for the request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
logger.info(f"sending request to runpod.io. endpoint=\"{endpoint_name}\"")
# Make the request
try:
r = requests.post(endpoint, json=input_data, headers=headers, timeout=timeout)
except requests.exceptions.RequestException as e:
raise ValueError(f"<HTTP ERROR>")
r_json = r.json()
logger.debug(r_json)
if r.status_code == 200:
status = r_json["status"]
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 5
for i in range(TIMEOUT//DELAY):
endpoint = f"https://api.runpod.ai/v2/{endpoint_name}/status/{job_id}"
r = requests.get(endpoint, headers=headers)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'IN_PROGRESS':
await typing_fn()
await asyncio.sleep(DELAY)
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
output = r_json["output"]
return output
else:
err_msg = r_json["error"] if "error" in r_json else ""
err_msg = err_msg.replace("\\n", "\n")
raise ValueError(f"<ERROR> RETURN CODE {status}: {err_msg}")
raise ValueError(f"<ERROR> TIMEOUT")
else:
raise ValueError(f"<ERROR>")