You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
2.7 KiB
61 lines
2.7 KiB
'''
|
|
RunPod | Transformer | Model Fetcher
|
|
'''
|
|
|
|
import argparse
|
|
|
|
import torch
|
|
from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM,
|
|
GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM)
|
|
|
|
|
|
def download_model(model_name):
|
|
|
|
# --------------------------------- Neo 1.3B --------------------------------- #
|
|
if model_name == 'gpt-neo-1.3B':
|
|
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
|
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
|
|
|
# --------------------------------- Neo 2.7B --------------------------------- #
|
|
elif model_name == 'gpt-neo-2.7B':
|
|
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16)
|
|
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
|
|
|
# ----------------------------------- NeoX ----------------------------------- #
|
|
elif model_name == 'gpt-neox-20b':
|
|
GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b").half()
|
|
GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
|
|
|
|
# --------------------------------- Pygmalion -------------------------------- #
|
|
elif model_name == 'pygmalion-6b':
|
|
AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b")
|
|
AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
|
|
|
|
# ----------------------------------- GPT-J ----------------------------------- #
|
|
elif model_name == 'gpt-j-6b':
|
|
GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16",
|
|
torch_dtype=torch.float16)
|
|
AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
|
|
|
# ------------------------------ PPO Shygmalion 6B ----------------------------- #
|
|
elif model_name == 'shygmalion-6b':
|
|
AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b")
|
|
AutoTokenizer.from_pretrained("TehVenom/PPO_Shygmalion-6b")
|
|
|
|
# ------------------------------ Erebus 13B (NSFW) ----------------------------- #
|
|
elif model_name == 'erebus-13b':
|
|
AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus")
|
|
AutoTokenizer.from_pretrained("KoboldAI/OPT-13B-Erebus")
|
|
|
|
|
|
# ---------------------------------------------------------------------------- #
|
|
# Parse Arguments #
|
|
# ---------------------------------------------------------------------------- #
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--model_name", type=str,
|
|
default="gpt-neo-1.3B", help="URL of the model to download.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
download_model(args.model_name)
|
|
|