|
|
@ -21,7 +21,7 @@ async def generate_sync( |
|
|
|
bot, |
|
|
|
): |
|
|
|
# Set the API endpoint URL |
|
|
|
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/runsync" |
|
|
|
endpoint = "https://api.runpod.ai/v2/pygmalion-6b/run" |
|
|
|
|
|
|
|
# Set the headers for the request |
|
|
|
headers = { |
|
|
@ -46,25 +46,11 @@ async def generate_sync( |
|
|
|
|
|
|
|
# Make the request |
|
|
|
r = requests.post(endpoint, json=input_data, headers=headers, timeout=180) |
|
|
|
|
|
|
|
r_json = r.json() |
|
|
|
logger.info(r_json) |
|
|
|
status = r_json["status"] |
|
|
|
|
|
|
|
if status == 'COMPLETED': |
|
|
|
text = r_json["output"] |
|
|
|
answer = text.removeprefix(prompt) |
|
|
|
# lines = reply.split('\n') |
|
|
|
# reply = lines[0].strip() |
|
|
|
idx = answer.find(f"\nYou:") |
|
|
|
if idx != -1: |
|
|
|
reply = answer[:idx].strip() |
|
|
|
else: |
|
|
|
reply = answer.removesuffix('<|endoftext|>').strip() |
|
|
|
reply = reply.replace("\n{bot.name}: ", " ") |
|
|
|
reply = reply.replace("\n<BOT>: ", " ") |
|
|
|
return reply |
|
|
|
elif status == 'IN_PROGRESS' or status == 'IN_QUEUE': |
|
|
|
|
|
|
|
if r.status_code == 200: |
|
|
|
status = r_json["status"] |
|
|
|
job_id = r_json["id"] |
|
|
|
TIMEOUT = 360 |
|
|
|
DELAY = 5 |
|
|
@ -92,9 +78,11 @@ async def generate_sync( |
|
|
|
reply = reply.replace("\n<BOT>: ", " ") |
|
|
|
return reply |
|
|
|
else: |
|
|
|
return "<ERROR>" |
|
|
|
err_msg = r_json["error"] if "error" in r_json else "" |
|
|
|
raise ValueError(f"RETURN CODE {status}: {err_msg}") |
|
|
|
raise ValueError(f"<TIMEOUT>") |
|
|
|
else: |
|
|
|
return "<ERROR>" |
|
|
|
raise ValueError(f"<ERROR>") |
|
|
|
|
|
|
|
async def get_full_prompt(simple_prompt: str, bot, chat_history): |
|
|
|
|
|
|
|