Browse Source

generate multiple answers

master
Hendrik Langer 2 years ago
parent
commit
72adbf7315
  1. 7
      matrix_pygmalion_bot/ai/runpod_pygmalion.py
  2. 4
      runpod/runpod-worker-transformers/Dockerfile
  3. 53
      runpod/runpod-worker-transformers/runpod_infer.py

7
matrix_pygmalion_bot/ai/runpod_pygmalion.py

@ -68,7 +68,12 @@ async def generate_sync(
elif status == 'IN_QUEUE': elif status == 'IN_QUEUE':
await asyncio.sleep(DELAY) await asyncio.sleep(DELAY)
elif status == 'COMPLETED': elif status == 'COMPLETED':
text = r_json["output"] output = r_json["output"]
if isinstance(output, list):
output.sort(key=len, reverse=True)
text = output[0]
else:
text = r_json["output"]
answer = text.removeprefix(prompt) answer = text.removeprefix(prompt)
# lines = reply.split('\n') # lines = reply.split('\n')
# reply = lines[0].strip() # reply = lines[0].strip()

4
runpod/runpod-worker-transformers/Dockerfile

@ -46,7 +46,8 @@ RUN apt-get update --yes && \
RUN apt-get update --yes && \ RUN apt-get update --yes && \
apt install --yes --no-install-recommends \ apt install --yes --no-install-recommends \
python3 python3-dev python3-venv python3-pip && \ python3 python3-dev python3-venv python3-pip \
cython3 && \
apt-get clean && rm -rf /var/lib/apt/lists/* apt-get clean && rm -rf /var/lib/apt/lists/*
ENV TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX" ENV TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
@ -85,6 +86,7 @@ WORKDIR /workspace
RUN apt-get update --yes && \ RUN apt-get update --yes && \
apt install --yes --no-install-recommends \ apt install --yes --no-install-recommends \
python3 python3-dev python3-venv python3-pip \ python3 python3-dev python3-venv python3-pip \
cython3 \
git && \ git && \
apt-get clean && rm -rf /var/lib/apt/lists/* apt-get clean && rm -rf /var/lib/apt/lists/*

53
runpod/runpod-worker-transformers/runpod_infer.py

@ -100,6 +100,11 @@ INPUT_SCHEMA = {
'required': False, 'required': False,
'default': 0 'default': 0
}, },
'chat_generation_attempts': {
'type': int,
'required': False,
'default': 1
},
} }
@ -191,27 +196,33 @@ def generator(job):
input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device) input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device)
gen_tokens = model.generate( output = []
input_ids, for i in range(val_input['chat_generation_attempts']):
do_sample=val_input['do_sample'], gen_tokens = model.generate(
temperature=val_input['temperature'], input_ids,
max_length=val_input['max_length'], do_sample=val_input['do_sample'],
repetition_penalty=val_input['repetition_penalty'], temperature=val_input['temperature'],
top_p=val_input['top_p'], max_length=val_input['max_length'],
top_k=val_input['top_k'], repetition_penalty=val_input['repetition_penalty'],
typical_p=val_input['typical_p'], top_p=val_input['top_p'],
encoder_repetition_penalty=val_input['encoder_repetition_penalty'], top_k=val_input['top_k'],
min_length=val_input['min_length'], typical_p=val_input['typical_p'],
num_beams=val_input['num_beams'], encoder_repetition_penalty=val_input['encoder_repetition_penalty'],
early_stopping=val_input['early_stopping'], min_length=val_input['min_length'],
penalty_alpha=val_input['penalty_alpha'], num_beams=val_input['num_beams'],
length_penalty=val_input['length_penalty'], early_stopping=val_input['early_stopping'],
no_repeat_ngram_size=val_input['no_repeat_ngram_size'], penalty_alpha=val_input['penalty_alpha'],
).to(device) length_penalty=val_input['length_penalty'],
no_repeat_ngram_size=val_input['no_repeat_ngram_size'],
gen_text = tokenizer.batch_decode(gen_tokens)[0] ).to(device)
return gen_text gen_text = tokenizer.batch_decode(gen_tokens)[0]
if val_input['chat_generation_attempts'] == 1:
output = gen_text
else:
output.append(gen_text)
return output
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #

Loading…
Cancel
Save