''' RunPod | Transformer | Model Fetcher ''' import os import argparse import torch from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM, GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM) from huggingface_hub import snapshot_download, hf_hub_download def download_model(model_name): # --------------------------------- Neo 1.3B --------------------------------- # if model_name == 'gpt-neo-1.3B': GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") # --------------------------------- Neo 2.7B --------------------------------- # elif model_name == 'gpt-neo-2.7B': GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16) GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") # ----------------------------------- NeoX ----------------------------------- # elif model_name == 'gpt-neox-20b': GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b").half() GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") # --------------------------------- Pygmalion -------------------------------- # elif model_name == 'pygmalion-6b': # AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b", load_in_8bit=True) # AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") snapshot_path = snapshot_download(repo_id="PygmalionAI/pygmalion-6b", revision="main") # --------------------------------- Pygmalion -------------------------------- # elif model_name == 'pygmalion-6b-4bit-128g': snapshot_path = snapshot_download(repo_id="mayaeary/pygmalion-6b-4bit-128g", revision="main") # --------------------------------- Pygmalion -------------------------------- # elif model_name == 'pygmalion-6b-gptq-4bit': # AutoModelForCausalLM.from_pretrained("OccamRazor/pygmalion-6b-gptq-4bit", from_pt=True) # AutoTokenizer.from_pretrained("OccamRazor/pygmalion-6b-gptq-4bit") snapshot_path = snapshot_download(repo_id="OccamRazor/pygmalion-6b-gptq-4bit", revision="main") # ----------------------------------- GPT-J ----------------------------------- # elif model_name == 'gpt-j-6b': GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16) AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") # ------------------------------ PPO Shygmalion 6B ----------------------------- # elif model_name == 'ppo-shygmalion-6b': AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b", load_in_8bit=True) AutoTokenizer.from_pretrained("TehVenom/PPO_Shygmalion-6b") # ------------------------------ Dolly Shygmalion 6B ----------------------------- # elif model_name == 'dolly-shygmalion-6b': AutoModelForCausalLM.from_pretrained("TehVenom/Dolly_Shygmalion-6b", load_in_8bit=True) AutoTokenizer.from_pretrained("TehVenom/Dolly_Shygmalion-6b") # ------------------------------ Erebus 13B (NSFW) ----------------------------- # elif model_name == 'erebus-13b': AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus", load_in_8bit=True) AutoTokenizer.from_pretrained("KoboldAI/OPT-13B-Erebus") # --------------------------- Alpaca 13B (Quantized) -------------------------- # elif model_name == 'gpt4-x-alpaca-13b-native-4bit-128g': AutoModelForCausalLM.from_pretrained("anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g") AutoTokenizer.from_pretrained("anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g") # --------------------------------- Alpaca 13B -------------------------------- # elif model_name == 'gpt4-x-alpaca': AutoModelForCausalLM.from_pretrained("chavinlo/gpt4-x-alpaca", load_in_8bit=True) AutoTokenizer.from_pretrained("chavinlo/gpt4-x-alpaca") # --------------------------------- RWKV Raven 7B -------------------------------- # elif model_name == 'rwkv-4-raven-7b': snapshot_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename="RWKV-4-Raven-7B-v8-Eng-20230408-ctx4096.pth") hf_hub_download(repo_id="BlinkDL/Raven-RWKV-7B", filename="20B_tokenizer.json", local_dir=snapshot_path) #https://huggingface.co/yahma/RWKV-14b_quant/resolve/main/RWKV-4-Pile-14B-20230213-8019.pqth if snapshot_path: print(f"model downloaded to \"{snapshot_path}\"") os.system(f"ln -s \"{snapshot_path}\" /workdir/model") # ---------------------------------------------------------------------------- # # Parse Arguments # # ---------------------------------------------------------------------------- # parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model_name", type=str, default="gpt-neo-1.3B", help="URL of the model to download.") if __name__ == "__main__": args = parser.parse_args() download_model(args.model_name)