''' RunPod | serverless-ckpt-template | model_fetcher.py Downloads the model from the URL passed in. ''' import os import shutil from diffusers import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) import requests import argparse SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" MODEL_CACHE_DIR = "diffusers-cache" def download_model(model_url: str): ''' Downloads the model from the URL passed in. ''' if os.path.exists(MODEL_CACHE_DIR): shutil.rmtree(MODEL_CACHE_DIR) os.makedirs(MODEL_CACHE_DIR, exist_ok=True) # Check if the URL is from huggingface.co, if so, grab the model repo id. if "huggingface.co" in model_url: url_parts = model_url.split("/") model_id = f"{url_parts[-2]}/{url_parts[-1]}" else: downloaded_model = requests.get(model_url, stream=True, timeout=600) with open(f"{MODEL_CACHE_DIR}/model.safetensors", "wb") as f: for chunk in downloaded_model.iter_content(chunk_size=1024): if chunk: f.write(chunk) os.system("wget -q https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_original_stable_diffusion_to_diffusers.py") os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --from_safetensors --checkpoint_path model.safetensors --dump_path {MODEL_CACHE_DIR}/pt") #os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path model.ckpt --dump_path pt") model_id = "pt" saftey_checker = StableDiffusionSafetyChecker.from_pretrained( SAFETY_MODEL_ID, cache_dir=MODEL_CACHE_DIR, ) pipe = StableDiffusionPipeline.from_pretrained( model_id, cache_dir=MODEL_CACHE_DIR, ) # ---------------------------------------------------------------------------- # # Parse Arguments # # ---------------------------------------------------------------------------- # parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model_url", type=str, default="https://huggingface.co/stabilityai/stable-diffusion-2-1", help="URL of the model to download.") if __name__ == "__main__": args = parser.parse_args() download_model(args.model_url)