Hendrik Langer
2 years ago
14 changed files with 530 additions and 167 deletions
@ -1,9 +0,0 @@ |
|||
from python:3.11.1-buster |
|||
|
|||
WORKDIR / |
|||
|
|||
RUN pip install runpod |
|||
|
|||
ADD handler.py . |
|||
|
|||
CMD [ "python", "-u", "/handler.py" ] |
@ -1,21 +0,0 @@ |
|||
#!/usr/bin/env python |
|||
''' Contains the handler function that will be called by the serverless. ''' |
|||
|
|||
import runpod |
|||
|
|||
# Load models into VRAM here so they can be warm between requests |
|||
|
|||
|
|||
def handler(event): |
|||
''' |
|||
This is the handler function that will be called by the serverless. |
|||
''' |
|||
print(event) |
|||
|
|||
# do the things |
|||
|
|||
# return the output that you want to be returned like pre-signed URLs to output artifacts |
|||
return "Hello World" |
|||
|
|||
|
|||
runpod.serverless.start({"handler": handler}) |
@ -1,18 +0,0 @@ |
|||
FROM runpod/stable-diffusion:web-automatic-1.5.16 |
|||
|
|||
SHELL ["/bin/bash", "-c"] |
|||
|
|||
ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" |
|||
|
|||
WORKDIR / |
|||
|
|||
RUN rm /workspace/v1-5-pruned-emaonly.ckpt |
|||
RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 |
|||
RUN pip install -U xformers |
|||
RUN pip install runpod |
|||
|
|||
ADD handler.py . |
|||
ADD start.sh /start.sh |
|||
RUN chmod +x /start.sh |
|||
|
|||
CMD [ "/start.sh" ] |
@ -1,39 +0,0 @@ |
|||
import runpod |
|||
import subprocess |
|||
import requests |
|||
import time |
|||
|
|||
def check_api_availability(host): |
|||
while True: |
|||
try: |
|||
response = requests.get(host) |
|||
return |
|||
except requests.exceptions.RequestException as e: |
|||
print(f"API is not available, retrying in 200ms... ({e})") |
|||
except Exception as e: |
|||
print('something went wrong') |
|||
time.sleep(200/1000) |
|||
|
|||
check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") |
|||
|
|||
print('run handler') |
|||
|
|||
def handler(event): |
|||
''' |
|||
This is the handler function that will be called by the serverless. |
|||
''' |
|||
print('got event') |
|||
print(event) |
|||
|
|||
response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) |
|||
|
|||
json = response.json() |
|||
# do the things |
|||
|
|||
print(json) |
|||
|
|||
# return the output that you want to be returned like pre-signed URLs to output artifacts |
|||
return json |
|||
|
|||
|
|||
runpod.serverless.start({"handler": handler}) |
@ -1,11 +0,0 @@ |
|||
#!/bin/bash |
|||
echo "Container Started" |
|||
export PYTHONUNBUFFERED=1 |
|||
source /workspace/stable-diffusion-webui/venv/bin/activate |
|||
cd /workspace/stable-diffusion-webui |
|||
echo "starting api" |
|||
python webui.py --port 3000 --nowebui --api --xformers --ckpt /model.safetensors & |
|||
cd / |
|||
|
|||
echo "starting worker" |
|||
python -u handler.py |
@ -1,18 +0,0 @@ |
|||
FROM runpod/stable-diffusion:web-automatic-1.5.16 |
|||
|
|||
SHELL ["/bin/bash", "-c"] |
|||
|
|||
ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" |
|||
|
|||
WORKDIR / |
|||
|
|||
RUN rm /workspace/v1-5-pruned-emaonly.ckpt |
|||
RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 |
|||
RUN pip install -U xformers |
|||
RUN pip install runpod |
|||
|
|||
ADD handler.py . |
|||
ADD start.sh /start.sh |
|||
RUN chmod +x /start.sh |
|||
|
|||
CMD [ "/start.sh" ] |
@ -1,39 +0,0 @@ |
|||
import runpod |
|||
import subprocess |
|||
import requests |
|||
import time |
|||
|
|||
def check_api_availability(host): |
|||
while True: |
|||
try: |
|||
response = requests.get(host) |
|||
return |
|||
except requests.exceptions.RequestException as e: |
|||
print(f"API is not available, retrying in 200ms... ({e})") |
|||
except Exception as e: |
|||
print('something went wrong') |
|||
time.sleep(200/1000) |
|||
|
|||
check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") |
|||
|
|||
print('run handler') |
|||
|
|||
def handler(event): |
|||
''' |
|||
This is the handler function that will be called by the serverless. |
|||
''' |
|||
print('got event') |
|||
print(event) |
|||
|
|||
response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) |
|||
|
|||
json = response.json() |
|||
# do the things |
|||
|
|||
print(json) |
|||
|
|||
# return the output that you want to be returned like pre-signed URLs to output artifacts |
|||
return json |
|||
|
|||
|
|||
runpod.serverless.start({"handler": handler}) |
@ -1,11 +0,0 @@ |
|||
#!/bin/bash |
|||
echo "Container Started" |
|||
export PYTHONUNBUFFERED=1 |
|||
source /workspace/stable-diffusion-webui/venv/bin/activate |
|||
cd /workspace/stable-diffusion-webui |
|||
echo "starting api" |
|||
python webui.py --port 3000 --nowebui --api --xformers --ckpt /model.safetensors & |
|||
cd / |
|||
|
|||
echo "starting worker" |
|||
python -u handler.py |
@ -0,0 +1,57 @@ |
|||
ARG BASE_IMAGE=nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 |
|||
FROM ${BASE_IMAGE} as dev-base |
|||
|
|||
WORKDIR / |
|||
SHELL ["/bin/bash", "-o", "pipefail", "-c"] |
|||
ENV DEBIAN_FRONTEND noninteractive\ |
|||
SHELL=/bin/bash |
|||
|
|||
RUN apt-get update --yes && \ |
|||
# - apt-get upgrade is run to patch known vulnerabilities in apt-get packages as |
|||
# the ubuntu base image is rebuilt too seldom sometimes (less than once a month) |
|||
apt-get upgrade --yes && \ |
|||
apt install --yes --no-install-recommends \ |
|||
build-essential \ |
|||
ca-certificates \ |
|||
git \ |
|||
git-lfs \ |
|||
wget \ |
|||
curl \ |
|||
bash \ |
|||
libgl1 \ |
|||
software-properties-common \ |
|||
openssh-server && \ |
|||
apt-get clean && rm -rf /var/lib/apt/lists/* && \ |
|||
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen |
|||
|
|||
RUN apt-key del 7fa2af80 && \ |
|||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub |
|||
|
|||
RUN add-apt-repository ppa:deadsnakes/ppa && \ |
|||
apt-get install python3.10 python3.10-dev python3.10-venv python3-pip -y --no-install-recommends && \ |
|||
apt-get clean && rm -rf /var/lib/apt/lists/* |
|||
|
|||
RUN pip install --upgrade pip |
|||
pip install huggingface-hub && \ |
|||
pip install diffusers && \ |
|||
pip install torch torchvision torchaudio --extra-index-url=https://download.pytorch.org/whl/cu116 && \ |
|||
pip install bitsandbytes |
|||
pip install transformers accelerate xformers triton && \ |
|||
pip install runpod |
|||
|
|||
RUN mkdir /workspace |
|||
WORKDIR /workspace |
|||
|
|||
COPY . /workspace/ |
|||
#ADD start.sh /workspace/ |
|||
|
|||
ARG MODEL_URL |
|||
ENV MODEL_URL=${MODEL_URL} |
|||
|
|||
RUN python3 model_fetcher.py --model_url=${MODEL_URL} && \ |
|||
echo "Model URL: $MODEL_URL" |
|||
#RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 |
|||
# PFG: https://civitai.com/api/download/models/1316 |
|||
# Hassanblend: https://civitai.com/api/download/models/4635 |
|||
|
|||
CMD python3 -u runpod_infer.py --model_url="$MODEL_URL" |
@ -0,0 +1,12 @@ |
|||
git clone https://github.com/runpod/serverless-ckpt-template.git |
|||
cd serverless-ckpt-template |
|||
|
|||
docker build --build-arg MODEL_URL=https://huggingface.co/hassanblend/HassanBlend1.5.1.2 -t magn418/runpod-hassan:1.5 . |
|||
docker login |
|||
docker push magn418/runpod-hassan:1.5 |
|||
|
|||
Models: |
|||
PFG https://civitai.com/models/1227/pfg |
|||
hassanblend https://civitai.com/models/1173/hassanblend-15-and-previous-versions |
|||
Deliberate |
|||
Anything v3 ? |
@ -0,0 +1,66 @@ |
|||
''' |
|||
RunPod | serverless-ckpt-template | model_fetcher.py |
|||
|
|||
Downloads the model from the URL passed in. |
|||
''' |
|||
|
|||
import os |
|||
import shutil |
|||
from diffusers import StableDiffusionPipeline |
|||
from diffusers.pipelines.stable_diffusion.safety_checker import ( |
|||
StableDiffusionSafetyChecker, |
|||
) |
|||
|
|||
import requests |
|||
import argparse |
|||
|
|||
|
|||
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" |
|||
MODEL_CACHE_DIR = "diffusers-cache" |
|||
|
|||
|
|||
def download_model(model_url: str): |
|||
''' |
|||
Downloads the model from the URL passed in. |
|||
''' |
|||
if os.path.exists(MODEL_CACHE_DIR): |
|||
shutil.rmtree(MODEL_CACHE_DIR) |
|||
os.makedirs(MODEL_CACHE_DIR, exist_ok=True) |
|||
|
|||
# Check if the URL is from huggingface.co, if so, grab the model repo id. |
|||
if "huggingface.co" in model_url: |
|||
url_parts = model_url.split("/") |
|||
model_id = f"{url_parts[-2]}/{url_parts[-1]}" |
|||
else: |
|||
downloaded_model = requests.get(model_url, stream=True, timeout=600) |
|||
with open(f"{MODEL_CACHE_DIR}/model.safetensors", "wb") as f: |
|||
for chunk in downloaded_model.iter_content(chunk_size=1024): |
|||
if chunk: |
|||
f.write(chunk) |
|||
os.system("wget -q https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_original_stable_diffusion_to_diffusers.py") |
|||
os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --from_safetensors --checkpoint_path model.safetensors --dump_path {MODEL_CACHE_DIR}/pt") |
|||
#os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path model.ckpt --dump_path pt") |
|||
model_id = "pt" |
|||
|
|||
saftey_checker = StableDiffusionSafetyChecker.from_pretrained( |
|||
SAFETY_MODEL_ID, |
|||
cache_dir=MODEL_CACHE_DIR, |
|||
) |
|||
|
|||
pipe = StableDiffusionPipeline.from_pretrained( |
|||
model_id, |
|||
cache_dir=MODEL_CACHE_DIR, |
|||
) |
|||
|
|||
|
|||
# ---------------------------------------------------------------------------- # |
|||
# Parse Arguments # |
|||
# ---------------------------------------------------------------------------- # |
|||
parser = argparse.ArgumentParser(description=__doc__) |
|||
parser.add_argument("--model_url", type=str, |
|||
default="https://huggingface.co/stabilityai/stable-diffusion-2-1", help="URL of the model to download.") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
args = parser.parse_args() |
|||
download_model(args.model_url) |
@ -0,0 +1,237 @@ |
|||
''' StableDiffusion-v1 Predict Module ''' |
|||
|
|||
import os |
|||
from typing import List |
|||
|
|||
import torch |
|||
from diffusers import ( |
|||
StableDiffusionPipeline, |
|||
StableDiffusionImg2ImgPipeline, |
|||
# StableDiffusionInpaintPipeline, |
|||
StableDiffusionInpaintPipelineLegacy, |
|||
|
|||
DDIMScheduler, |
|||
DDPMScheduler, |
|||
# DEISMultistepScheduler, |
|||
DPMSolverMultistepScheduler, |
|||
DPMSolverSinglestepScheduler, |
|||
EulerAncestralDiscreteScheduler, |
|||
EulerDiscreteScheduler, |
|||
HeunDiscreteScheduler, |
|||
IPNDMScheduler, |
|||
KDPM2AncestralDiscreteScheduler, |
|||
KDPM2DiscreteScheduler, |
|||
# KarrasVeScheduler, |
|||
PNDMScheduler, |
|||
# RePaintScheduler, |
|||
# ScoreSdeVeScheduler, |
|||
# ScoreSdeVpScheduler, |
|||
# UnCLIPScheduler, |
|||
# VQDiffusionScheduler, |
|||
LMSDiscreteScheduler |
|||
) |
|||
|
|||
from PIL import Image |
|||
from cog import BasePredictor, Input, Path |
|||
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp |
|||
|
|||
MODEL_CACHE = "diffusers-cache" |
|||
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" |
|||
|
|||
|
|||
class Predictor(BasePredictor): |
|||
'''Predictor class for StableDiffusion-v1''' |
|||
|
|||
def __init__(self, model_id): |
|||
self.model_id = model_id |
|||
|
|||
def setup(self): |
|||
''' |
|||
Load the model into memory to make running multiple predictions efficient |
|||
''' |
|||
print("Loading pipeline...") |
|||
|
|||
# safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
|||
# SAFETY_MODEL_ID, |
|||
# cache_dir=MODEL_CACHE, |
|||
# local_files_only=True, |
|||
# ) |
|||
self.txt2img_pipe = StableDiffusionPipeline.from_pretrained( |
|||
self.model_id, |
|||
safety_checker=None, |
|||
# safety_checker=safety_checker, |
|||
cache_dir=MODEL_CACHE, |
|||
local_files_only=True, |
|||
).to("cuda") |
|||
self.img2img_pipe = StableDiffusionImg2ImgPipeline( |
|||
vae=self.txt2img_pipe.vae, |
|||
text_encoder=self.txt2img_pipe.text_encoder, |
|||
tokenizer=self.txt2img_pipe.tokenizer, |
|||
unet=self.txt2img_pipe.unet, |
|||
scheduler=self.txt2img_pipe.scheduler, |
|||
safety_checker=None, |
|||
# safety_checker=self.txt2img_pipe.safety_checker, |
|||
feature_extractor=self.txt2img_pipe.feature_extractor, |
|||
).to("cuda") |
|||
self.inpaint_pipe = StableDiffusionInpaintPipelineLegacy( |
|||
vae=self.txt2img_pipe.vae, |
|||
text_encoder=self.txt2img_pipe.text_encoder, |
|||
tokenizer=self.txt2img_pipe.tokenizer, |
|||
unet=self.txt2img_pipe.unet, |
|||
scheduler=self.txt2img_pipe.scheduler, |
|||
safety_checker=None, |
|||
# safety_checker=self.txt2img_pipe.safety_checker, |
|||
feature_extractor=self.txt2img_pipe.feature_extractor, |
|||
).to("cuda") |
|||
|
|||
self.txt2img_pipe.enable_xformers_memory_efficient_attention() |
|||
self.img2img_pipe.enable_xformers_memory_efficient_attention() |
|||
self.inpaint_pipe.enable_xformers_memory_efficient_attention() |
|||
|
|||
@torch.inference_mode() |
|||
@torch.cuda.amp.autocast() |
|||
def predict( |
|||
self, |
|||
prompt: str = Input(description="Input prompt", default=""), |
|||
negative_prompt: str = Input( |
|||
description="Specify things to not see in the output", |
|||
default=None, |
|||
), |
|||
width: int = Input( |
|||
description="Output image width; max 1024x768 or 768x1024 due to memory limits", |
|||
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024], |
|||
default=512, |
|||
), |
|||
height: int = Input( |
|||
description="Output image height; max 1024x768 or 768x1024 due to memory limits", |
|||
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024], |
|||
default=512, |
|||
), |
|||
init_image: Path = Input( |
|||
description="Initial image to generate variations of, resized to the specified WxH.", |
|||
default=None, |
|||
), |
|||
mask: Path = Input( |
|||
description="""Black and white image to use as mask for inpainting over init_image. |
|||
Black pixels are inpainted and white pixels are preserved. |
|||
Tends to work better with prompt strength of 0.5-0.7""", |
|||
default=None, |
|||
), |
|||
prompt_strength: float = Input( |
|||
description="Prompt strength init image. 1.0 full destruction of init image", |
|||
default=0.8, |
|||
), |
|||
num_outputs: int = Input( |
|||
description="Number of images to output.", |
|||
ge=1, |
|||
le=10, |
|||
default=1 |
|||
), |
|||
num_inference_steps: int = Input( |
|||
description="Number of denoising steps", ge=1, le=500, default=50 |
|||
), |
|||
guidance_scale: float = Input( |
|||
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5 |
|||
), |
|||
scheduler: str = Input( |
|||
default="K-LMS", |
|||
choices=["DDIM", "DDPM", "DPM-M", "DPM-S", "EULER-A", "EULER-D", |
|||
"HEUN", "IPNDM", "KDPM2-A", "KDPM2-D", "PNDM", "K-LMS"], |
|||
description="Choose a scheduler. If you use an init image, PNDM will be used", |
|||
), |
|||
seed: int = Input( |
|||
description="Random seed. Leave blank to randomize the seed", default=None |
|||
), |
|||
) -> List[Path]: |
|||
''' |
|||
Run a single prediction on the model |
|||
''' |
|||
if seed is None: |
|||
seed = int.from_bytes(os.urandom(2), "big") |
|||
|
|||
if width * height > 786432: |
|||
raise ValueError( |
|||
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits." |
|||
) |
|||
|
|||
extra_kwargs = {} |
|||
if mask: |
|||
if not init_image: |
|||
raise ValueError("mask was provided without init_image") |
|||
|
|||
pipe = self.inpaint_pipe |
|||
init_image = Image.open(init_image).convert("RGB") |
|||
extra_kwargs = { |
|||
"mask_image": Image.open(mask).convert("RGB").resize(init_image.size), |
|||
"image": init_image, |
|||
"strength": prompt_strength, |
|||
} |
|||
elif init_image: |
|||
pipe = self.img2img_pipe |
|||
extra_kwargs = { |
|||
"init_image": Image.open(init_image).convert("RGB"), |
|||
"strength": prompt_strength, |
|||
} |
|||
else: |
|||
pipe = self.txt2img_pipe |
|||
extra_kwargs = { |
|||
"width": width, |
|||
"height": height, |
|||
} |
|||
|
|||
pipe.scheduler = make_scheduler(scheduler, pipe.scheduler.config) |
|||
|
|||
generator = torch.Generator("cuda").manual_seed(seed) |
|||
output = pipe( |
|||
prompt=[prompt] * num_outputs if prompt is not None else None, |
|||
negative_prompt=[negative_prompt]*num_outputs if negative_prompt is not None else None, |
|||
# width=width, |
|||
# height=height, |
|||
guidance_scale=guidance_scale, |
|||
generator=generator, |
|||
num_inference_steps=num_inference_steps, |
|||
**extra_kwargs, |
|||
) |
|||
|
|||
output_paths = [] |
|||
for i, sample in enumerate(output.images): |
|||
# if output.nsfw_content_detected and output.nsfw_content_detected[i] and self.NSFW: |
|||
# continue |
|||
|
|||
output_path = f"/tmp/out-{i}.png" |
|||
sample.save(output_path) |
|||
output_paths.append(Path(output_path)) |
|||
|
|||
if len(output_paths) == 0: |
|||
raise Exception( |
|||
"NSFW content detected. Try running it again, or try a different prompt." |
|||
) |
|||
|
|||
return output_paths |
|||
|
|||
|
|||
def make_scheduler(name, config): |
|||
''' |
|||
Returns a scheduler from a name and config. |
|||
''' |
|||
return { |
|||
"DDIM": DDIMScheduler.from_config(config), |
|||
"DDPM": DDPMScheduler.from_config(config), |
|||
# "DEIS": DEISMultistepScheduler.from_config(config), |
|||
"DPM-M": DPMSolverMultistepScheduler.from_config(config), |
|||
"DPM-S": DPMSolverSinglestepScheduler.from_config(config), |
|||
"EULER-A": EulerAncestralDiscreteScheduler.from_config(config), |
|||
"EULER-D": EulerDiscreteScheduler.from_config(config), |
|||
"HEUN": HeunDiscreteScheduler.from_config(config), |
|||
"IPNDM": IPNDMScheduler.from_config(config), |
|||
"KDPM2-A": KDPM2AncestralDiscreteScheduler.from_config(config), |
|||
"KDPM2-D": KDPM2DiscreteScheduler.from_config(config), |
|||
# "KARRAS-VE": KarrasVeScheduler.from_config(config), |
|||
"PNDM": PNDMScheduler.from_config(config), |
|||
# "RE-PAINT": RePaintScheduler.from_config(config), |
|||
# "SCORE-VE": ScoreSdeVeScheduler.from_config(config), |
|||
# "SCORE-VP": ScoreSdeVpScheduler.from_config(config), |
|||
# "UN-CLIPS": UnCLIPScheduler.from_config(config), |
|||
# "VQD": VQDiffusionScheduler.from_config(config), |
|||
"K-LMS": LMSDiscreteScheduler.from_config(config) |
|||
}[name] |
@ -0,0 +1,157 @@ |
|||
''' infer.py for runpod worker ''' |
|||
|
|||
import os |
|||
import predict |
|||
|
|||
import runpod |
|||
from runpod.serverless.utils import rp_download, rp_upload, rp_cleanup |
|||
from runpod.serverless.utils.rp_validator import validate |
|||
|
|||
|
|||
INPUT_SCHEMA = { |
|||
'prompt': { |
|||
'type': str, |
|||
'required': True |
|||
}, |
|||
'negative_prompt': { |
|||
'type': str, |
|||
'required': False, |
|||
'default': None |
|||
}, |
|||
'width': { |
|||
'type': int, |
|||
'required': False, |
|||
'default': 512, |
|||
'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768] |
|||
}, |
|||
'height': { |
|||
'type': int, |
|||
'required': False, |
|||
'default': 512, |
|||
'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768] |
|||
}, |
|||
'init_image': { |
|||
'type': str, |
|||
'required': False, |
|||
'default': None |
|||
}, |
|||
'mask': { |
|||
'type': str, |
|||
'required': False, |
|||
'default': None |
|||
}, |
|||
'prompt_strength': { |
|||
'type': float, |
|||
'required': False, |
|||
'default': 0.8, |
|||
'constraints': lambda prompt_strength: 0 <= prompt_strength <= 1 |
|||
}, |
|||
'num_outputs': { |
|||
'type': int, |
|||
'required': False, |
|||
'default': 1, |
|||
'constraints': lambda num_outputs: 10 > num_outputs > 0 |
|||
}, |
|||
'num_inference_steps': { |
|||
'type': int, |
|||
'required': False, |
|||
'default': 50, |
|||
'constraints': lambda num_inference_steps: 0 < num_inference_steps < 500 |
|||
}, |
|||
'guidance_scale': { |
|||
'type': float, |
|||
'required': False, |
|||
'default': 7.5, |
|||
'constraints': lambda guidance_scale: 0 < guidance_scale < 20 |
|||
}, |
|||
'scheduler': { |
|||
'type': str, |
|||
'required': False, |
|||
'default': 'K-LMS', |
|||
'constraints': lambda scheduler: scheduler in ['DDIM', 'DDPM', 'DPM-M', 'DPM-S', 'EULER-A', 'EULER-D', 'HEUN', 'IPNDM', 'KDPM2-A', 'KDPM2-D', 'PNDM', 'K-LMS'] |
|||
}, |
|||
'seed': { |
|||
'type': int, |
|||
'required': False, |
|||
'default': None |
|||
}, |
|||
'nsfw': { |
|||
'type': bool, |
|||
'required': False, |
|||
'default': False |
|||
} |
|||
} |
|||
|
|||
|
|||
def run(job): |
|||
''' |
|||
Run inference on the model. |
|||
Returns output path, width the seed used to generate the image. |
|||
''' |
|||
job_input = job['input'] |
|||
|
|||
# Input validation |
|||
validated_input = validate(job_input, INPUT_SCHEMA) |
|||
|
|||
if 'errors' in validated_input: |
|||
return {"error": validated_input['errors']} |
|||
validated_input = validated_input['validated_input'] |
|||
|
|||
# Download input objects |
|||
job_input['init_image'], job_input['mask'] = rp_download.download_input_objects( |
|||
[job_input.get('init_image', None), job_input.get('mask', None)] |
|||
) # pylint: disable=unbalanced-tuple-unpacking |
|||
|
|||
MODEL.NSFW = job_input.get('nsfw', True) |
|||
|
|||
if job_input['seed'] is None: |
|||
job_input['seed'] = int.from_bytes(os.urandom(2), "big") |
|||
|
|||
img_paths = MODEL.predict( |
|||
prompt=job_input["prompt"], |
|||
negative_prompt=job_input.get("negative_prompt", None), |
|||
width=job_input.get('width', 512), |
|||
height=job_input.get('height', 512), |
|||
init_image=job_input['init_image'], |
|||
mask=job_input['mask'], |
|||
prompt_strength=job_input['prompt_strength'], |
|||
num_outputs=job_input.get('num_outputs', 1), |
|||
num_inference_steps=job_input.get('num_inference_steps', 50), |
|||
guidance_scale=job_input['guidance_scale'], |
|||
scheduler=job_input.get('scheduler', "K-LMS"), |
|||
seed=job_input['seed'] |
|||
) |
|||
|
|||
job_output = [] |
|||
for index, img_path in enumerate(img_paths): |
|||
image_url = rp_upload.upload_image(job['id'], img_path, index) |
|||
|
|||
job_output.append({ |
|||
"image": image_url, |
|||
"seed": job_input['seed'] + index |
|||
}) |
|||
|
|||
# Remove downloaded input objects |
|||
rp_cleanup.clean(['input_objects']) |
|||
|
|||
return job_output |
|||
|
|||
|
|||
parser = argparse.ArgumentParser(description=__doc__) |
|||
parser.add_argument("--model_url", type=str, |
|||
default=None, help="Model URL") |
|||
|
|||
if __name__ == "__main__": |
|||
args = parser.parse_args() |
|||
print(args) |
|||
|
|||
if "huggingface.co" in args.model_url: |
|||
url_parts = args.model_url.split("/") |
|||
model_id = f"{url_parts[-2]}/{url_parts[-1]}" |
|||
else: |
|||
model_id = f"model.safetensors" |
|||
|
|||
MODEL = predict.Predictor(model_id) |
|||
MODEL.setup() |
|||
|
|||
runpod.serverless.start({"handler": run}) |
Loading…
Reference in new issue