Chatbot
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
8.6 KiB

''' StableDiffusion-v1 Predict Module '''
import os
from typing import List
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
# StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
DDIMScheduler,
DDPMScheduler,
# DEISMultistepScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
# KarrasVeScheduler,
PNDMScheduler,
# RePaintScheduler,
# ScoreSdeVeScheduler,
# ScoreSdeVpScheduler,
# UnCLIPScheduler,
# VQDiffusionScheduler,
LMSDiscreteScheduler
)
from PIL import Image
from cog import BasePredictor, Input, Path
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
class Predictor(BasePredictor):
'''Predictor class for StableDiffusion-v1'''
def __init__(self, model_id):
self.model_id = model_id
def setup(self):
'''
Load the model into memory to make running multiple predictions efficient
'''
print("Loading pipeline...")
# safety_checker = StableDiffusionSafetyChecker.from_pretrained(
# SAFETY_MODEL_ID,
# cache_dir=MODEL_CACHE,
# local_files_only=True,
# )
self.txt2img_pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
safety_checker=None,
# safety_checker=safety_checker,
cache_dir=MODEL_CACHE,
local_files_only=True,
).to("cuda")
self.img2img_pipe = StableDiffusionImg2ImgPipeline(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
tokenizer=self.txt2img_pipe.tokenizer,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
safety_checker=None,
# safety_checker=self.txt2img_pipe.safety_checker,
feature_extractor=self.txt2img_pipe.feature_extractor,
).to("cuda")
self.inpaint_pipe = StableDiffusionInpaintPipelineLegacy(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
tokenizer=self.txt2img_pipe.tokenizer,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
safety_checker=None,
# safety_checker=self.txt2img_pipe.safety_checker,
feature_extractor=self.txt2img_pipe.feature_extractor,
).to("cuda")
self.txt2img_pipe.enable_xformers_memory_efficient_attention()
self.img2img_pipe.enable_xformers_memory_efficient_attention()
self.inpaint_pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
@torch.cuda.amp.autocast()
def predict(
self,
prompt: str = Input(description="Input prompt", default=""),
negative_prompt: str = Input(
description="Specify things to not see in the output",
default=None,
),
width: int = Input(
description="Output image width; max 1024x768 or 768x1024 due to memory limits",
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024],
default=512,
),
height: int = Input(
description="Output image height; max 1024x768 or 768x1024 due to memory limits",
choices=[128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024],
default=512,
),
init_image: Path = Input(
description="Initial image to generate variations of, resized to the specified WxH.",
default=None,
),
mask: Path = Input(
description="""Black and white image to use as mask for inpainting over init_image.
Black pixels are inpainted and white pixels are preserved.
Tends to work better with prompt strength of 0.5-0.7""",
default=None,
),
prompt_strength: float = Input(
description="Prompt strength init image. 1.0 full destruction of init image",
default=0.8,
),
num_outputs: int = Input(
description="Number of images to output.",
ge=1,
le=10,
default=1
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
),
scheduler: str = Input(
default="K-LMS",
choices=["DDIM", "DDPM", "DPM-M", "DPM-S", "EULER-A", "EULER-D",
"HEUN", "IPNDM", "KDPM2-A", "KDPM2-D", "PNDM", "K-LMS"],
description="Choose a scheduler. If you use an init image, PNDM will be used",
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> List[Path]:
'''
Run a single prediction on the model
'''
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
if width * height > 786432:
raise ValueError(
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits."
)
extra_kwargs = {}
if mask:
if not init_image:
raise ValueError("mask was provided without init_image")
pipe = self.inpaint_pipe
init_image = Image.open(init_image).convert("RGB")
extra_kwargs = {
"mask_image": Image.open(mask).convert("RGB").resize(init_image.size),
"image": init_image,
"strength": prompt_strength,
}
elif init_image:
pipe = self.img2img_pipe
extra_kwargs = {
"init_image": Image.open(init_image).convert("RGB"),
"strength": prompt_strength,
}
else:
pipe = self.txt2img_pipe
extra_kwargs = {
"width": width,
"height": height,
}
pipe.scheduler = make_scheduler(scheduler, pipe.scheduler.config)
generator = torch.Generator("cuda").manual_seed(seed)
output = pipe(
prompt=[prompt] * num_outputs if prompt is not None else None,
negative_prompt=[negative_prompt]*num_outputs if negative_prompt is not None else None,
# width=width,
# height=height,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
**extra_kwargs,
)
output_paths = []
for i, sample in enumerate(output.images):
# if output.nsfw_content_detected and output.nsfw_content_detected[i] and self.NSFW:
# continue
output_path = f"/tmp/out-{i}.png"
sample.save(output_path)
output_paths.append(Path(output_path))
if len(output_paths) == 0:
raise Exception(
"NSFW content detected. Try running it again, or try a different prompt."
)
return output_paths
def make_scheduler(name, config):
'''
Returns a scheduler from a name and config.
'''
return {
"DDIM": DDIMScheduler.from_config(config),
"DDPM": DDPMScheduler.from_config(config),
# "DEIS": DEISMultistepScheduler.from_config(config),
"DPM-M": DPMSolverMultistepScheduler.from_config(config),
"DPM-S": DPMSolverSinglestepScheduler.from_config(config),
"EULER-A": EulerAncestralDiscreteScheduler.from_config(config),
"EULER-D": EulerDiscreteScheduler.from_config(config),
"HEUN": HeunDiscreteScheduler.from_config(config),
"IPNDM": IPNDMScheduler.from_config(config),
"KDPM2-A": KDPM2AncestralDiscreteScheduler.from_config(config),
"KDPM2-D": KDPM2DiscreteScheduler.from_config(config),
# "KARRAS-VE": KarrasVeScheduler.from_config(config),
"PNDM": PNDMScheduler.from_config(config),
# "RE-PAINT": RePaintScheduler.from_config(config),
# "SCORE-VE": ScoreSdeVeScheduler.from_config(config),
# "SCORE-VP": ScoreSdeVpScheduler.from_config(config),
# "UN-CLIPS": UnCLIPScheduler.from_config(config),
# "VQD": VQDiffusionScheduler.from_config(config),
"K-LMS": LMSDiscreteScheduler.from_config(config)
}[name]