Browse Source

async api

master
Hendrik Langer 2 years ago
parent
commit
140ce0b5ef
  1. 26
      matrix_pygmalion_bot/ai/runpod_pygmalion.py

26
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -21,7 +21,7 @@ async def generate_sync(
bot,
):
# Set the API endpoint URL
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync"
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/run"
# Set the headers for the request
headers = {
@ -46,25 +46,11 @@ async def generate_sync(
# Make the request
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180)
r_json = r.json()
logger.info(r_json)
status = r_json["status"]
if status == 'COMPLETED':
text = r_json["output"]
answer = text.removeprefix(prompt)
# lines = reply.split('\n')
# reply = lines[0].strip()
idx = answer.find(f"\nYou:")
if idx != -1:
reply = answer[:idx].strip()
else:
reply = answer.removesuffix('<|endoftext|>').strip()
reply = reply.replace("\n{bot.name}: ", " ")
reply = reply.replace("\n<BOT>: ", " ")
return reply
elif status == 'IN_PROGRESS' or status == 'IN_QUEUE':
if r.status_code == 200:
status = r_json["status"]
job_id = r_json["id"]
TIMEOUT = 360
DELAY = 5
@ -92,9 +78,11 @@ async def generate_sync(
reply = reply.replace("\n<BOT>: ", " ")
return reply
else:
return "<ERROR>"
err_msg = r_json["error"] if "error" in r_json else ""
raise ValueError(f"RETURN CODE {status}: {err_msg}")
raise ValueError(f"<TIMEOUT>")
else:
return "<ERROR>"
raise ValueError(f"<ERROR>")
async def get_full_prompt(simple_prompt: str, bot, chat_history):

Loading…
Cancel
Save