Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

97 lines
5.1 KiB

'''
RunPod | Transformer | Model Fetcher
'''
import os
import argparse
import torch
from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM,
GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM)
from huggingface_hub import snapshot_download, hf_hub_download
def download_model(model_name):
# --------------------------------- Neo 1.3B --------------------------------- #
if model_name == 'gpt-neo-1.3B':
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
# --------------------------------- Neo 2.7B --------------------------------- #
elif model_name == 'gpt-neo-2.7B':
GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16)
GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
# ----------------------------------- NeoX ----------------------------------- #
elif model_name == 'gpt-neox-20b':
GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b").half()
GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
# --------------------------------- Pygmalion -------------------------------- #
elif model_name == 'pygmalion-6b':
# AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b", load_in_8bit=True)
# AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
snapshot_path = snapshot_download(repo_id="PygmalionAI/pygmalion-6b", revision="main")
# --------------------------------- Pygmalion -------------------------------- #
elif model_name == 'pygmalion-6b-4bit-128g':
snapshot_path = snapshot_download(repo_id="mayaeary/pygmalion-6b-4bit-128g", revision="main")
# --------------------------------- Pygmalion -------------------------------- #
elif model_name == 'pygmalion-6b-gptq-4bit':
# AutoModelForCausalLM.from_pretrained("OccamRazor/pygmalion-6b-gptq-4bit", from_pt=True)
# AutoTokenizer.from_pretrained("OccamRazor/pygmalion-6b-gptq-4bit")
snapshot_path = snapshot_download(repo_id="OccamRazor/pygmalion-6b-gptq-4bit", revision="main")
# ----------------------------------- GPT-J ----------------------------------- #
elif model_name == 'gpt-j-6b':
GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16",
torch_dtype=torch.float16)
AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
# ------------------------------ PPO Shygmalion 6B ----------------------------- #
elif model_name == 'ppo-shygmalion-6b':
AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b", load_in_8bit=True)
AutoTokenizer.from_pretrained("TehVenom/PPO_Shygmalion-6b")
# ------------------------------ Dolly Shygmalion 6B ----------------------------- #
elif model_name == 'dolly-shygmalion-6b':
AutoModelForCausalLM.from_pretrained("TehVenom/Dolly_Shygmalion-6b", load_in_8bit=True)
AutoTokenizer.from_pretrained("TehVenom/Dolly_Shygmalion-6b")
# ------------------------------ Erebus 13B (NSFW) ----------------------------- #
elif model_name == 'erebus-13b':
AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus", load_in_8bit=True)
AutoTokenizer.from_pretrained("KoboldAI/OPT-13B-Erebus")
# --------------------------- Alpaca 13B (Quantized) -------------------------- #
elif model_name == 'gpt4-x-alpaca-13b-native-4bit-128g':
AutoModelForCausalLM.from_pretrained("anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g")
AutoTokenizer.from_pretrained("anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g")
# --------------------------------- Alpaca 13B -------------------------------- #
elif model_name == 'gpt4-x-alpaca':
AutoModelForCausalLM.from_pretrained("chavinlo/gpt4-x-alpaca", load_in_8bit=True)
AutoTokenizer.from_pretrained("chavinlo/gpt4-x-alpaca")
# --------------------------------- RWKV Raven 7B -------------------------------- #
elif model_name == 'rwkv-4-raven-7b':
snapshot_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename="RWKV-4-Raven-7B-v8-Eng-20230408-ctx4096.pth")
hf_hub_download(repo_id="BlinkDL/Raven-RWKV-7B", filename="20B_tokenizer.json", local_dir=snapshot_path)
#https://huggingface.co/yahma/RWKV-14b_quant/resolve/main/RWKV-4-Pile-14B-20230213-8019.pqth
if snapshot_path:
print(f"model downloaded to \"{snapshot_path}\"")
os.system(f"ln -s \"{snapshot_path}\" /workdir/model")
# ---------------------------------------------------------------------------- #
# Parse Arguments #
# ---------------------------------------------------------------------------- #
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_name", type=str,
default="gpt-neo-1.3B", help="URL of the model to download.")
if __name__ == "__main__":
args = parser.parse_args()
download_model(args.model_name)