|
|
@ -100,6 +100,11 @@ INPUT_SCHEMA = { |
|
|
|
'required': False, |
|
|
|
'default': 0 |
|
|
|
}, |
|
|
|
'chat_generation_attempts': { |
|
|
|
'type': int, |
|
|
|
'required': False, |
|
|
|
'default': 1 |
|
|
|
}, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
@ -191,6 +196,8 @@ def generator(job): |
|
|
|
|
|
|
|
input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
|
|
output = [] |
|
|
|
for i in range(val_input['chat_generation_attempts']): |
|
|
|
gen_tokens = model.generate( |
|
|
|
input_ids, |
|
|
|
do_sample=val_input['do_sample'], |
|
|
@ -210,8 +217,12 @@ def generator(job): |
|
|
|
).to(device) |
|
|
|
|
|
|
|
gen_text = tokenizer.batch_decode(gen_tokens)[0] |
|
|
|
if val_input['chat_generation_attempts'] == 1: |
|
|
|
output = gen_text |
|
|
|
else: |
|
|
|
output.append(gen_text) |
|
|
|
|
|
|
|
return gen_text |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------- # |
|
|
|