You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
237 lines
8.6 KiB
237 lines
8.6 KiB
''' StableDiffusion-v1 Predict Module '''
|
|
|
|
import os
|
|
from typing import List
|
|
|
|
import torch
|
|
from diffusers import (
|
|
StableDiffusionPipeline,
|
|
StableDiffusionImg2ImgPipeline,
|
|
# StableDiffusionInpaintPipeline,
|
|
StableDiffusionInpaintPipelineLegacy,
|
|
|
|
DDIMScheduler,
|
|
DDPMScheduler,
|
|
# DEISMultistepScheduler,
|
|
DPMSolverMultistepScheduler,
|
|
DPMSolverSinglestepScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
EulerDiscreteScheduler,
|
|
HeunDiscreteScheduler,
|
|
IPNDMScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
KDPM2DiscreteScheduler,
|
|
# KarrasVeScheduler,
|
|
PNDMScheduler,
|
|
# RePaintScheduler,
|
|
# ScoreSdeVeScheduler,
|
|
# ScoreSdeVpScheduler,
|
|
# UnCLIPScheduler,
|
|
# VQDiffusionScheduler,
|
|
LMSDiscreteScheduler
|
|
)
|
|
|
|
from PIL import Image
|
|
from cog import BasePredictor, Input, Path
|
|
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
|
|
|
MODEL_CACHE = "diffusers-cache"
|
|
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
|
|
|
|
|
|
class Predictor(BasePredictor):
|
|
'''Predictor class for StableDiffusion-v1'''
|
|
|
|
def __init__(self, model_id):
|
|
self.model_id = model_id
|
|
|
|
def setup(self):
|
|
'''
|
|
Load the model into memory to make running multiple predictions efficient
|
|
'''
|
|
print("Loading pipeline...")
|
|
|
|
# safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
|
# SAFETY_MODEL_ID,
|
|
# cache_dir=MODEL_CACHE,
|
|
# local_files_only=True,
|
|
# )
|
|
self.txt2img_pipe = StableDiffusionPipeline.from_pretrained(
|
|
self.model_id,
|
|
safety_checker=None,
|
|
# safety_checker=safety_checker,
|
|
cache_dir=MODEL_CACHE,
|
|
local_files_only=True,
|
|
).to("cuda")
|
|
self.img2img_pipe = StableDiffusionImg2ImgPipeline(
|
|
vae=self.txt2img_pipe.vae,
|
|
text_encoder=self.txt2img_pipe.text_encoder,
|
|
tokenizer=self.txt2img_pipe.tokenizer,
|
|
unet=self.txt2img_pipe.unet,
|
|
scheduler=self.txt2img_pipe.scheduler,
|
|
safety_checker=None,
|
|
# safety_checker=self.txt2img_pipe.safety_checker,
|
|
feature_extractor=self.txt2img_pipe.feature_extractor,
|
|
).to("cuda")
|
|
self.inpaint_pipe = StableDiffusionInpaintPipelineLegacy(
|
|
vae=self.txt2img_pipe.vae,
|
|
text_encoder=self.txt2img_pipe.text_encoder,
|
|
tokenizer=self.txt2img_pipe.tokenizer,
|
|
unet=self.txt2img_pipe.unet,
|
|
scheduler=self.txt2img_pipe.scheduler,
|
|
safety_checker=None,
|
|
# safety_checker=self.txt2img_pipe.safety_checker,
|
|
feature_extractor=self.txt2img_pipe.feature_extractor,
|
|
).to("cuda")
|
|
|
|
self.txt2img_pipe.enable_xformers_memory_efficient_attention()
|
|
self.img2img_pipe.enable_xformers_memory_efficient_attention()
|
|
self.inpaint_pipe.enable_xformers_memory_efficient_attention()
|
|
|
|
@torch.inference_mode()
|
|
@torch.cuda.amp.autocast()
|
|
def predict(
|
|
self,
|
|
prompt: str = Input(description="Input prompt", default=""),
|
|
negative_prompt: str = Input(
|
|
description="Specify things to not see in the output",
|
|
default=None,
|
|
),
|
|
width: int = Input(
|
|
description="Output image width; max 1024x768 or 768x1024 due to memory limits",
|
|
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024],
|
|
default=512,
|
|
),
|
|
height: int = Input(
|
|
description="Output image height; max 1024x768 or 768x1024 due to memory limits",
|
|
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024],
|
|
default=512,
|
|
),
|
|
init_image: Path = Input(
|
|
description="Initial image to generate variations of, resized to the specified WxH.",
|
|
default=None,
|
|
),
|
|
mask: Path = Input(
|
|
description="""Black and white image to use as mask for inpainting over init_image.
|
|
Black pixels are inpainted and white pixels are preserved.
|
|
Tends to work better with prompt strength of 0.5-0.7""",
|
|
default=None,
|
|
),
|
|
prompt_strength: float = Input(
|
|
description="Prompt strength init image. 1.0 full destruction of init image",
|
|
default=0.8,
|
|
),
|
|
num_outputs: int = Input(
|
|
description="Number of images to output.",
|
|
ge=1,
|
|
le=10,
|
|
default=1
|
|
),
|
|
num_inference_steps: int = Input(
|
|
description="Number of denoising steps", ge=1, le=500, default=50
|
|
),
|
|
guidance_scale: float = Input(
|
|
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
|
|
),
|
|
scheduler: str = Input(
|
|
default="K-LMS",
|
|
choices=["DDIM", "DDPM", "DPM-M", "DPM-S", "EULER-A", "EULER-D",
|
|
"HEUN", "IPNDM", "KDPM2-A", "KDPM2-D", "PNDM", "K-LMS"],
|
|
description="Choose a scheduler. If you use an init image, PNDM will be used",
|
|
),
|
|
seed: int = Input(
|
|
description="Random seed. Leave blank to randomize the seed", default=None
|
|
),
|
|
) -> List[Path]:
|
|
'''
|
|
Run a single prediction on the model
|
|
'''
|
|
if seed is None:
|
|
seed = int.from_bytes(os.urandom(2), "big")
|
|
|
|
if width * height > 786432:
|
|
raise ValueError(
|
|
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits."
|
|
)
|
|
|
|
extra_kwargs = {}
|
|
if mask:
|
|
if not init_image:
|
|
raise ValueError("mask was provided without init_image")
|
|
|
|
pipe = self.inpaint_pipe
|
|
init_image = Image.open(init_image).convert("RGB")
|
|
extra_kwargs = {
|
|
"mask_image": Image.open(mask).convert("RGB").resize(init_image.size),
|
|
"image": init_image,
|
|
"strength": prompt_strength,
|
|
}
|
|
elif init_image:
|
|
pipe = self.img2img_pipe
|
|
extra_kwargs = {
|
|
"init_image": Image.open(init_image).convert("RGB"),
|
|
"strength": prompt_strength,
|
|
}
|
|
else:
|
|
pipe = self.txt2img_pipe
|
|
extra_kwargs = {
|
|
"width": width,
|
|
"height": height,
|
|
}
|
|
|
|
pipe.scheduler = make_scheduler(scheduler, pipe.scheduler.config)
|
|
|
|
generator = torch.Generator("cuda").manual_seed(seed)
|
|
output = pipe(
|
|
prompt=[prompt] * num_outputs if prompt is not None else None,
|
|
negative_prompt=[negative_prompt]*num_outputs if negative_prompt is not None else None,
|
|
# width=width,
|
|
# height=height,
|
|
guidance_scale=guidance_scale,
|
|
generator=generator,
|
|
num_inference_steps=num_inference_steps,
|
|
**extra_kwargs,
|
|
)
|
|
|
|
output_paths = []
|
|
for i, sample in enumerate(output.images):
|
|
# if output.nsfw_content_detected and output.nsfw_content_detected[i] and self.NSFW:
|
|
# continue
|
|
|
|
output_path = f"/tmp/out-{i}.png"
|
|
sample.save(output_path)
|
|
output_paths.append(Path(output_path))
|
|
|
|
if len(output_paths) == 0:
|
|
raise Exception(
|
|
"NSFW content detected. Try running it again, or try a different prompt."
|
|
)
|
|
|
|
return output_paths
|
|
|
|
|
|
def make_scheduler(name, config):
|
|
'''
|
|
Returns a scheduler from a name and config.
|
|
'''
|
|
return {
|
|
"DDIM": DDIMScheduler.from_config(config),
|
|
"DDPM": DDPMScheduler.from_config(config),
|
|
# "DEIS": DEISMultistepScheduler.from_config(config),
|
|
"DPM-M": DPMSolverMultistepScheduler.from_config(config),
|
|
"DPM-S": DPMSolverSinglestepScheduler.from_config(config),
|
|
"EULER-A": EulerAncestralDiscreteScheduler.from_config(config),
|
|
"EULER-D": EulerDiscreteScheduler.from_config(config),
|
|
"HEUN": HeunDiscreteScheduler.from_config(config),
|
|
"IPNDM": IPNDMScheduler.from_config(config),
|
|
"KDPM2-A": KDPM2AncestralDiscreteScheduler.from_config(config),
|
|
"KDPM2-D": KDPM2DiscreteScheduler.from_config(config),
|
|
# "KARRAS-VE": KarrasVeScheduler.from_config(config),
|
|
"PNDM": PNDMScheduler.from_config(config),
|
|
# "RE-PAINT": RePaintScheduler.from_config(config),
|
|
# "SCORE-VE": ScoreSdeVeScheduler.from_config(config),
|
|
# "SCORE-VP": ScoreSdeVpScheduler.from_config(config),
|
|
# "UN-CLIPS": UnCLIPScheduler.from_config(config),
|
|
# "VQD": VQDiffusionScheduler.from_config(config),
|
|
"K-LMS": LMSDiscreteScheduler.from_config(config)
|
|
}[name]
|
|
|