You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.4 KiB
66 lines
2.4 KiB
'''
|
|
RunPod | serverless-ckpt-template | model_fetcher.py
|
|
|
|
Downloads the model from the URL passed in.
|
|
'''
|
|
|
|
import os
|
|
import shutil
|
|
from diffusers import StableDiffusionPipeline
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
StableDiffusionSafetyChecker,
|
|
)
|
|
|
|
import requests
|
|
import argparse
|
|
|
|
|
|
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
|
|
MODEL_CACHE_DIR = "diffusers-cache"
|
|
|
|
|
|
def download_model(model_url: str):
|
|
'''
|
|
Downloads the model from the URL passed in.
|
|
'''
|
|
if os.path.exists(MODEL_CACHE_DIR):
|
|
shutil.rmtree(MODEL_CACHE_DIR)
|
|
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
|
|
|
# Check if the URL is from huggingface.co, if so, grab the model repo id.
|
|
if "huggingface.co" in model_url:
|
|
url_parts = model_url.split("/")
|
|
model_id = f"{url_parts[-2]}/{url_parts[-1]}"
|
|
else:
|
|
downloaded_model = requests.get(model_url, stream=True, timeout=600)
|
|
with open(f"{MODEL_CACHE_DIR}/model.safetensors", "wb") as f:
|
|
for chunk in downloaded_model.iter_content(chunk_size=1024):
|
|
if chunk:
|
|
f.write(chunk)
|
|
os.system("wget -q https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_original_stable_diffusion_to_diffusers.py")
|
|
os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --from_safetensors --checkpoint_path model.safetensors --dump_path {MODEL_CACHE_DIR}/pt")
|
|
#os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path model.ckpt --dump_path pt")
|
|
model_id = "pt"
|
|
|
|
saftey_checker = StableDiffusionSafetyChecker.from_pretrained(
|
|
SAFETY_MODEL_ID,
|
|
cache_dir=MODEL_CACHE_DIR,
|
|
)
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(
|
|
model_id,
|
|
cache_dir=MODEL_CACHE_DIR,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------- #
|
|
# Parse Arguments #
|
|
# ---------------------------------------------------------------------------- #
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--model_url", type=str,
|
|
default="https://huggingface.co/stabilityai/stable-diffusion-2-1", help="URL of the model to download.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
download_model(args.model_url)
|
|
|