From 1ac4d1b4cded4f6c0de809b14348ec07e0950c38 Mon Sep 17 00:00:00 2001 From: Hendrik Langer Date: Tue, 4 Apr 2023 18:37:05 +0200 Subject: [PATCH] remote worker container --- matrix_pygmalion_bot/ai/runpod_pygmalion.py | 2 +- runpod/example/Dockerfile | 9 - runpod/example/handler.py | 21 -- runpod/runpod-worker-deliberate/Dockerfile | 18 -- runpod/runpod-worker-deliberate/handler.py | 39 ---- runpod/runpod-worker-deliberate/start.sh | 11 - runpod/runpod-worker-hassanblend/Dockerfile | 18 -- runpod/runpod-worker-hassanblend/handler.py | 39 ---- runpod/runpod-worker-hassanblend/start.sh | 11 - runpod/runpod-worker-sd/Dockerfile | 57 +++++ runpod/runpod-worker-sd/README.md | 12 + runpod/runpod-worker-sd/model_fetcher.py | 66 ++++++ runpod/runpod-worker-sd/predict.py | 237 ++++++++++++++++++++ runpod/runpod-worker-sd/runpod_infer.py | 157 +++++++++++++ 14 files changed, 530 insertions(+), 167 deletions(-) delete mode 100644 runpod/example/Dockerfile delete mode 100644 runpod/example/handler.py delete mode 100644 runpod/runpod-worker-deliberate/Dockerfile delete mode 100644 runpod/runpod-worker-deliberate/handler.py delete mode 100644 runpod/runpod-worker-deliberate/start.sh delete mode 100644 runpod/runpod-worker-hassanblend/Dockerfile delete mode 100644 runpod/runpod-worker-hassanblend/handler.py delete mode 100644 runpod/runpod-worker-hassanblend/start.sh create mode 100644 runpod/runpod-worker-sd/Dockerfile create mode 100644 runpod/runpod-worker-sd/README.md create mode 100644 runpod/runpod-worker-sd/model_fetcher.py create mode 100644 runpod/runpod-worker-sd/predict.py create mode 100644 runpod/runpod-worker-sd/runpod_infer.py diff --git a/matrix_pygmalion_bot/ai/runpod_pygmalion.py b/matrix_pygmalion_bot/ai/runpod_pygmalion.py index 8883b62..4f5645e 100644 --- a/matrix_pygmalion_bot/ai/runpod_pygmalion.py +++ b/matrix_pygmalion_bot/ai/runpod_pygmalion.py @@ -230,7 +230,7 @@ async def generate_image2(input_prompt: str, negative_prompt: str, api_key: str) return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/sd-openjourney/", api_key) async def generate_image3(input_prompt: str, negative_prompt: str, api_key: str): - return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/7me2asocq5lr01/", api_key) + return await generate_image(input_prompt, negative_prompt, "https://api.runpod.ai/v1/mf5f6mocy8bsvx/", api_key) async def download_image(url, path): r = requests.get(url, stream=True) diff --git a/runpod/example/Dockerfile b/runpod/example/Dockerfile deleted file mode 100644 index 43ef7a4..0000000 --- a/runpod/example/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -from python:3.11.1-buster - -WORKDIR / - -RUN pip install runpod - -ADD handler.py . - -CMD [ "python", "-u", "/handler.py" ] diff --git a/runpod/example/handler.py b/runpod/example/handler.py deleted file mode 100644 index dc0c957..0000000 --- a/runpod/example/handler.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -''' Contains the handler function that will be called by the serverless. ''' - -import runpod - -# Load models into VRAM here so they can be warm between requests - - -def handler(event): - ''' - This is the handler function that will be called by the serverless. - ''' - print(event) - - # do the things - - # return the output that you want to be returned like pre-signed URLs to output artifacts - return "Hello World" - - -runpod.serverless.start({"handler": handler}) diff --git a/runpod/runpod-worker-deliberate/Dockerfile b/runpod/runpod-worker-deliberate/Dockerfile deleted file mode 100644 index af04dff..0000000 --- a/runpod/runpod-worker-deliberate/Dockerfile +++ /dev/null @@ -1,18 +0,0 @@ -FROM runpod/stable-diffusion:web-automatic-1.5.16 - -SHELL ["/bin/bash", "-c"] - -ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" - -WORKDIR / - -RUN rm /workspace/v1-5-pruned-emaonly.ckpt -RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 -RUN pip install -U xformers -RUN pip install runpod - -ADD handler.py . -ADD start.sh /start.sh -RUN chmod +x /start.sh - -CMD [ "/start.sh" ] diff --git a/runpod/runpod-worker-deliberate/handler.py b/runpod/runpod-worker-deliberate/handler.py deleted file mode 100644 index 6353f2b..0000000 --- a/runpod/runpod-worker-deliberate/handler.py +++ /dev/null @@ -1,39 +0,0 @@ -import runpod -import subprocess -import requests -import time - -def check_api_availability(host): - while True: - try: - response = requests.get(host) - return - except requests.exceptions.RequestException as e: - print(f"API is not available, retrying in 200ms... ({e})") - except Exception as e: - print('something went wrong') - time.sleep(200/1000) - -check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") - -print('run handler') - -def handler(event): - ''' - This is the handler function that will be called by the serverless. - ''' - print('got event') - print(event) - - response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) - - json = response.json() - # do the things - - print(json) - - # return the output that you want to be returned like pre-signed URLs to output artifacts - return json - - -runpod.serverless.start({"handler": handler}) diff --git a/runpod/runpod-worker-deliberate/start.sh b/runpod/runpod-worker-deliberate/start.sh deleted file mode 100644 index f79378e..0000000 --- a/runpod/runpod-worker-deliberate/start.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -echo "Container Started" -export PYTHONUNBUFFERED=1 -source /workspace/stable-diffusion-webui/venv/bin/activate -cd /workspace/stable-diffusion-webui -echo "starting api" -python webui.py --port 3000 --nowebui --api --xformers --ckpt /model.safetensors & -cd / - -echo "starting worker" -python -u handler.py diff --git a/runpod/runpod-worker-hassanblend/Dockerfile b/runpod/runpod-worker-hassanblend/Dockerfile deleted file mode 100644 index af04dff..0000000 --- a/runpod/runpod-worker-hassanblend/Dockerfile +++ /dev/null @@ -1,18 +0,0 @@ -FROM runpod/stable-diffusion:web-automatic-1.5.16 - -SHELL ["/bin/bash", "-c"] - -ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" - -WORKDIR / - -RUN rm /workspace/v1-5-pruned-emaonly.ckpt -RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 -RUN pip install -U xformers -RUN pip install runpod - -ADD handler.py . -ADD start.sh /start.sh -RUN chmod +x /start.sh - -CMD [ "/start.sh" ] diff --git a/runpod/runpod-worker-hassanblend/handler.py b/runpod/runpod-worker-hassanblend/handler.py deleted file mode 100644 index 6353f2b..0000000 --- a/runpod/runpod-worker-hassanblend/handler.py +++ /dev/null @@ -1,39 +0,0 @@ -import runpod -import subprocess -import requests -import time - -def check_api_availability(host): - while True: - try: - response = requests.get(host) - return - except requests.exceptions.RequestException as e: - print(f"API is not available, retrying in 200ms... ({e})") - except Exception as e: - print('something went wrong') - time.sleep(200/1000) - -check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") - -print('run handler') - -def handler(event): - ''' - This is the handler function that will be called by the serverless. - ''' - print('got event') - print(event) - - response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) - - json = response.json() - # do the things - - print(json) - - # return the output that you want to be returned like pre-signed URLs to output artifacts - return json - - -runpod.serverless.start({"handler": handler}) diff --git a/runpod/runpod-worker-hassanblend/start.sh b/runpod/runpod-worker-hassanblend/start.sh deleted file mode 100644 index f79378e..0000000 --- a/runpod/runpod-worker-hassanblend/start.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -echo "Container Started" -export PYTHONUNBUFFERED=1 -source /workspace/stable-diffusion-webui/venv/bin/activate -cd /workspace/stable-diffusion-webui -echo "starting api" -python webui.py --port 3000 --nowebui --api --xformers --ckpt /model.safetensors & -cd / - -echo "starting worker" -python -u handler.py diff --git a/runpod/runpod-worker-sd/Dockerfile b/runpod/runpod-worker-sd/Dockerfile new file mode 100644 index 0000000..0c16637 --- /dev/null +++ b/runpod/runpod-worker-sd/Dockerfile @@ -0,0 +1,57 @@ +ARG BASE_IMAGE=nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 +FROM ${BASE_IMAGE} as dev-base + +WORKDIR / +SHELL ["/bin/bash", "-o", "pipefail", "-c"] +ENV DEBIAN_FRONTEND noninteractive\ + SHELL=/bin/bash + +RUN apt-get update --yes && \ + # - apt-get upgrade is run to patch known vulnerabilities in apt-get packages as + # the ubuntu base image is rebuilt too seldom sometimes (less than once a month) + apt-get upgrade --yes && \ + apt install --yes --no-install-recommends \ + build-essential \ + ca-certificates \ + git \ + git-lfs \ + wget \ + curl \ + bash \ + libgl1 \ + software-properties-common \ + openssh-server && \ + apt-get clean && rm -rf /var/lib/apt/lists/* && \ + echo "en_US.UTF-8 UTF-8" > /etc/locale.gen + +RUN apt-key del 7fa2af80 && \ + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub + +RUN add-apt-repository ppa:deadsnakes/ppa && \ + apt-get install python3.10 python3.10-dev python3.10-venv python3-pip -y --no-install-recommends && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +RUN pip install --upgrade pip + pip install huggingface-hub && \ + pip install diffusers && \ + pip install torch torchvision torchaudio --extra-index-url=https://download.pytorch.org/whl/cu116 && \ + pip install bitsandbytes + pip install transformers accelerate xformers triton && \ + pip install runpod + +RUN mkdir /workspace +WORKDIR /workspace + +COPY . /workspace/ +#ADD start.sh /workspace/ + +ARG MODEL_URL +ENV MODEL_URL=${MODEL_URL} + +RUN python3 model_fetcher.py --model_url=${MODEL_URL} && \ + echo "Model URL: $MODEL_URL" +#RUN wget -O model.safetensors https://civitai.com/api/download/models/5616 +# PFG: https://civitai.com/api/download/models/1316 +# Hassanblend: https://civitai.com/api/download/models/4635 + +CMD python3 -u runpod_infer.py --model_url="$MODEL_URL" diff --git a/runpod/runpod-worker-sd/README.md b/runpod/runpod-worker-sd/README.md new file mode 100644 index 0000000..dd19aee --- /dev/null +++ b/runpod/runpod-worker-sd/README.md @@ -0,0 +1,12 @@ +git clone https://github.com/runpod/serverless-ckpt-template.git +cd serverless-ckpt-template + +docker build --build-arg MODEL_URL=https://huggingface.co/hassanblend/HassanBlend1.5.1.2 -t magn418/runpod-hassan:1.5 . +docker login +docker push magn418/runpod-hassan:1.5 + +Models: +PFG https://civitai.com/models/1227/pfg +hassanblend https://civitai.com/models/1173/hassanblend-15-and-previous-versions +Deliberate +Anything v3 ? diff --git a/runpod/runpod-worker-sd/model_fetcher.py b/runpod/runpod-worker-sd/model_fetcher.py new file mode 100644 index 0000000..c0daf32 --- /dev/null +++ b/runpod/runpod-worker-sd/model_fetcher.py @@ -0,0 +1,66 @@ +''' +RunPod | serverless-ckpt-template | model_fetcher.py + +Downloads the model from the URL passed in. +''' + +import os +import shutil +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) + +import requests +import argparse + + +SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" +MODEL_CACHE_DIR = "diffusers-cache" + + +def download_model(model_url: str): + ''' + Downloads the model from the URL passed in. + ''' + if os.path.exists(MODEL_CACHE_DIR): + shutil.rmtree(MODEL_CACHE_DIR) + os.makedirs(MODEL_CACHE_DIR, exist_ok=True) + + # Check if the URL is from huggingface.co, if so, grab the model repo id. + if "huggingface.co" in model_url: + url_parts = model_url.split("/") + model_id = f"{url_parts[-2]}/{url_parts[-1]}" + else: + downloaded_model = requests.get(model_url, stream=True, timeout=600) + with open(f"{MODEL_CACHE_DIR}/model.safetensors", "wb") as f: + for chunk in downloaded_model.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + os.system("wget -q https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_original_stable_diffusion_to_diffusers.py") + os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --from_safetensors --checkpoint_path model.safetensors --dump_path {MODEL_CACHE_DIR}/pt") + #os.system(f"python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path model.ckpt --dump_path pt") + model_id = "pt" + + saftey_checker = StableDiffusionSafetyChecker.from_pretrained( + SAFETY_MODEL_ID, + cache_dir=MODEL_CACHE_DIR, + ) + + pipe = StableDiffusionPipeline.from_pretrained( + model_id, + cache_dir=MODEL_CACHE_DIR, + ) + + +# ---------------------------------------------------------------------------- # +# Parse Arguments # +# ---------------------------------------------------------------------------- # +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument("--model_url", type=str, + default="https://huggingface.co/stabilityai/stable-diffusion-2-1", help="URL of the model to download.") + + +if __name__ == "__main__": + args = parser.parse_args() + download_model(args.model_url) diff --git a/runpod/runpod-worker-sd/predict.py b/runpod/runpod-worker-sd/predict.py new file mode 100644 index 0000000..425331c --- /dev/null +++ b/runpod/runpod-worker-sd/predict.py @@ -0,0 +1,237 @@ +''' StableDiffusion-v1 Predict Module ''' + +import os +from typing import List + +import torch +from diffusers import ( + StableDiffusionPipeline, + StableDiffusionImg2ImgPipeline, + # StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + + DDIMScheduler, + DDPMScheduler, + # DEISMultistepScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + IPNDMScheduler, + KDPM2AncestralDiscreteScheduler, + KDPM2DiscreteScheduler, + # KarrasVeScheduler, + PNDMScheduler, + # RePaintScheduler, + # ScoreSdeVeScheduler, + # ScoreSdeVpScheduler, + # UnCLIPScheduler, + # VQDiffusionScheduler, + LMSDiscreteScheduler +) + +from PIL import Image +from cog import BasePredictor, Input, Path +from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + +MODEL_CACHE = "diffusers-cache" +SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" + + +class Predictor(BasePredictor): + '''Predictor class for StableDiffusion-v1''' + + def __init__(self, model_id): + self.model_id = model_id + + def setup(self): + ''' + Load the model into memory to make running multiple predictions efficient + ''' + print("Loading pipeline...") + +# safety_checker = StableDiffusionSafetyChecker.from_pretrained( +# SAFETY_MODEL_ID, +# cache_dir=MODEL_CACHE, +# local_files_only=True, +# ) + self.txt2img_pipe = StableDiffusionPipeline.from_pretrained( + self.model_id, + safety_checker=None, +# safety_checker=safety_checker, + cache_dir=MODEL_CACHE, + local_files_only=True, + ).to("cuda") + self.img2img_pipe = StableDiffusionImg2ImgPipeline( + vae=self.txt2img_pipe.vae, + text_encoder=self.txt2img_pipe.text_encoder, + tokenizer=self.txt2img_pipe.tokenizer, + unet=self.txt2img_pipe.unet, + scheduler=self.txt2img_pipe.scheduler, + safety_checker=None, + # safety_checker=self.txt2img_pipe.safety_checker, + feature_extractor=self.txt2img_pipe.feature_extractor, + ).to("cuda") + self.inpaint_pipe = StableDiffusionInpaintPipelineLegacy( + vae=self.txt2img_pipe.vae, + text_encoder=self.txt2img_pipe.text_encoder, + tokenizer=self.txt2img_pipe.tokenizer, + unet=self.txt2img_pipe.unet, + scheduler=self.txt2img_pipe.scheduler, + safety_checker=None, + # safety_checker=self.txt2img_pipe.safety_checker, + feature_extractor=self.txt2img_pipe.feature_extractor, + ).to("cuda") + + self.txt2img_pipe.enable_xformers_memory_efficient_attention() + self.img2img_pipe.enable_xformers_memory_efficient_attention() + self.inpaint_pipe.enable_xformers_memory_efficient_attention() + + @torch.inference_mode() + @torch.cuda.amp.autocast() + def predict( + self, + prompt: str = Input(description="Input prompt", default=""), + negative_prompt: str = Input( + description="Specify things to not see in the output", + default=None, + ), + width: int = Input( + description="Output image width; max 1024x768 or 768x1024 due to memory limits", + choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024], + default=512, + ), + height: int = Input( + description="Output image height; max 1024x768 or 768x1024 due to memory limits", + choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024], + default=512, + ), + init_image: Path = Input( + description="Initial image to generate variations of, resized to the specified WxH.", + default=None, + ), + mask: Path = Input( + description="""Black and white image to use as mask for inpainting over init_image. + Black pixels are inpainted and white pixels are preserved. + Tends to work better with prompt strength of 0.5-0.7""", + default=None, + ), + prompt_strength: float = Input( + description="Prompt strength init image. 1.0 full destruction of init image", + default=0.8, + ), + num_outputs: int = Input( + description="Number of images to output.", + ge=1, + le=10, + default=1 + ), + num_inference_steps: int = Input( + description="Number of denoising steps", ge=1, le=500, default=50 + ), + guidance_scale: float = Input( + description="Scale for classifier-free guidance", ge=1, le=20, default=7.5 + ), + scheduler: str = Input( + default="K-LMS", + choices=["DDIM", "DDPM", "DPM-M", "DPM-S", "EULER-A", "EULER-D", + "HEUN", "IPNDM", "KDPM2-A", "KDPM2-D", "PNDM", "K-LMS"], + description="Choose a scheduler. If you use an init image, PNDM will be used", + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + ) -> List[Path]: + ''' + Run a single prediction on the model + ''' + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + + if width * height > 786432: + raise ValueError( + "Maximum size is 1024x768 or 768x1024 pixels, because of memory limits." + ) + + extra_kwargs = {} + if mask: + if not init_image: + raise ValueError("mask was provided without init_image") + + pipe = self.inpaint_pipe + init_image = Image.open(init_image).convert("RGB") + extra_kwargs = { + "mask_image": Image.open(mask).convert("RGB").resize(init_image.size), + "image": init_image, + "strength": prompt_strength, + } + elif init_image: + pipe = self.img2img_pipe + extra_kwargs = { + "init_image": Image.open(init_image).convert("RGB"), + "strength": prompt_strength, + } + else: + pipe = self.txt2img_pipe + extra_kwargs = { + "width": width, + "height": height, + } + + pipe.scheduler = make_scheduler(scheduler, pipe.scheduler.config) + + generator = torch.Generator("cuda").manual_seed(seed) + output = pipe( + prompt=[prompt] * num_outputs if prompt is not None else None, + negative_prompt=[negative_prompt]*num_outputs if negative_prompt is not None else None, + # width=width, + # height=height, + guidance_scale=guidance_scale, + generator=generator, + num_inference_steps=num_inference_steps, + **extra_kwargs, + ) + + output_paths = [] + for i, sample in enumerate(output.images): + # if output.nsfw_content_detected and output.nsfw_content_detected[i] and self.NSFW: + # continue + + output_path = f"/tmp/out-{i}.png" + sample.save(output_path) + output_paths.append(Path(output_path)) + + if len(output_paths) == 0: + raise Exception( + "NSFW content detected. Try running it again, or try a different prompt." + ) + + return output_paths + + +def make_scheduler(name, config): + ''' + Returns a scheduler from a name and config. + ''' + return { + "DDIM": DDIMScheduler.from_config(config), + "DDPM": DDPMScheduler.from_config(config), + # "DEIS": DEISMultistepScheduler.from_config(config), + "DPM-M": DPMSolverMultistepScheduler.from_config(config), + "DPM-S": DPMSolverSinglestepScheduler.from_config(config), + "EULER-A": EulerAncestralDiscreteScheduler.from_config(config), + "EULER-D": EulerDiscreteScheduler.from_config(config), + "HEUN": HeunDiscreteScheduler.from_config(config), + "IPNDM": IPNDMScheduler.from_config(config), + "KDPM2-A": KDPM2AncestralDiscreteScheduler.from_config(config), + "KDPM2-D": KDPM2DiscreteScheduler.from_config(config), + # "KARRAS-VE": KarrasVeScheduler.from_config(config), + "PNDM": PNDMScheduler.from_config(config), + # "RE-PAINT": RePaintScheduler.from_config(config), + # "SCORE-VE": ScoreSdeVeScheduler.from_config(config), + # "SCORE-VP": ScoreSdeVpScheduler.from_config(config), + # "UN-CLIPS": UnCLIPScheduler.from_config(config), + # "VQD": VQDiffusionScheduler.from_config(config), + "K-LMS": LMSDiscreteScheduler.from_config(config) + }[name] diff --git a/runpod/runpod-worker-sd/runpod_infer.py b/runpod/runpod-worker-sd/runpod_infer.py new file mode 100644 index 0000000..abef5d9 --- /dev/null +++ b/runpod/runpod-worker-sd/runpod_infer.py @@ -0,0 +1,157 @@ +''' infer.py for runpod worker ''' + +import os +import predict + +import runpod +from runpod.serverless.utils import rp_download, rp_upload, rp_cleanup +from runpod.serverless.utils.rp_validator import validate + + +INPUT_SCHEMA = { + 'prompt': { + 'type': str, + 'required': True + }, + 'negative_prompt': { + 'type': str, + 'required': False, + 'default': None + }, + 'width': { + 'type': int, + 'required': False, + 'default': 512, + 'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768] + }, + 'height': { + 'type': int, + 'required': False, + 'default': 512, + 'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768] + }, + 'init_image': { + 'type': str, + 'required': False, + 'default': None + }, + 'mask': { + 'type': str, + 'required': False, + 'default': None + }, + 'prompt_strength': { + 'type': float, + 'required': False, + 'default': 0.8, + 'constraints': lambda prompt_strength: 0 <= prompt_strength <= 1 + }, + 'num_outputs': { + 'type': int, + 'required': False, + 'default': 1, + 'constraints': lambda num_outputs: 10 > num_outputs > 0 + }, + 'num_inference_steps': { + 'type': int, + 'required': False, + 'default': 50, + 'constraints': lambda num_inference_steps: 0 < num_inference_steps < 500 + }, + 'guidance_scale': { + 'type': float, + 'required': False, + 'default': 7.5, + 'constraints': lambda guidance_scale: 0 < guidance_scale < 20 + }, + 'scheduler': { + 'type': str, + 'required': False, + 'default': 'K-LMS', + 'constraints': lambda scheduler: scheduler in ['DDIM', 'DDPM', 'DPM-M', 'DPM-S', 'EULER-A', 'EULER-D', 'HEUN', 'IPNDM', 'KDPM2-A', 'KDPM2-D', 'PNDM', 'K-LMS'] + }, + 'seed': { + 'type': int, + 'required': False, + 'default': None + }, + 'nsfw': { + 'type': bool, + 'required': False, + 'default': False + } +} + + +def run(job): + ''' + Run inference on the model. + Returns output path, width the seed used to generate the image. + ''' + job_input = job['input'] + + # Input validation + validated_input = validate(job_input, INPUT_SCHEMA) + + if 'errors' in validated_input: + return {"error": validated_input['errors']} + validated_input = validated_input['validated_input'] + + # Download input objects + job_input['init_image'], job_input['mask'] = rp_download.download_input_objects( + [job_input.get('init_image', None), job_input.get('mask', None)] + ) # pylint: disable=unbalanced-tuple-unpacking + + MODEL.NSFW = job_input.get('nsfw', True) + + if job_input['seed'] is None: + job_input['seed'] = int.from_bytes(os.urandom(2), "big") + + img_paths = MODEL.predict( + prompt=job_input["prompt"], + negative_prompt=job_input.get("negative_prompt", None), + width=job_input.get('width', 512), + height=job_input.get('height', 512), + init_image=job_input['init_image'], + mask=job_input['mask'], + prompt_strength=job_input['prompt_strength'], + num_outputs=job_input.get('num_outputs', 1), + num_inference_steps=job_input.get('num_inference_steps', 50), + guidance_scale=job_input['guidance_scale'], + scheduler=job_input.get('scheduler', "K-LMS"), + seed=job_input['seed'] + ) + + job_output = [] + for index, img_path in enumerate(img_paths): + image_url = rp_upload.upload_image(job['id'], img_path, index) + + job_output.append({ + "image": image_url, + "seed": job_input['seed'] + index + }) + + # Remove downloaded input objects + rp_cleanup.clean(['input_objects']) + + return job_output + + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument("--model_url", type=str, + default=None, help="Model URL") + +if __name__ == "__main__": + args = parser.parse_args() + print(args) + + if "huggingface.co" in args.model_url: + url_parts = args.model_url.split("/") + model_id = f"{url_parts[-2]}/{url_parts[-1]}" + else: + model_id = f"model.safetensors" + + MODEL = predict.Predictor(model_id) + MODEL.setup() + + runpod.serverless.start({"handler": run})