Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
6.0 KiB

'''
RunPod | Transformer | Handler
'''
import argparse
import torch
import runpod
from runpod.serverless.utils.rp_validator import validate
from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM,
GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM)
torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_SCHEMA = {
'prompt': {
'type': str,
'required': True
},
'do_sample': {
'type': bool,
'required': False,
'default': True,
'description': '''
Enables decoding strategies such as multinomial sampling,
beam-search multinomial sampling, Top-K sampling and Top-p sampling.
All these strategies select the next token from the probability distribution
over the entire vocabulary with various strategy-specific adjustments.
'''
},
'max_length': {
'type': int,
'required': False,
'default': 100
},
'temperature': {
'type': float,
'required': False,
'default': 0.9
},
'repetition_penalty': {
'type': float,
'required': False,
'default': 1.1
},
'top_p': {
'type': float,
'required': False,
'default': 0.5
},
'top_k': {
'type': int,
'required': False,
'default': 40
},
'typical_p': {
'type': float,
'required': False,
'default': 1.0
},
'encoder_repetition_penalty': {
'type': float,
'required': False,
'default': 1.0
},
'min_length': {
'type': int,
'required': False,
'default': 0
},
'num_beams': {
'type': int,
'required': False,
'default': 1
},
'early_stopping': {
'type': bool,
'required': False,
'default': False
},
'penalty_alpha': {
'type': float,
'required': False,
'default': 0.0
},
'length_penalty': {
'type': float,
'required': False,
'default': 1.0
},
'no_repeat_ngram_size': {
'type': int,
'required': False,
'default': 0
},
}
def generator(job):
'''
Run the job input to generate text output.
'''
# Validate the input
val_input = validate(job['input'], INPUT_SCHEMA)
if 'errors' in val_input:
return {"error": val_input['errors']}
val_input = val_input['validated_input']
input_ids = tokenizer(val_input['prompt'], return_tensors="pt").input_ids.to(device)
gen_tokens = model.generate(
input_ids,
do_sample=val_input['do_sample'],
temperature=val_input['temperature'],
max_length=val_input['max_length'],
repetition_penalty=val_input['repetition_penalty'],
top_p=val_input['top_p'],
top_k=val_input['top_k'],
typical_p=val_input['typical_p'],
encoder_repetition_penalty=val_input['encoder_repetition_penalty'],
min_length=val_input['min_length'],
num_beams=val_input['num_beams'],
early_stopping=val_input['early_stopping'],
penalty_alpha=val_input['penalty_alpha'],
length_penalty=val_input['length_penalty'],
no_repeat_ngram_size=val_input['no_repeat_ngram_size'],
).to(device)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
return gen_text
# ---------------------------------------------------------------------------- #
# Parse Arguments #
# ---------------------------------------------------------------------------- #
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_name", type=str,
default="gpt-neo-1.3B", help="URL of the model to download.")
if __name__ == "__main__":
args = parser.parse_args()
# --------------------------------- Neo 1.3B --------------------------------- #
if args.model_name == 'gpt-neo-1.3B':
model = GPTNeoForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-1.3B", local_files_only=True).to(device)
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B", local_files_only=True)
elif args.model_name == 'gpt-neo-2.7B':
model = GPTNeoForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-2.7B", local_files_only=True, torch_dtype=torch.float16).to(device)
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B", local_files_only=True)
elif args.model_name == 'gpt-neox-20b':
model = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/gpt-neox-20b", local_files_only=True).half().to(device)
tokenizer = GPTNeoXTokenizerFast.from_pretrained(
"EleutherAI/gpt-neox-20b", local_files_only=True)
elif args.model_name == 'pygmalion-6b':
model = AutoModelForCausalLM.from_pretrained(
"PygmalionAI/pygmalion-6b", local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"PygmalionAI/pygmalion-6b", local_files_only=True)
elif args.model_name == 'gpt-j-6b':
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", local_files_only=True, revision="float16",
torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/gpt-j-6B", local_files_only=True)
elif args.model_name == 'shygmalion-6b':
model = AutoModelForCausalLM.from_pretrained(
"TehVenom/PPO_Shygmalion-6b", local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"TehVenom/PPO_Shygmalion-6b", local_files_only=True)
elif args.model_name == 'erebus-13b':
model = AutoModelForCausalLM.from_pretrained(
"KoboldAI/OPT-13B-Erebus", local_files_only=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"KoboldAI/OPT-13B-Erebus", local_files_only=True)
runpod.serverless.start({"handler": generator})