Browse Source

generate multiple answers

master
Hendrik Langer 2 years ago
parent
commit
72adbf7315
  1. 7
      matrix_pygmalion_bot/ai/runpod_pygmalion.py
  2. 4
      runpod/runpod-worker-transformers/Dockerfile
  3. 53
      runpod/runpod-worker-transformers/runpod_infer.py

7
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -68,7 +68,12 @@ async def generate_sync(
elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY)
elif status == 'COMPLETED':
text = r_json["output"]
output = r_json["output"]
if isinstance(output, list):
output.sort(key=len, reverse=True)
text = output[0]
else:
text = r_json["output"]
answer = text.removeprefix(prompt)
# lines = reply.split('\n')
# reply = lines[0].strip()

4
runpod/runpod-worker-transformers/Dockerfile

@ -46,7 +46,8 @@ RUN apt-get update --yes && \
RUN apt-get update --yes && \
apt install --yes --no-install-recommends \
python3 python3-dev python3-venv python3-pip && \
python3 python3-dev python3-venv python3-pip \
cython3 && \
apt-get clean && rm -rf /var/lib/apt/lists/*
ENV TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
@ -85,6 +86,7 @@ WORKDIR /workspace
RUN apt-get update --yes && \
apt install --yes --no-install-recommends \
python3 python3-dev python3-venv python3-pip \
cython3 \
git && \
apt-get clean && rm -rf /var/lib/apt/lists/*

53
runpod/runpod-worker-transformers/runpod_infer.py

@ -100,6 +100,11 @@ INPUT_SCHEMA = {
'required': False,
'default': 0
},
'chat_generation_attempts': {
'type': int,
'required': False,
'default': 1
},
}
@ -191,27 +196,33 @@ def generator(job):
input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device)
gen_tokens = model.generate(
input_ids,
do_sample=val_input['do_sample'],
temperature=val_input['temperature'],
max_length=val_input['max_length'],
repetition_penalty=val_input['repetition_penalty'],
top_p=val_input['top_p'],
top_k=val_input['top_k'],
typical_p=val_input['typical_p'],
encoder_repetition_penalty=val_input['encoder_repetition_penalty'],
min_length=val_input['min_length'],
num_beams=val_input['num_beams'],
early_stopping=val_input['early_stopping'],
penalty_alpha=val_input['penalty_alpha'],
length_penalty=val_input['length_penalty'],
no_repeat_ngram_size=val_input['no_repeat_ngram_size'],
).to(device)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
return gen_text
output = []
for i in range(val_input['chat_generation_attempts']):
gen_tokens = model.generate(
input_ids,
do_sample=val_input['do_sample'],
temperature=val_input['temperature'],
max_length=val_input['max_length'],
repetition_penalty=val_input['repetition_penalty'],
top_p=val_input['top_p'],
top_k=val_input['top_k'],
typical_p=val_input['typical_p'],
encoder_repetition_penalty=val_input['encoder_repetition_penalty'],
min_length=val_input['min_length'],
num_beams=val_input['num_beams'],
early_stopping=val_input['early_stopping'],
penalty_alpha=val_input['penalty_alpha'],
length_penalty=val_input['length_penalty'],
no_repeat_ngram_size=val_input['no_repeat_ngram_size'],
).to(device)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
if val_input['chat_generation_attempts'] == 1:
output = gen_text
else:
output.append(gen_text)
return output
# ---------------------------------------------------------------------------- #

Loading…
Cancel
Save