Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.5 KiB

'''
RunPod | serverless-ckpt-template | model_fetcher.py
Downloads the model from the URL passed in.
'''
import os
import shutil
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
import requests
import argparse
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
MODEL_CACHE_DIR = "diffusers-cache"
def download_model(model_url: str):
'''
Downloads the model from the URL passed in.
'''
if os.path.exists(MODEL_CACHE_DIR):
shutil.rmtree(MODEL_CACHE_DIR)
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
# Check if the URL is from huggingface.co, if so, grab the model repo id.
if "huggingface.co" in model_url:
url_parts = model_url.split("/")
model_id = f"{url_parts[-2]}/{url_parts[-1]}"
else:
downloaded_model = requests.get(model_url, stream=True, timeout=600)
with open(f"{MODEL_CACHE_DIR}/model.safetensors", "wb") as f:
for chunk in downloaded_model.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
os.system("wget -q https://raw.githubusercontent.com/huggingface/diffusers/v0.14.0/scripts/convert_original_stable_diffusion_to_diffusers.py")
os.system(f"pip install omegaconf")
os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --from_safetensors --checkpoint_path model.safetensors --dump_path {MODEL_CACHE_DIR}/model")
os.system(f"rm model.safetensors")
#os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path model.ckpt --dump_path pt")
model_id = "./{MODEL_CACHE_DIR}/model"
saftey_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_MODEL_ID,
cache_dir=MODEL_CACHE_DIR,
)
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
cache_dir=MODEL_CACHE_DIR,
)
# ---------------------------------------------------------------------------- #
# Parse Arguments #
# ---------------------------------------------------------------------------- #
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_url", type=str,
default="https://huggingface.co/stabilityai/stable-diffusion-2-1", help="URL of the model to download.")
if __name__ == "__main__":
args = parser.parse_args()
download_model(args.model_url)