You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
4.9 KiB
96 lines
4.9 KiB
'''
|
|
RunPod | Transformer | Model Fetcher
|
|
'''
|
|
|
|
import os
|
|
import argparse
|
|
|
|
import torch
|
|
from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM,
|
|
GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM)
|
|
from huggingface_hub import snapshot_download, hf_hub_download
|
|
|
|
def download_model(model_name):
|
|
|
|
# --------------------------------- Neo 1.3B --------------------------------- #
|
|
if model_name == 'gpt-neo-1.3B':
|
|
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
|
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
|
|
|
# --------------------------------- Neo 2.7B --------------------------------- #
|
|
elif model_name == 'gpt-neo-2.7B':
|
|
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16)
|
|
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
|
|
|
# ----------------------------------- NeoX ----------------------------------- #
|
|
elif model_name == 'gpt-neox-20b':
|
|
GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b").half()
|
|
GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
|
|
|
|
# --------------------------------- Pygmalion -------------------------------- #
|
|
elif model_name == 'pygmalion-6b':
|
|
# AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b", load_in_8bit=True)
|
|
# AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
|
|
snapshot_path = snapshot_download(repo_id="PygmalionAI/pygmalion-6b", revision="main")
|
|
|
|
# --------------------------------- Pygmalion -------------------------------- #
|
|
elif model_name == 'pygmalion-6b-4bit-128g':
|
|
snapshot_path = snapshot_download(repo_id="mayaeary/pygmalion-6b-4bit-128g", revision="main")
|
|
|
|
# --------------------------------- Pygmalion -------------------------------- #
|
|
elif model_name == 'pygmalion-6b-gptq-4bit':
|
|
# AutoModelForCausalLM.from_pretrained("OccamRazor/pygmalion-6b-gptq-4bit", from_pt=True)
|
|
# AutoTokenizer.from_pretrained("OccamRazor/pygmalion-6b-gptq-4bit")
|
|
snapshot_path = snapshot_download(repo_id="OccamRazor/pygmalion-6b-gptq-4bit", revision="main")
|
|
|
|
# ----------------------------------- GPT-J ----------------------------------- #
|
|
elif model_name == 'gpt-j-6b':
|
|
GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16",
|
|
torch_dtype=torch.float16)
|
|
AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
|
|
|
# ------------------------------ PPO Shygmalion 6B ----------------------------- #
|
|
elif model_name == 'ppo-shygmalion-6b':
|
|
AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b", load_in_8bit=True)
|
|
AutoTokenizer.from_pretrained("TehVenom/PPO_Shygmalion-6b")
|
|
|
|
# ------------------------------ Dolly Shygmalion 6B ----------------------------- #
|
|
elif model_name == 'dolly-shygmalion-6b':
|
|
AutoModelForCausalLM.from_pretrained("TehVenom/Dolly_Shygmalion-6b", load_in_8bit=True)
|
|
AutoTokenizer.from_pretrained("TehVenom/Dolly_Shygmalion-6b")
|
|
|
|
# ------------------------------ Erebus 13B (NSFW) ----------------------------- #
|
|
elif model_name == 'erebus-13b':
|
|
AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus", load_in_8bit=True)
|
|
AutoTokenizer.from_pretrained("KoboldAI/OPT-13B-Erebus")
|
|
|
|
# --------------------------- Alpaca 13B (Quantized) -------------------------- #
|
|
elif model_name == 'gpt4-x-alpaca-13b-native-4bit-128g':
|
|
AutoModelForCausalLM.from_pretrained("anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g")
|
|
AutoTokenizer.from_pretrained("anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g")
|
|
|
|
# --------------------------------- Alpaca 13B -------------------------------- #
|
|
elif model_name == 'gpt4-x-alpaca':
|
|
AutoModelForCausalLM.from_pretrained("chavinlo/gpt4-x-alpaca", load_in_8bit=True)
|
|
AutoTokenizer.from_pretrained("chavinlo/gpt4-x-alpaca")
|
|
|
|
# --------------------------------- RWKV Raven 7B -------------------------------- #
|
|
elif model_name == 'rwkv-4-raven-7b':
|
|
hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename="RWKV-4-Raven-7B-v7-EngAndMore-20230404-ctx4096.pth")
|
|
#https://huggingface.co/yahma/RWKV-14b_quant/resolve/main/RWKV-4-Pile-14B-20230213-8019.pqth
|
|
|
|
if snapshot_path:
|
|
print(f"model downloaded to \"{snapshot_path}\"")
|
|
os.system(f"ln -s \"{snapshot_path}\" /workdir/model")
|
|
|
|
# ---------------------------------------------------------------------------- #
|
|
# Parse Arguments #
|
|
# ---------------------------------------------------------------------------- #
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--model_name", type=str,
|
|
default="gpt-neo-1.3B", help="URL of the model to download.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
download_model(args.model_name)
|
|
|