Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

61 lines
2.7 KiB

'''
RunPod | Transformer | Model Fetcher
'''
import argparse
import torch
from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM,
GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM)
def download_model(model_name):
# --------------------------------- Neo 1.3B --------------------------------- #
if model_name == 'gpt-neo-1.3B':
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
# --------------------------------- Neo 2.7B --------------------------------- #
elif model_name == 'gpt-neo-2.7B':
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16)
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
# ----------------------------------- NeoX ----------------------------------- #
elif model_name == 'gpt-neox-20b':
GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b").half()
GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
# --------------------------------- Pygmalion -------------------------------- #
elif model_name == 'pygmalion-6b':
AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b")
AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
# ----------------------------------- GPT-J ----------------------------------- #
elif model_name == 'gpt-j-6b':
GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16",
torch_dtype=torch.float16)
AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
# ------------------------------ PPO Shygmalion 6B ----------------------------- #
elif model_name == 'shygmalion-6b':
AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b")
AutoTokenizer.from_pretrained("TehVenom/PPO_Shygmalion-6b")
# ------------------------------ Erebus 13B (NSFW) ----------------------------- #
elif model_name == 'erebus-13b':
AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus")
AutoTokenizer.from_pretrained("KoboldAI/OPT-13B-Erebus")
# ---------------------------------------------------------------------------- #
# Parse Arguments #
# ---------------------------------------------------------------------------- #
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_name", type=str,
default="gpt-neo-1.3B", help="URL of the model to download.")
if __name__ == "__main__":
args = parser.parse_args()
download_model(args.model_name)