Hendrik Langer
2 years ago
14 changed files with 530 additions and 167 deletions
@ -1,9 +0,0 @@ |
|||||
from python:3.11.1-buster |
|
||||
|
|
||||
WORKDIR / |
|
||||
|
|
||||
RUN pip install runpod |
|
||||
|
|
||||
ADD handler.py . |
|
||||
|
|
||||
CMD [ "python", "-u", "/handler.py" ] |
|
@ -1,21 +0,0 @@ |
|||||
#!/usr/bin/env python |
|
||||
''' Contains the handler function that will be called by the serverless. ''' |
|
||||
|
|
||||
import runpod |
|
||||
|
|
||||
# Load models into VRAM here so they can be warm between requests |
|
||||
|
|
||||
|
|
||||
def handler(event): |
|
||||
''' |
|
||||
This is the handler function that will be called by the serverless. |
|
||||
''' |
|
||||
print(event) |
|
||||
|
|
||||
# do the things |
|
||||
|
|
||||
# return the output that you want to be returned like pre-signed URLs to output artifacts |
|
||||
return "Hello World" |
|
||||
|
|
||||
|
|
||||
runpod.serverless.start({"handler": handler}) |
|
@ -1,18 +0,0 @@ |
|||||
FROM runpod/stable-diffusion:web-automatic-1.5.16 |
|
||||
|
|
||||
SHELL ["/bin/bash", "-c"] |
|
||||
|
|
||||
ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" |
|
||||
|
|
||||
WORKDIR / |
|
||||
|
|
||||
RUN rm /workspace/v1-5-pruned-emaonly.ckpt |
|
||||
RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 |
|
||||
RUN pip install -U xformers |
|
||||
RUN pip install runpod |
|
||||
|
|
||||
ADD handler.py . |
|
||||
ADD start.sh /start.sh |
|
||||
RUN chmod +x /start.sh |
|
||||
|
|
||||
CMD [ "/start.sh" ] |
|
@ -1,39 +0,0 @@ |
|||||
import runpod |
|
||||
import subprocess |
|
||||
import requests |
|
||||
import time |
|
||||
|
|
||||
def check_api_availability(host): |
|
||||
while True: |
|
||||
try: |
|
||||
response = requests.get(host) |
|
||||
return |
|
||||
except requests.exceptions.RequestException as e: |
|
||||
print(f"API is not available, retrying in 200ms... ({e})") |
|
||||
except Exception as e: |
|
||||
print('something went wrong') |
|
||||
time.sleep(200/1000) |
|
||||
|
|
||||
check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") |
|
||||
|
|
||||
print('run handler') |
|
||||
|
|
||||
def handler(event): |
|
||||
''' |
|
||||
This is the handler function that will be called by the serverless. |
|
||||
''' |
|
||||
print('got event') |
|
||||
print(event) |
|
||||
|
|
||||
response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) |
|
||||
|
|
||||
json = response.json() |
|
||||
# do the things |
|
||||
|
|
||||
print(json) |
|
||||
|
|
||||
# return the output that you want to be returned like pre-signed URLs to output artifacts |
|
||||
return json |
|
||||
|
|
||||
|
|
||||
runpod.serverless.start({"handler": handler}) |
|
@ -1,11 +0,0 @@ |
|||||
#!/bin/bash |
|
||||
echo "Container Started" |
|
||||
export PYTHONUNBUFFERED=1 |
|
||||
source /workspace/stable-diffusion-webui/venv/bin/activate |
|
||||
cd /workspace/stable-diffusion-webui |
|
||||
echo "starting api" |
|
||||
python webui.py --port 3000 --nowebui --api --xformers --ckpt /model.safetensors & |
|
||||
cd / |
|
||||
|
|
||||
echo "starting worker" |
|
||||
python -u handler.py |
|
@ -1,18 +0,0 @@ |
|||||
FROM runpod/stable-diffusion:web-automatic-1.5.16 |
|
||||
|
|
||||
SHELL ["/bin/bash", "-c"] |
|
||||
|
|
||||
ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" |
|
||||
|
|
||||
WORKDIR / |
|
||||
|
|
||||
RUN rm /workspace/v1-5-pruned-emaonly.ckpt |
|
||||
RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 |
|
||||
RUN pip install -U xformers |
|
||||
RUN pip install runpod |
|
||||
|
|
||||
ADD handler.py . |
|
||||
ADD start.sh /start.sh |
|
||||
RUN chmod +x /start.sh |
|
||||
|
|
||||
CMD [ "/start.sh" ] |
|
@ -1,39 +0,0 @@ |
|||||
import runpod |
|
||||
import subprocess |
|
||||
import requests |
|
||||
import time |
|
||||
|
|
||||
def check_api_availability(host): |
|
||||
while True: |
|
||||
try: |
|
||||
response = requests.get(host) |
|
||||
return |
|
||||
except requests.exceptions.RequestException as e: |
|
||||
print(f"API is not available, retrying in 200ms... ({e})") |
|
||||
except Exception as e: |
|
||||
print('something went wrong') |
|
||||
time.sleep(200/1000) |
|
||||
|
|
||||
check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") |
|
||||
|
|
||||
print('run handler') |
|
||||
|
|
||||
def handler(event): |
|
||||
''' |
|
||||
This is the handler function that will be called by the serverless. |
|
||||
''' |
|
||||
print('got event') |
|
||||
print(event) |
|
||||
|
|
||||
response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) |
|
||||
|
|
||||
json = response.json() |
|
||||
# do the things |
|
||||
|
|
||||
print(json) |
|
||||
|
|
||||
# return the output that you want to be returned like pre-signed URLs to output artifacts |
|
||||
return json |
|
||||
|
|
||||
|
|
||||
runpod.serverless.start({"handler": handler}) |
|
@ -1,11 +0,0 @@ |
|||||
#!/bin/bash |
|
||||
echo "Container Started" |
|
||||
export PYTHONUNBUFFERED=1 |
|
||||
source /workspace/stable-diffusion-webui/venv/bin/activate |
|
||||
cd /workspace/stable-diffusion-webui |
|
||||
echo "starting api" |
|
||||
python webui.py --port 3000 --nowebui --api --xformers --ckpt /model.safetensors & |
|
||||
cd / |
|
||||
|
|
||||
echo "starting worker" |
|
||||
python -u handler.py |
|
@ -0,0 +1,57 @@ |
|||||
|
ARG BASE_IMAGE=nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 |
||||
|
FROM ${BASE_IMAGE} as dev-base |
||||
|
|
||||
|
WORKDIR / |
||||
|
SHELL ["/bin/bash", "-o", "pipefail", "-c"] |
||||
|
ENV DEBIAN_FRONTEND noninteractive\ |
||||
|
SHELL=/bin/bash |
||||
|
|
||||
|
RUN apt-get update --yes && \ |
||||
|
# - apt-get upgrade is run to patch known vulnerabilities in apt-get packages as |
||||
|
# the ubuntu base image is rebuilt too seldom sometimes (less than once a month) |
||||
|
apt-get upgrade --yes && \ |
||||
|
apt install --yes --no-install-recommends \ |
||||
|
build-essential \ |
||||
|
ca-certificates \ |
||||
|
git \ |
||||
|
git-lfs \ |
||||
|
wget \ |
||||
|
curl \ |
||||
|
bash \ |
||||
|
libgl1 \ |
||||
|
software-properties-common \ |
||||
|
openssh-server && \ |
||||
|
apt-get clean && rm -rf /var/lib/apt/lists/* && \ |
||||
|
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen |
||||
|
|
||||
|
RUN apt-key del 7fa2af80 && \ |
||||
|
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub |
||||
|
|
||||
|
RUN add-apt-repository ppa:deadsnakes/ppa && \ |
||||
|
apt-get install python3.10 python3.10-dev python3.10-venv python3-pip -y --no-install-recommends && \ |
||||
|
apt-get clean && rm -rf /var/lib/apt/lists/* |
||||
|
|
||||
|
RUN pip install --upgrade pip |
||||
|
pip install huggingface-hub && \ |
||||
|
pip install diffusers && \ |
||||
|
pip install torch torchvision torchaudio --extra-index-url=https://download.pytorch.org/whl/cu116 && \ |
||||
|
pip install bitsandbytes |
||||
|
pip install transformers accelerate xformers triton && \ |
||||
|
pip install runpod |
||||
|
|
||||
|
RUN mkdir /workspace |
||||
|
WORKDIR /workspace |
||||
|
|
||||
|
COPY . /workspace/ |
||||
|
#ADD start.sh /workspace/ |
||||
|
|
||||
|
ARG MODEL_URL |
||||
|
ENV MODEL_URL=${MODEL_URL} |
||||
|
|
||||
|
RUN python3 model_fetcher.py --model_url=${MODEL_URL} && \ |
||||
|
echo "Model URL: $MODEL_URL" |
||||
|
#RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 |
||||
|
# PFG: https://civitai.com/api/download/models/1316 |
||||
|
# Hassanblend: https://civitai.com/api/download/models/4635 |
||||
|
|
||||
|
CMD python3 -u runpod_infer.py --model_url="$MODEL_URL" |
@ -0,0 +1,12 @@ |
|||||
|
git clone https://github.com/runpod/serverless-ckpt-template.git |
||||
|
cd serverless-ckpt-template |
||||
|
|
||||
|
docker build --build-arg MODEL_URL=https://huggingface.co/hassanblend/HassanBlend1.5.1.2 -t magn418/runpod-hassan:1.5 . |
||||
|
docker login |
||||
|
docker push magn418/runpod-hassan:1.5 |
||||
|
|
||||
|
Models: |
||||
|
PFG https://civitai.com/models/1227/pfg |
||||
|
hassanblend https://civitai.com/models/1173/hassanblend-15-and-previous-versions |
||||
|
Deliberate |
||||
|
Anything v3 ? |
@ -0,0 +1,66 @@ |
|||||
|
''' |
||||
|
RunPod | serverless-ckpt-template | model_fetcher.py |
||||
|
|
||||
|
Downloads the model from the URL passed in. |
||||
|
''' |
||||
|
|
||||
|
import os |
||||
|
import shutil |
||||
|
from diffusers import StableDiffusionPipeline |
||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import ( |
||||
|
StableDiffusionSafetyChecker, |
||||
|
) |
||||
|
|
||||
|
import requests |
||||
|
import argparse |
||||
|
|
||||
|
|
||||
|
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" |
||||
|
MODEL_CACHE_DIR = "diffusers-cache" |
||||
|
|
||||
|
|
||||
|
def download_model(model_url: str): |
||||
|
''' |
||||
|
Downloads the model from the URL passed in. |
||||
|
''' |
||||
|
if os.path.exists(MODEL_CACHE_DIR): |
||||
|
shutil.rmtree(MODEL_CACHE_DIR) |
||||
|
os.makedirs(MODEL_CACHE_DIR, exist_ok=True) |
||||
|
|
||||
|
# Check if the URL is from huggingface.co, if so, grab the model repo id. |
||||
|
if "huggingface.co" in model_url: |
||||
|
url_parts = model_url.split("/") |
||||
|
model_id = f"{url_parts[-2]}/{url_parts[-1]}" |
||||
|
else: |
||||
|
downloaded_model = requests.get(model_url, stream=True, timeout=600) |
||||
|
with open(f"{MODEL_CACHE_DIR}/model.safetensors", "wb") as f: |
||||
|
for chunk in downloaded_model.iter_content(chunk_size=1024): |
||||
|
if chunk: |
||||
|
f.write(chunk) |
||||
|
os.system("wget -q https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_original_stable_diffusion_to_diffusers.py") |
||||
|
os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --from_safetensors --checkpoint_path model.safetensors --dump_path {MODEL_CACHE_DIR}/pt") |
||||
|
#os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path model.ckpt --dump_path pt") |
||||
|
model_id = "pt" |
||||
|
|
||||
|
saftey_checker = StableDiffusionSafetyChecker.from_pretrained( |
||||
|
SAFETY_MODEL_ID, |
||||
|
cache_dir=MODEL_CACHE_DIR, |
||||
|
) |
||||
|
|
||||
|
pipe = StableDiffusionPipeline.from_pretrained( |
||||
|
model_id, |
||||
|
cache_dir=MODEL_CACHE_DIR, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
# ---------------------------------------------------------------------------- # |
||||
|
# Parse Arguments # |
||||
|
# ---------------------------------------------------------------------------- # |
||||
|
parser = argparse.ArgumentParser(description=__doc__) |
||||
|
parser.add_argument("--model_url", type=str, |
||||
|
default="https://huggingface.co/stabilityai/stable-diffusion-2-1", help="URL of the model to download.") |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
args = parser.parse_args() |
||||
|
download_model(args.model_url) |
@ -0,0 +1,237 @@ |
|||||
|
''' StableDiffusion-v1 Predict Module ''' |
||||
|
|
||||
|
import os |
||||
|
from typing import List |
||||
|
|
||||
|
import torch |
||||
|
from diffusers import ( |
||||
|
StableDiffusionPipeline, |
||||
|
StableDiffusionImg2ImgPipeline, |
||||
|
# StableDiffusionInpaintPipeline, |
||||
|
StableDiffusionInpaintPipelineLegacy, |
||||
|
|
||||
|
DDIMScheduler, |
||||
|
DDPMScheduler, |
||||
|
# DEISMultistepScheduler, |
||||
|
DPMSolverMultistepScheduler, |
||||
|
DPMSolverSinglestepScheduler, |
||||
|
EulerAncestralDiscreteScheduler, |
||||
|
EulerDiscreteScheduler, |
||||
|
HeunDiscreteScheduler, |
||||
|
IPNDMScheduler, |
||||
|
KDPM2AncestralDiscreteScheduler, |
||||
|
KDPM2DiscreteScheduler, |
||||
|
# KarrasVeScheduler, |
||||
|
PNDMScheduler, |
||||
|
# RePaintScheduler, |
||||
|
# ScoreSdeVeScheduler, |
||||
|
# ScoreSdeVpScheduler, |
||||
|
# UnCLIPScheduler, |
||||
|
# VQDiffusionScheduler, |
||||
|
LMSDiscreteScheduler |
||||
|
) |
||||
|
|
||||
|
from PIL import Image |
||||
|
from cog import BasePredictor, Input, Path |
||||
|
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp |
||||
|
|
||||
|
MODEL_CACHE = "diffusers-cache" |
||||
|
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" |
||||
|
|
||||
|
|
||||
|
class Predictor(BasePredictor): |
||||
|
'''Predictor class for StableDiffusion-v1''' |
||||
|
|
||||
|
def __init__(self, model_id): |
||||
|
self.model_id = model_id |
||||
|
|
||||
|
def setup(self): |
||||
|
''' |
||||
|
Load the model into memory to make running multiple predictions efficient |
||||
|
''' |
||||
|
print("Loading pipeline...") |
||||
|
|
||||
|
# safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
||||
|
# SAFETY_MODEL_ID, |
||||
|
# cache_dir=MODEL_CACHE, |
||||
|
# local_files_only=True, |
||||
|
# ) |
||||
|
self.txt2img_pipe = StableDiffusionPipeline.from_pretrained( |
||||
|
self.model_id, |
||||
|
safety_checker=None, |
||||
|
# safety_checker=safety_checker, |
||||
|
cache_dir=MODEL_CACHE, |
||||
|
local_files_only=True, |
||||
|
).to("cuda") |
||||
|
self.img2img_pipe = StableDiffusionImg2ImgPipeline( |
||||
|
vae=self.txt2img_pipe.vae, |
||||
|
text_encoder=self.txt2img_pipe.text_encoder, |
||||
|
tokenizer=self.txt2img_pipe.tokenizer, |
||||
|
unet=self.txt2img_pipe.unet, |
||||
|
scheduler=self.txt2img_pipe.scheduler, |
||||
|
safety_checker=None, |
||||
|
# safety_checker=self.txt2img_pipe.safety_checker, |
||||
|
feature_extractor=self.txt2img_pipe.feature_extractor, |
||||
|
).to("cuda") |
||||
|
self.inpaint_pipe = StableDiffusionInpaintPipelineLegacy( |
||||
|
vae=self.txt2img_pipe.vae, |
||||
|
text_encoder=self.txt2img_pipe.text_encoder, |
||||
|
tokenizer=self.txt2img_pipe.tokenizer, |
||||
|
unet=self.txt2img_pipe.unet, |
||||
|
scheduler=self.txt2img_pipe.scheduler, |
||||
|
safety_checker=None, |
||||
|
# safety_checker=self.txt2img_pipe.safety_checker, |
||||
|
feature_extractor=self.txt2img_pipe.feature_extractor, |
||||
|
).to("cuda") |
||||
|
|
||||
|
self.txt2img_pipe.enable_xformers_memory_efficient_attention() |
||||
|
self.img2img_pipe.enable_xformers_memory_efficient_attention() |
||||
|
self.inpaint_pipe.enable_xformers_memory_efficient_attention() |
||||
|
|
||||
|
@torch.inference_mode() |
||||
|
@torch.cuda.amp.autocast() |
||||
|
def predict( |
||||
|
self, |
||||
|
prompt: str = Input(description="Input prompt", default=""), |
||||
|
negative_prompt: str = Input( |
||||
|
description="Specify things to not see in the output", |
||||
|
default=None, |
||||
|
), |
||||
|
width: int = Input( |
||||
|
description="Output image width; max 1024x768 or 768x1024 due to memory limits", |
||||
|
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024], |
||||
|
default=512, |
||||
|
), |
||||
|
height: int = Input( |
||||
|
description="Output image height; max 1024x768 or 768x1024 due to memory limits", |
||||
|
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024], |
||||
|
default=512, |
||||
|
), |
||||
|
init_image: Path = Input( |
||||
|
description="Initial image to generate variations of, resized to the specified WxH.", |
||||
|
default=None, |
||||
|
), |
||||
|
mask: Path = Input( |
||||
|
description="""Black and white image to use as mask for inpainting over init_image. |
||||
|
Black pixels are inpainted and white pixels are preserved. |
||||
|
Tends to work better with prompt strength of 0.5-0.7""", |
||||
|
default=None, |
||||
|
), |
||||
|
prompt_strength: float = Input( |
||||
|
description="Prompt strength init image. 1.0 full destruction of init image", |
||||
|
default=0.8, |
||||
|
), |
||||
|
num_outputs: int = Input( |
||||
|
description="Number of images to output.", |
||||
|
ge=1, |
||||
|
le=10, |
||||
|
default=1 |
||||
|
), |
||||
|
num_inference_steps: int = Input( |
||||
|
description="Number of denoising steps", ge=1, le=500, default=50 |
||||
|
), |
||||
|
guidance_scale: float = Input( |
||||
|
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5 |
||||
|
), |
||||
|
scheduler: str = Input( |
||||
|
default="K-LMS", |
||||
|
choices=["DDIM", "DDPM", "DPM-M", "DPM-S", "EULER-A", "EULER-D", |
||||
|
"HEUN", "IPNDM", "KDPM2-A", "KDPM2-D", "PNDM", "K-LMS"], |
||||
|
description="Choose a scheduler. If you use an init image, PNDM will be used", |
||||
|
), |
||||
|
seed: int = Input( |
||||
|
description="Random seed. Leave blank to randomize the seed", default=None |
||||
|
), |
||||
|
) -> List[Path]: |
||||
|
''' |
||||
|
Run a single prediction on the model |
||||
|
''' |
||||
|
if seed is None: |
||||
|
seed = int.from_bytes(os.urandom(2), "big") |
||||
|
|
||||
|
if width * height > 786432: |
||||
|
raise ValueError( |
||||
|
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits." |
||||
|
) |
||||
|
|
||||
|
extra_kwargs = {} |
||||
|
if mask: |
||||
|
if not init_image: |
||||
|
raise ValueError("mask was provided without init_image") |
||||
|
|
||||
|
pipe = self.inpaint_pipe |
||||
|
init_image = Image.open(init_image).convert("RGB") |
||||
|
extra_kwargs = { |
||||
|
"mask_image": Image.open(mask).convert("RGB").resize(init_image.size), |
||||
|
"image": init_image, |
||||
|
"strength": prompt_strength, |
||||
|
} |
||||
|
elif init_image: |
||||
|
pipe = self.img2img_pipe |
||||
|
extra_kwargs = { |
||||
|
"init_image": Image.open(init_image).convert("RGB"), |
||||
|
"strength": prompt_strength, |
||||
|
} |
||||
|
else: |
||||
|
pipe = self.txt2img_pipe |
||||
|
extra_kwargs = { |
||||
|
"width": width, |
||||
|
"height": height, |
||||
|
} |
||||
|
|
||||
|
pipe.scheduler = make_scheduler(scheduler, pipe.scheduler.config) |
||||
|
|
||||
|
generator = torch.Generator("cuda").manual_seed(seed) |
||||
|
output = pipe( |
||||
|
prompt=[prompt] * num_outputs if prompt is not None else None, |
||||
|
negative_prompt=[negative_prompt]*num_outputs if negative_prompt is not None else None, |
||||
|
# width=width, |
||||
|
# height=height, |
||||
|
guidance_scale=guidance_scale, |
||||
|
generator=generator, |
||||
|
num_inference_steps=num_inference_steps, |
||||
|
**extra_kwargs, |
||||
|
) |
||||
|
|
||||
|
output_paths = [] |
||||
|
for i, sample in enumerate(output.images): |
||||
|
# if output.nsfw_content_detected and output.nsfw_content_detected[i] and self.NSFW: |
||||
|
# continue |
||||
|
|
||||
|
output_path = f"/tmp/out-{i}.png" |
||||
|
sample.save(output_path) |
||||
|
output_paths.append(Path(output_path)) |
||||
|
|
||||
|
if len(output_paths) == 0: |
||||
|
raise Exception( |
||||
|
"NSFW content detected. Try running it again, or try a different prompt." |
||||
|
) |
||||
|
|
||||
|
return output_paths |
||||
|
|
||||
|
|
||||
|
def make_scheduler(name, config): |
||||
|
''' |
||||
|
Returns a scheduler from a name and config. |
||||
|
''' |
||||
|
return { |
||||
|
"DDIM": DDIMScheduler.from_config(config), |
||||
|
"DDPM": DDPMScheduler.from_config(config), |
||||
|
# "DEIS": DEISMultistepScheduler.from_config(config), |
||||
|
"DPM-M": DPMSolverMultistepScheduler.from_config(config), |
||||
|
"DPM-S": DPMSolverSinglestepScheduler.from_config(config), |
||||
|
"EULER-A": EulerAncestralDiscreteScheduler.from_config(config), |
||||
|
"EULER-D": EulerDiscreteScheduler.from_config(config), |
||||
|
"HEUN": HeunDiscreteScheduler.from_config(config), |
||||
|
"IPNDM": IPNDMScheduler.from_config(config), |
||||
|
"KDPM2-A": KDPM2AncestralDiscreteScheduler.from_config(config), |
||||
|
"KDPM2-D": KDPM2DiscreteScheduler.from_config(config), |
||||
|
# "KARRAS-VE": KarrasVeScheduler.from_config(config), |
||||
|
"PNDM": PNDMScheduler.from_config(config), |
||||
|
# "RE-PAINT": RePaintScheduler.from_config(config), |
||||
|
# "SCORE-VE": ScoreSdeVeScheduler.from_config(config), |
||||
|
# "SCORE-VP": ScoreSdeVpScheduler.from_config(config), |
||||
|
# "UN-CLIPS": UnCLIPScheduler.from_config(config), |
||||
|
# "VQD": VQDiffusionScheduler.from_config(config), |
||||
|
"K-LMS": LMSDiscreteScheduler.from_config(config) |
||||
|
}[name] |
@ -0,0 +1,157 @@ |
|||||
|
''' infer.py for runpod worker ''' |
||||
|
|
||||
|
import os |
||||
|
import predict |
||||
|
|
||||
|
import runpod |
||||
|
from runpod.serverless.utils import rp_download, rp_upload, rp_cleanup |
||||
|
from runpod.serverless.utils.rp_validator import validate |
||||
|
|
||||
|
|
||||
|
INPUT_SCHEMA = { |
||||
|
'prompt': { |
||||
|
'type': str, |
||||
|
'required': True |
||||
|
}, |
||||
|
'negative_prompt': { |
||||
|
'type': str, |
||||
|
'required': False, |
||||
|
'default': None |
||||
|
}, |
||||
|
'width': { |
||||
|
'type': int, |
||||
|
'required': False, |
||||
|
'default': 512, |
||||
|
'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768] |
||||
|
}, |
||||
|
'height': { |
||||
|
'type': int, |
||||
|
'required': False, |
||||
|
'default': 512, |
||||
|
'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768] |
||||
|
}, |
||||
|
'init_image': { |
||||
|
'type': str, |
||||
|
'required': False, |
||||
|
'default': None |
||||
|
}, |
||||
|
'mask': { |
||||
|
'type': str, |
||||
|
'required': False, |
||||
|
'default': None |
||||
|
}, |
||||
|
'prompt_strength': { |
||||
|
'type': float, |
||||
|
'required': False, |
||||
|
'default': 0.8, |
||||
|
'constraints': lambda prompt_strength: 0 <= prompt_strength <= 1 |
||||
|
}, |
||||
|
'num_outputs': { |
||||
|
'type': int, |
||||
|
'required': False, |
||||
|
'default': 1, |
||||
|
'constraints': lambda num_outputs: 10 > num_outputs > 0 |
||||
|
}, |
||||
|
'num_inference_steps': { |
||||
|
'type': int, |
||||
|
'required': False, |
||||
|
'default': 50, |
||||
|
'constraints': lambda num_inference_steps: 0 < num_inference_steps < 500 |
||||
|
}, |
||||
|
'guidance_scale': { |
||||
|
'type': float, |
||||
|
'required': False, |
||||
|
'default': 7.5, |
||||
|
'constraints': lambda guidance_scale: 0 < guidance_scale < 20 |
||||
|
}, |
||||
|
'scheduler': { |
||||
|
'type': str, |
||||
|
'required': False, |
||||
|
'default': 'K-LMS', |
||||
|
'constraints': lambda scheduler: scheduler in ['DDIM', 'DDPM', 'DPM-M', 'DPM-S', 'EULER-A', 'EULER-D', 'HEUN', 'IPNDM', 'KDPM2-A', 'KDPM2-D', 'PNDM', 'K-LMS'] |
||||
|
}, |
||||
|
'seed': { |
||||
|
'type': int, |
||||
|
'required': False, |
||||
|
'default': None |
||||
|
}, |
||||
|
'nsfw': { |
||||
|
'type': bool, |
||||
|
'required': False, |
||||
|
'default': False |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def run(job): |
||||
|
''' |
||||
|
Run inference on the model. |
||||
|
Returns output path, width the seed used to generate the image. |
||||
|
''' |
||||
|
job_input = job['input'] |
||||
|
|
||||
|
# Input validation |
||||
|
validated_input = validate(job_input, INPUT_SCHEMA) |
||||
|
|
||||
|
if 'errors' in validated_input: |
||||
|
return {"error": validated_input['errors']} |
||||
|
validated_input = validated_input['validated_input'] |
||||
|
|
||||
|
# Download input objects |
||||
|
job_input['init_image'], job_input['mask'] = rp_download.download_input_objects( |
||||
|
[job_input.get('init_image', None), job_input.get('mask', None)] |
||||
|
) # pylint: disable=unbalanced-tuple-unpacking |
||||
|
|
||||
|
MODEL.NSFW = job_input.get('nsfw', True) |
||||
|
|
||||
|
if job_input['seed'] is None: |
||||
|
job_input['seed'] = int.from_bytes(os.urandom(2), "big") |
||||
|
|
||||
|
img_paths = MODEL.predict( |
||||
|
prompt=job_input["prompt"], |
||||
|
negative_prompt=job_input.get("negative_prompt", None), |
||||
|
width=job_input.get('width', 512), |
||||
|
height=job_input.get('height', 512), |
||||
|
init_image=job_input['init_image'], |
||||
|
mask=job_input['mask'], |
||||
|
prompt_strength=job_input['prompt_strength'], |
||||
|
num_outputs=job_input.get('num_outputs', 1), |
||||
|
num_inference_steps=job_input.get('num_inference_steps', 50), |
||||
|
guidance_scale=job_input['guidance_scale'], |
||||
|
scheduler=job_input.get('scheduler', "K-LMS"), |
||||
|
seed=job_input['seed'] |
||||
|
) |
||||
|
|
||||
|
job_output = [] |
||||
|
for index, img_path in enumerate(img_paths): |
||||
|
image_url = rp_upload.upload_image(job['id'], img_path, index) |
||||
|
|
||||
|
job_output.append({ |
||||
|
"image": image_url, |
||||
|
"seed": job_input['seed'] + index |
||||
|
}) |
||||
|
|
||||
|
# Remove downloaded input objects |
||||
|
rp_cleanup.clean(['input_objects']) |
||||
|
|
||||
|
return job_output |
||||
|
|
||||
|
|
||||
|
parser = argparse.ArgumentParser(description=__doc__) |
||||
|
parser.add_argument("--model_url", type=str, |
||||
|
default=None, help="Model URL") |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
args = parser.parse_args() |
||||
|
print(args) |
||||
|
|
||||
|
if "huggingface.co" in args.model_url: |
||||
|
url_parts = args.model_url.split("/") |
||||
|
model_id = f"{url_parts[-2]}/{url_parts[-1]}" |
||||
|
else: |
||||
|
model_id = f"model.safetensors" |
||||
|
|
||||
|
MODEL = predict.Predictor(model_id) |
||||
|
MODEL.setup() |
||||
|
|
||||
|
runpod.serverless.start({"handler": run}) |
Loading…
Reference in new issue