diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod_pygmalion.py index 02385a5..e6c75e1 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod_pygmalion.py @@ -68,7 +68,12 @@ async def generate_sync( elif status == 'IN_QUEUE': await asyncio.sleep(DELAY) elif status == 'COMPLETED': - text = r_json["output"] + output = r_json["output"] + if isinstance(output, list): + output.sort(key=len, reverse=True) + text = output[0] + else: + text = r_json["output"] answer = text.removeprefix(prompt) # lines = reply.split('\n') # reply = lines[0].strip() diff --git a/runpod/runpod-worker-transformers/Dockerfile b/runpod/runpod-worker-transformers/Dockerfile index 26cd367..57b7dcd 100644 --- a/runpod/runpod-worker-transformers/Dockerfile +++ b/runpod/runpod-worker-transformers/Dockerfile @@ -46,7 +46,8 @@ RUN apt-get update --yes && \ RUN apt-get update --yes && \ apt install --yes --no-install-recommends \ - python3 python3-dev python3-venv python3-pip && \ + python3 python3-dev python3-venv python3-pip \ + cython3 && \ apt-get clean && rm -rf /var/lib/apt/lists/* ENV TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX" @@ -85,6 +86,7 @@ WORKDIR /workspace RUN apt-get update --yes && \ apt install --yes --no-install-recommends \ python3 python3-dev python3-venv python3-pip \ + cython3 \ git && \ apt-get clean && rm -rf /var/lib/apt/lists/* diff --git a/runpod/runpod-worker-transformers/runpod_infer.py b/runpod/runpod-worker-transformers/runpod_infer.py index c59567e..a04370e 100644 --- a/runpod/runpod-worker-transformers/runpod_infer.py +++ b/runpod/runpod-worker-transformers/runpod_infer.py @@ -100,6 +100,11 @@ INPUT_SCHEMA = { 'required': False, 'default': 0 }, + 'chat_generation_attempts': { + 'type': int, + 'required': False, + 'default': 1 + }, } @@ -191,27 +196,33 @@ def generator(job): input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device) - gen_tokens = model.generate( - input_ids, - do_sample=val_input['do_sample'], - temperature=val_input['temperature'], - max_length=val_input['max_length'], - repetition_penalty=val_input['repetition_penalty'], - top_p=val_input['top_p'], - top_k=val_input['top_k'], - typical_p=val_input['typical_p'], - encoder_repetition_penalty=val_input['encoder_repetition_penalty'], - min_length=val_input['min_length'], - num_beams=val_input['num_beams'], - early_stopping=val_input['early_stopping'], - penalty_alpha=val_input['penalty_alpha'], - length_penalty=val_input['length_penalty'], - no_repeat_ngram_size=val_input['no_repeat_ngram_size'], - ).to(device) - - gen_text = tokenizer.batch_decode(gen_tokens)[0] - - return gen_text + output = [] + for i in range(val_input['chat_generation_attempts']): + gen_tokens = model.generate( + input_ids, + do_sample=val_input['do_sample'], + temperature=val_input['temperature'], + max_length=val_input['max_length'], + repetition_penalty=val_input['repetition_penalty'], + top_p=val_input['top_p'], + top_k=val_input['top_k'], + typical_p=val_input['typical_p'], + encoder_repetition_penalty=val_input['encoder_repetition_penalty'], + min_length=val_input['min_length'], + num_beams=val_input['num_beams'], + early_stopping=val_input['early_stopping'], + penalty_alpha=val_input['penalty_alpha'], + length_penalty=val_input['length_penalty'], + no_repeat_ngram_size=val_input['no_repeat_ngram_size'], + ).to(device) + + gen_text = tokenizer.batch_decode(gen_tokens)[0] + if val_input['chat_generation_attempts'] == 1: + output = gen_text + else: + output.append(gen_text) + + return output # ---------------------------------------------------------------------------- #