From 140ce0b5ef09e102478cf237276d4a13dd68940f Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Thu, 6 Apr 2023 19:14:24 +0200 Subject: [PATCH] async api --- matrix_pygmalion_bot/ai/runpod_pygmalion.py | 28 ++++++--------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod_pygmalion.py index a872943..0d63f02 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod_pygmalion.py @@ -21,7 +21,7 @@ async def generate_sync( bot, ): # Set the API endpoint URL - endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync" + endpoint = "https://api.runpod.ai/v2/pygmalion-6b/run" # Set the headers for the request headers = { @@ -46,25 +46,11 @@ async def generate_sync( # Make the request r = requests.post(endpoint, json=input_data, headers=headers, timeout=180) - r_json = r.json() logger.info(r_json) - status = r_json["status"] - - if status == 'COMPLETED': - text = r_json["output"] - answer = text.removeprefix(prompt) -# lines = reply.split('\n') -# reply = lines[0].strip() - idx = answer.find(f"\nYou:") - if idx != -1: - reply = answer[:idx].strip() - else: - reply = answer.removesuffix('<|endoftext|>').strip() - reply = reply.replace("\n{bot.name}: ", " ") - reply = reply.replace("\n: ", " ") - return reply - elif status == 'IN_PROGRESS' or status == 'IN_QUEUE': + + if r.status_code == 200: + status = r_json["status"] job_id = r_json["id"] TIMEOUT = 360 DELAY = 5 @@ -92,9 +78,11 @@ async def generate_sync( reply = reply.replace("\n: ", " ") return reply else: - return "" + err_msg = r_json["error"] if "error" in r_json else "" + raise ValueError(f"RETURN CODE {status}: {err_msg}") + raise ValueError(f"") else: - return "" + raise ValueError(f"") async def get_full_prompt(simple_prompt: str, bot, chat_history):