|
|
@ -100,6 +100,11 @@ INPUT_SCHEMA = { |
|
|
|
'required': False, |
|
|
|
'default': 0 |
|
|
|
}, |
|
|
|
'chat_generation_attempts': { |
|
|
|
'type': int, |
|
|
|
'required': False, |
|
|
|
'default': 1 |
|
|
|
}, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
@ -191,27 +196,33 @@ def generator(job): |
|
|
|
|
|
|
|
input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
|
|
gen_tokens = model.generate( |
|
|
|
input_ids, |
|
|
|
do_sample=val_input['do_sample'], |
|
|
|
temperature=val_input['temperature'], |
|
|
|
max_length=val_input['max_length'], |
|
|
|
repetition_penalty=val_input['repetition_penalty'], |
|
|
|
top_p=val_input['top_p'], |
|
|
|
top_k=val_input['top_k'], |
|
|
|
typical_p=val_input['typical_p'], |
|
|
|
encoder_repetition_penalty=val_input['encoder_repetition_penalty'], |
|
|
|
min_length=val_input['min_length'], |
|
|
|
num_beams=val_input['num_beams'], |
|
|
|
early_stopping=val_input['early_stopping'], |
|
|
|
penalty_alpha=val_input['penalty_alpha'], |
|
|
|
length_penalty=val_input['length_penalty'], |
|
|
|
no_repeat_ngram_size=val_input['no_repeat_ngram_size'], |
|
|
|
).to(device) |
|
|
|
|
|
|
|
gen_text = tokenizer.batch_decode(gen_tokens)[0] |
|
|
|
|
|
|
|
return gen_text |
|
|
|
output = [] |
|
|
|
for i in range(val_input['chat_generation_attempts']): |
|
|
|
gen_tokens = model.generate( |
|
|
|
input_ids, |
|
|
|
do_sample=val_input['do_sample'], |
|
|
|
temperature=val_input['temperature'], |
|
|
|
max_length=val_input['max_length'], |
|
|
|
repetition_penalty=val_input['repetition_penalty'], |
|
|
|
top_p=val_input['top_p'], |
|
|
|
top_k=val_input['top_k'], |
|
|
|
typical_p=val_input['typical_p'], |
|
|
|
encoder_repetition_penalty=val_input['encoder_repetition_penalty'], |
|
|
|
min_length=val_input['min_length'], |
|
|
|
num_beams=val_input['num_beams'], |
|
|
|
early_stopping=val_input['early_stopping'], |
|
|
|
penalty_alpha=val_input['penalty_alpha'], |
|
|
|
length_penalty=val_input['length_penalty'], |
|
|
|
no_repeat_ngram_size=val_input['no_repeat_ngram_size'], |
|
|
|
).to(device) |
|
|
|
|
|
|
|
gen_text = tokenizer.batch_decode(gen_tokens)[0] |
|
|
|
if val_input['chat_generation_attempts'] == 1: |
|
|
|
output = gen_text |
|
|
|
else: |
|
|
|
output.append(gen_text) |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------- # |
|
|
|