Browse Source

generate multiple answers

master
Hendrik Langer 2 years ago
parent
commit
72adbf7315
  1. 5
      matrix_pygmalion_bot/ai/runpod_pygmalion.py
  2. 4
      runpod/runpod-worker-transformers/Dockerfile
  3. 13
      runpod/runpod-worker-transformers/runpod_infer.py

5
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -68,6 +68,11 @@ async def generate_sync(
elif status == 'IN_QUEUE': elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
elif status == 'COMPLETED': elif status == 'COMPLETED':
output = r_json["output"]
if isinstance(output, list):
output.sort(key=len, reverse=True)
text = output[0]
else:
text = r_json["output"] text = r_json["output"]
answer = text.removeprefix(prompt) answer = text.removeprefix(prompt)
# lines = reply.split('\n') # lines = reply.split('\n')

4
runpod/runpod-worker-transformers/Dockerfile

@ -46,7 +46,8 @@ RUN apt-get update --yes && \
RUN apt-get update --yes && \ RUN apt-get update --yes && \
apt install --yes --no-install-recommends \ apt install --yes --no-install-recommends \
python3 python3-dev python3-venv python3-pip && \ python3 python3-dev python3-venv python3-pip \
cython3 && \
apt-get clean && rm -rf /var/lib/apt/lists/* apt-get clean && rm -rf /var/lib/apt/lists/*
ENV TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX" ENV TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
@ -85,6 +86,7 @@ WORKDIR /workspace
RUN apt-get update --yes && \ RUN apt-get update --yes && \
apt install --yes --no-install-recommends \ apt install --yes --no-install-recommends \
python3 python3-dev python3-venv python3-pip \ python3 python3-dev python3-venv python3-pip \
cython3 \
git && \ git && \
apt-get clean && rm -rf /var/lib/apt/lists/* apt-get clean && rm -rf /var/lib/apt/lists/*

13
runpod/runpod-worker-transformers/runpod_infer.py

@ -100,6 +100,11 @@ INPUT_SCHEMA = {
'required': False, 'required': False,
'default': 0 'default': 0
}, },
'chat_generation_attempts': {
'type': int,
'required': False,
'default': 1
},
} }
@ -191,6 +196,8 @@ def generator(job):
input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device) input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device)
output = []
for i in range(val_input['chat_generation_attempts']):
gen_tokens = model.generate( gen_tokens = model.generate(
input_ids, input_ids,
do_sample=val_input['do_sample'], do_sample=val_input['do_sample'],
@ -210,8 +217,12 @@ def generator(job):
).to(device) ).to(device)
gen_text = tokenizer.batch_decode(gen_tokens)[0] gen_text = tokenizer.batch_decode(gen_tokens)[0]
if val_input['chat_generation_attempts'] == 1:
output = gen_text
else:
output.append(gen_text)
return gen_text return output
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #

Loading…
Cancel
Save