|
|
@ -7,7 +7,7 @@ import argparse |
|
|
|
import torch |
|
|
|
from transformers import (GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoXForCausalLM, |
|
|
|
GPTNeoXTokenizerFast, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM) |
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
def download_model(model_name): |
|
|
|
|
|
|
@ -28,8 +28,9 @@ def download_model(model_name): |
|
|
|
|
|
|
|
# --------------------------------- Pygmalion -------------------------------- # |
|
|
|
elif model_name == 'pygmalion-6b': |
|
|
|
AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b") |
|
|
|
AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
|
|
|
# AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b", load_in_8bit=True) |
|
|
|
# AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") |
|
|
|
snapshot_download(repo_id="PygmalionAI/pygmalion-6b", revision="main") |
|
|
|
|
|
|
|
# ----------------------------------- GPT-J ----------------------------------- # |
|
|
|
elif model_name == 'gpt-j-6b': |
|
|
@ -39,17 +40,17 @@ def download_model(model_name): |
|
|
|
|
|
|
|
# ------------------------------ PPO Shygmalion 6B ----------------------------- # |
|
|
|
elif model_name == 'ppo-shygmalion-6b': |
|
|
|
AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b") |
|
|
|
AutoModelForCausalLM.from_pretrained("TehVenom/PPO_Shygmalion-6b", load_in_8bit=True) |
|
|
|
AutoTokenizer.from_pretrained("TehVenom/PPO_Shygmalion-6b") |
|
|
|
|
|
|
|
# ------------------------------ Dolly Shygmalion 6B ----------------------------- # |
|
|
|
elif model_name == 'dolly-shygmalion-6b': |
|
|
|
AutoModelForCausalLM.from_pretrained("TehVenom/Dolly_Shygmalion-6b") |
|
|
|
AutoModelForCausalLM.from_pretrained("TehVenom/Dolly_Shygmalion-6b", load_in_8bit=True) |
|
|
|
AutoTokenizer.from_pretrained("TehVenom/Dolly_Shygmalion-6b") |
|
|
|
|
|
|
|
# ------------------------------ Erebus 13B (NSFW) ----------------------------- # |
|
|
|
elif model_name == 'erebus-13b': |
|
|
|
AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus") |
|
|
|
AutoModelForCausalLM.from_pretrained("KoboldAI/OPT-13B-Erebus", load_in_8bit=True) |
|
|
|
AutoTokenizer.from_pretrained("KoboldAI/OPT-13B-Erebus") |
|
|
|
|
|
|
|
# --------------------------- Alpaca 13B (Quantized) -------------------------- # |
|
|
@ -59,7 +60,7 @@ def download_model(model_name): |
|
|
|
|
|
|
|
# --------------------------------- Alpaca 13B -------------------------------- # |
|
|
|
elif model_name == 'gpt4-x-alpaca': |
|
|
|
AutoModelForCausalLM.from_pretrained("chavinlo/gpt4-x-alpaca") |
|
|
|
AutoModelForCausalLM.from_pretrained("chavinlo/gpt4-x-alpaca", load_in_8bit=True) |
|
|
|
AutoTokenizer.from_pretrained("chavinlo/gpt4-x-alpaca") |
|
|
|
|
|
|
|
|
|
|
|