You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
6.0 KiB
190 lines
6.0 KiB
2 years ago
|
'''
|
||
|
RunPod | Transformer | Handler
|
||
|
'''
|
||
|
import argparse
|
||
|
|
||
|
import torch
|
||
|
import runpod
|
||
|
from runpod.serverless.utils.rp_validator import validate
|
||
|
from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM,
|
||
|
GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM)
|
||
|
|
||
|
|
||
|
torch.cuda.is_available()
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
|
||
|
INPUT_SCHEMA = {
|
||
|
'prompt': {
|
||
|
'type': str,
|
||
|
'required': True
|
||
|
},
|
||
|
'do_sample': {
|
||
|
'type': bool,
|
||
|
'required': False,
|
||
|
'default': True,
|
||
|
'description': '''
|
||
|
Enables decoding strategies such as multinomial sampling,
|
||
|
beam-search multinomial sampling, Top-K sampling and Top-p sampling.
|
||
|
All these strategies select the next token from the probability distribution
|
||
|
over the entire vocabulary with various strategy-specific adjustments.
|
||
|
'''
|
||
|
},
|
||
|
'max_length': {
|
||
|
'type': int,
|
||
|
'required': False,
|
||
|
'default': 100
|
||
|
},
|
||
|
'temperature': {
|
||
|
'type': float,
|
||
|
'required': False,
|
||
|
'default': 0.9
|
||
|
},
|
||
|
'repetition_penalty': {
|
||
|
'type': float,
|
||
|
'required': False,
|
||
|
'default': 1.1
|
||
|
},
|
||
|
'top_p': {
|
||
|
'type': float,
|
||
|
'required': False,
|
||
|
'default': 0.5
|
||
|
},
|
||
|
'top_k': {
|
||
|
'type': int,
|
||
|
'required': False,
|
||
|
'default': 40
|
||
|
},
|
||
|
'typical_p': {
|
||
|
'type': float,
|
||
|
'required': False,
|
||
|
'default': 1.0
|
||
|
},
|
||
|
'encoder_repetition_penalty': {
|
||
|
'type': float,
|
||
|
'required': False,
|
||
|
'default': 1.0
|
||
|
},
|
||
|
'min_length': {
|
||
|
'type': int,
|
||
|
'required': False,
|
||
|
'default': 0
|
||
|
},
|
||
|
'num_beams': {
|
||
|
'type': int,
|
||
|
'required': False,
|
||
|
'default': 1
|
||
|
},
|
||
|
'early_stopping': {
|
||
|
'type': bool,
|
||
|
'required': False,
|
||
|
'default': False
|
||
|
},
|
||
|
'penalty_alpha': {
|
||
|
'type': float,
|
||
|
'required': False,
|
||
|
'default': 0.0
|
||
|
},
|
||
|
'length_penalty': {
|
||
|
'type': float,
|
||
|
'required': False,
|
||
|
'default': 1.0
|
||
|
},
|
||
|
'no_repeat_ngram_size': {
|
||
|
'type': int,
|
||
|
'required': False,
|
||
|
'default': 0
|
||
|
},
|
||
|
|
||
|
}
|
||
|
|
||
|
|
||
|
def generator(job):
|
||
|
'''
|
||
|
Run the job input to generate text output.
|
||
|
'''
|
||
|
# Validate the input
|
||
|
val_input = validate(job['input'], INPUT_SCHEMA)
|
||
|
if 'errors' in val_input:
|
||
|
return {"error": val_input['errors']}
|
||
|
val_input = val_input['validated_input']
|
||
|
|
||
|
input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device)
|
||
|
|
||
|
gen_tokens = model.generate(
|
||
|
input_ids,
|
||
|
do_sample=val_input['do_sample'],
|
||
|
temperature=val_input['temperature'],
|
||
|
max_length=val_input['max_length'],
|
||
|
repetition_penalty=val_input['repetition_penalty'],
|
||
|
top_p=val_input['top_p'],
|
||
|
top_k=val_input['top_k'],
|
||
|
typical_p=val_input['typical_p'],
|
||
|
encoder_repetition_penalty=val_input['encoder_repetition_penalty'],
|
||
|
min_length=val_input['min_length'],
|
||
|
num_beams=val_input['num_beams'],
|
||
|
early_stopping=val_input['early_stopping'],
|
||
|
penalty_alpha=val_input['penalty_alpha'],
|
||
|
length_penalty=val_input['length_penalty'],
|
||
|
no_repeat_ngram_size=val_input['no_repeat_ngram_size'],
|
||
|
).to(device)
|
||
|
|
||
|
gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||
|
|
||
|
return gen_text
|
||
|
|
||
|
|
||
|
# ---------------------------------------------------------------------------- #
|
||
|
# Parse Arguments #
|
||
|
# ---------------------------------------------------------------------------- #
|
||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||
|
parser.add_argument("--model_name", type=str,
|
||
|
default="gpt-neo-1.3B", help="URL of the model to download.")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# --------------------------------- Neo 1.3B --------------------------------- #
|
||
|
if args.model_name == 'gpt-neo-1.3B':
|
||
|
model = GPTNeoForCausalLM.from_pretrained(
|
||
|
"EleutherAI/gpt-neo-1.3B", local_files_only=True).to(device)
|
||
|
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B", local_files_only=True)
|
||
|
|
||
|
elif args.model_name == 'gpt-neo-2.7B':
|
||
|
model = GPTNeoForCausalLM.from_pretrained(
|
||
|
"EleutherAI/gpt-neo-2.7B", local_files_only=True, torch_dtype=torch.float16).to(device)
|
||
|
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B", local_files_only=True)
|
||
|
|
||
|
elif args.model_name == 'gpt-neox-20b':
|
||
|
model = GPTNeoXForCausalLM.from_pretrained(
|
||
|
"EleutherAI/gpt-neox-20b", local_files_only=True).half().to(device)
|
||
|
tokenizer = GPTNeoXTokenizerFast.from_pretrained(
|
||
|
"EleutherAI/gpt-neox-20b", local_files_only=True)
|
||
|
|
||
|
elif args.model_name == 'pygmalion-6b':
|
||
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
"PygmalionAI/pygmalion-6b", local_files_only=True).to(device)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
"PygmalionAI/pygmalion-6b", local_files_only=True)
|
||
|
|
||
|
elif args.model_name == 'gpt-j-6b':
|
||
|
model = GPTJForCausalLM.from_pretrained(
|
||
|
"EleutherAI/gpt-j-6B", local_files_only=True, revision="float16",
|
||
|
torch_dtype=torch.float16).to(device)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
"EleutherAI/gpt-j-6B", local_files_only=True)
|
||
|
|
||
|
elif args.model_name == 'shygmalion-6b':
|
||
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
"TehVenom/PPO_Shygmalion-6b", local_files_only=True).to(device)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
"TehVenom/PPO_Shygmalion-6b", local_files_only=True)
|
||
|
|
||
|
elif args.model_name == 'erebus-13b':
|
||
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
"KoboldAI/OPT-13B-Erebus", local_files_only=True).to(device)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
"KoboldAI/OPT-13B-Erebus", local_files_only=True)
|
||
|
|
||
|
runpod.serverless.start({"handler": generator})
|