diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 88b879cc9..6cb6a8375 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -15,7 +15,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: "Symlink checkpoints" - run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints + run: ln -s $SGM_CHECKPOINTS checkpoints - name: "Setup python" uses: actions/setup-python@v4 with: diff --git a/pyproject.toml b/pyproject.toml index 2cc502168..94ba68dfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,5 +44,5 @@ dependencies = [ test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install -r requirements/pt2.txt", - "pytest -v tests/inference/test_inference.py {args}", + "pytest -v tests/inference {args}", ] diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 2984dbf7a..017db211b 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,6 +1,30 @@ +import os + +import numpy as np +import streamlit as st +import torch +from einops import repeat from pytorch_lightning import seed_everything -from scripts.demo.streamlit_helpers import * +from sgm.inference.api import ( + SamplingSpec, + SamplingParams, + ModelArchitecture, + SamplingPipeline, + model_specs, +) +from sgm.inference.helpers import ( + get_unique_embedder_keys_from_conditioner, + perform_save_locally, +) +from scripts.demo.streamlit_helpers import ( + get_interactive_image, + init_embedder_options, + init_sampling, + init_save_locally, + init_st, + show_samples, +) SAVE_PATH = "outputs/demo/txt2img/" @@ -33,63 +57,6 @@ "3.0": (1728, 576), } -VERSION2SPECS = { - "SDXL-base-1.0": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": False, - "config": "configs/inference/sd_xl_base.yaml", - "ckpt": "checkpoints/sd_xl_base_1.0.safetensors", - }, - "SDXL-base-0.9": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": False, - "config": "configs/inference/sd_xl_base.yaml", - "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", - }, - "SD-2.1": { - "H": 512, - "W": 512, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1.yaml", - "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", - }, - "SD-2.1-768": { - "H": 768, - "W": 768, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1_768.yaml", - "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", - }, - "SDXL-refiner-0.9": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_xl_refiner.yaml", - "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", - }, - "SDXL-refiner-1.0": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_xl_refiner.yaml", - "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors", - }, -} - def load_img(display=True, key=None, device="cuda"): image = get_interactive_image(key=key) @@ -111,170 +78,181 @@ def load_img(display=True, key=None, device="cuda"): def run_txt2img( state, - version, - version_dict, - is_legacy=False, + model_id: ModelArchitecture, + prompt: str, + negative_prompt: str, return_latents=False, - filter=None, stage2strength=None, ): - if version.startswith("SDXL-base"): - W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) + model: SamplingPipeline = state["model"] + params: SamplingParams = state["params"] + if model_id in sdxl_base_model_list: + width, height = st.selectbox( + "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 + ) else: - H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048) - W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048) - C = version_dict["C"] - F = version_dict["f"] - - init_dict = { - "orig_width": W, - "orig_height": H, - "target_width": W, - "target_height": H, - } - value_dict = init_embedder_options( - get_unique_embedder_keys_from_conditioner(state["model"].conditioner), - init_dict, + height = int( + st.number_input("H", value=params.height, min_value=64, max_value=2048) + ) + width = int( + st.number_input("W", value=params.width, min_value=64, max_value=2048) + ) + + params = init_embedder_options( + get_unique_embedder_keys_from_conditioner(model.model.conditioner), + params=params, prompt=prompt, negative_prompt=negative_prompt, ) - sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength) + params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols + params.height = height + params.width = width if st.button("Sample"): st.write(f"**Model I:** {version}") - out = do_sample( - state["model"], - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings=["txt"] if not is_legacy else [], + outputs = st.empty() + st.text("Sampling") + out = model.text_to_image( + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=int(num_samples), return_latents=return_latents, - filter=filter, + noise_strength=stage2strength, + filter=state["filter"], ) + + show_samples(out, outputs) + return out def run_img2img( state, - version_dict, - is_legacy=False, + prompt: str, + negative_prompt: str, return_latents=False, - filter=None, stage2strength=None, ): + model: SamplingPipeline = state["model"] + params: SamplingParams = state["params"] + img = load_img() if img is None: return None - H, W = img.shape[2], img.shape[3] - - init_dict = { - "orig_width": W, - "orig_height": H, - "target_width": W, - "target_height": H, - } - value_dict = init_embedder_options( - get_unique_embedder_keys_from_conditioner(state["model"].conditioner), - init_dict, + params.height, params.width = img.shape[2], img.shape[3] + + params = init_embedder_options( + get_unique_embedder_keys_from_conditioner(model.model.conditioner), + params=params, prompt=prompt, negative_prompt=negative_prompt, ) - strength = st.number_input( + params.img2img_strength = st.number_input( "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 ) - sampler, num_rows, num_cols = init_sampling( - img2img_strength=strength, - stage2strength=stage2strength, - ) + params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols if st.button("Sample"): - out = do_img2img( - repeat(img, "1 ... -> n ...", n=num_samples), - state["model"], - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=["txt"] if not is_legacy else [], + outputs = st.empty() + st.text("Sampling") + out = model.image_to_image( + image=repeat(img, "1 ... -> n ...", n=num_samples), + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=int(num_samples), return_latents=return_latents, - filter=filter, + noise_strength=stage2strength, + filter=state["filter"], ) + + show_samples(out, outputs) return out def apply_refiner( input, state, - sampler, - num_samples, - prompt, - negative_prompt, - filter=None, + num_samples: int, + prompt: str, + negative_prompt: str, finish_denoising=False, ): - init_dict = { - "orig_width": input.shape[3] * 8, - "orig_height": input.shape[2] * 8, - "target_width": input.shape[3] * 8, - "target_height": input.shape[2] * 8, - } - - value_dict = init_dict - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt + model: SamplingPipeline = state["model"] + params: SamplingParams = state["params"] - value_dict["crop_coords_top"] = 0 - value_dict["crop_coords_left"] = 0 - - value_dict["aesthetic_score"] = 6.0 - value_dict["negative_aesthetic_score"] = 2.5 + params.orig_width = input.shape[3] * 8 + params.orig_height = input.shape[2] * 8 + params.width = input.shape[3] * 8 + params.height = input.shape[2] * 8 st.warning(f"refiner input shape: {input.shape}") - samples = do_img2img( - input, - state["model"], - sampler, - value_dict, - num_samples, - skip_encode=True, - filter=filter, + + samples = model.refiner( + image=input, + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=num_samples, + return_latents=False, + filter=state["filter"], add_noise=not finish_denoising, ) return samples +sdxl_base_model_list = [ + ModelArchitecture.SDXL_V1_0_BASE, + ModelArchitecture.SDXL_V0_9_BASE, +] + +sdxl_refiner_model_list = [ + ModelArchitecture.SDXL_V1_0_REFINER, + ModelArchitecture.SDXL_V0_9_REFINER, +] + if __name__ == "__main__": st.title("Stable Diffusion") - version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) - version_dict = VERSION2SPECS[version] + version = st.selectbox( + "Model Version", + [member.value for member in ModelArchitecture], + 0, + ) + version_enum = ModelArchitecture(version) + specs = model_specs[version_enum] mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") - set_lowvram_mode(st.checkbox("Low vram mode", True)) + st.write("**Performance Options:**") + use_fp16 = st.checkbox("Use fp16 (Saves VRAM)", True) + enable_swap = st.checkbox("Swap models to CPU (Saves VRAM, uses RAM)", True) + st.write("__________________________") - if version.startswith("SDXL-base"): + if version_enum in sdxl_base_model_list: add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: add_pipeline = False - seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + seed = int( + st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + ) seed_everything(seed) - save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) - - state = init_st(version_dict, load_filter=True) - if state["msg"]: - st.info(state["msg"]) + save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) + state = init_st( + model_specs[version_enum], + load_filter=True, + use_fp16=use_fp16, + enable_swap=enable_swap, + ) model = state["model"] - is_legacy = version_dict["is_legacy"] + is_legacy = specs.is_legacy prompt = st.text_input( "prompt", @@ -290,47 +268,58 @@ def apply_refiner( if add_pipeline: st.write("__________________________") - version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"]) + version2 = ModelArchitecture( + st.selectbox( + "Refiner:", + [member.value for member in sdxl_refiner_model_list], + ) + ) st.warning( f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " ) st.write("**Refiner Options:**") - version_dict2 = VERSION2SPECS[version2] - state2 = init_st(version_dict2, load_filter=False) - st.info(state2["msg"]) + specs2 = model_specs[version2] + state2 = init_st( + specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap + ) + params2 = state2["params"] - stage2strength = st.number_input( + params2.img2img_strength = st.number_input( "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 ) - sampler2, *_ = init_sampling( + params2, *_ = init_sampling( + params=state2["params"], key=2, - img2img_strength=stage2strength, specify_num_samples=False, ) st.write("__________________________") finish_denoising = st.checkbox("Finish denoising with refiner.", True) - if not finish_denoising: + if finish_denoising: + stage2strength = params2.img2img_strength + else: stage2strength = None + else: + state2 = None + params2 = None + stage2strength = None if mode == "txt2img": out = run_txt2img( - state, - version, - version_dict, - is_legacy=is_legacy, + state=state, + model_id=version_enum, + prompt=prompt, + negative_prompt=negative_prompt, return_latents=add_pipeline, - filter=state.get("filter"), stage2strength=stage2strength, ) elif mode == "img2img": out = run_img2img( - state, - version_dict, - is_legacy=is_legacy, + state=state, + prompt=prompt, + negative_prompt=negative_prompt, return_latents=add_pipeline, - filter=state.get("filter"), stage2strength=stage2strength, ) else: @@ -342,17 +331,17 @@ def apply_refiner( samples_z = None if add_pipeline and samples_z is not None: + outputs = st.empty() st.write("**Running Refinement Stage**") samples = apply_refiner( - samples_z, - state2, - sampler2, - samples_z.shape[0], + input=samples_z, + state=state2, + num_samples=samples_z.shape[0], prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", - filter=state.get("filter"), finish_denoising=finish_denoising, ) + show_samples(samples, outputs) if save_locally and samples is not None: perform_save_locally(save_path, samples) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 82b7fb9cc..a0f3848a9 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,166 +1,68 @@ -import math import os -from typing import List, Union import numpy as np import streamlit as st import torch from einops import rearrange, repeat -from imwatermark import WatermarkEncoder -from omegaconf import ListConfig, OmegaConf from PIL import Image -from safetensors.torch import load_file as load_safetensors -from torch import autocast from torchvision import transforms -from torchvision.utils import make_grid +from typing import Optional, Tuple, Dict, Any -from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.modules.diffusionmodules.sampling import ( - DPMPP2MSampler, - DPMPP2SAncestralSampler, - EulerAncestralSampler, - EulerEDMSampler, - HeunEDMSampler, - LinearMultistepSampler, -) -from sgm.util import append_dims, instantiate_from_config - - -class WatermarkEmbedder: - def __init__(self, watermark): - self.watermark = watermark - self.num_bits = len(WATERMARK_BITS) - self.encoder = WatermarkEncoder() - self.encoder.set_watermark("bits", self.watermark) - - def __call__(self, image: torch.Tensor): - """ - Adds a predefined watermark to the input image - - Args: - image: ([N,] B, C, H, W) in range [0, 1] - - Returns: - same as input but watermarked - """ - # watermarking libary expects input as cv2 BGR format - squeeze = len(image.shape) == 4 - if squeeze: - image = image[None, ...] - n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] - # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] - for k in range(image_np.shape[0]): - image_np[k] = self.encoder.encode(image_np[k], "dwtDct") - image = torch.from_numpy( - rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) - ).to(image.device) - image = torch.clamp(image / 255, min=0.0, max=1.0) - if squeeze: - image = image[0] - return image +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -# A fixed 48-bit message that was choosen at random -# WATERMARK_MESSAGE = 0xB3EC907BB19E -WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 -# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 -WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] -embed_watemark = WatermarkEmbedder(WATERMARK_BITS) +from sgm.inference.api import ( + Discretization, + Guider, + Sampler, + SamplingParams, + SamplingSpec, + SamplingPipeline, + Thresholder, +) +from sgm.inference.helpers import embed_watermark, CudaModelManager @st.cache_resource() -def init_st(version_dict, load_ckpt=True, load_filter=True): - state = dict() - if not "model" in state: - config = version_dict["config"] - ckpt = version_dict["ckpt"] - - config = OmegaConf.load(config) - model, msg = load_model_from_config(config, ckpt if load_ckpt else None) - - state["msg"] = msg - state["model"] = model - state["ckpt"] = ckpt if load_ckpt else None - state["config"] = config - if load_filter: - state["filter"] = DeepFloydDataFiltering(verbose=False) - return state - - -def load_model(model): - model.cuda() - - -lowvram_mode = False - - -def set_lowvram_mode(mode): - global lowvram_mode - lowvram_mode = mode - - -def initial_model_load(model): - global lowvram_mode - if lowvram_mode: - model.model.half() +def init_st( + spec: SamplingSpec, + load_ckpt=True, + load_filter=True, + use_fp16=True, + enable_swap=True, +) -> Dict[str, Any]: + state: Dict[str, Any] = dict() + config = spec.config + ckpt = spec.ckpt + + if enable_swap: + pipeline = SamplingPipeline( + model_spec=spec, + use_fp16=use_fp16, + device=CudaModelManager(device="cuda", swap_device="cpu"), + ) else: - model.cuda() - return model - - -def unload_model(model): - global lowvram_mode - if lowvram_mode: - model.cpu() - torch.cuda.empty_cache() - - -def load_model_from_config(config, ckpt=None, verbose=True): - model = instantiate_from_config(config.model) - - if ckpt is not None: - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - global_step = pl_sd["global_step"] - st.info(f"loaded ckpt from global step {global_step}") - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - msg = None - - m, u = model.load_state_dict(sd, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) + pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16) + + state["spec"] = spec + state["model"] = pipeline + state["ckpt"] = ckpt if load_ckpt else None + state["config"] = config + state["params"] = spec.default_params + if load_filter: + state["filter"] = DeepFloydDataFiltering(verbose=False) else: - msg = None - - model = initial_model_load(model) - model.eval() - return model, msg + state["filter"] = None + return state def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) -def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): - # Hardcoded demo settings; might undergo some changes in the future - - value_dict = {} +def init_embedder_options( + keys, params: SamplingParams, prompt=None, negative_prompt=None +) -> SamplingParams: for key in keys: if key == "txt": if prompt is None: @@ -170,46 +72,38 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): if negative_prompt is None: negative_prompt = st.text_input("Negative prompt", "") - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", - value=init_dict["orig_width"], + value=params.orig_width, min_value=16, ) orig_height = st.number_input( "orig_height", - value=init_dict["orig_height"], + value=params.orig_height, min_value=16, ) - value_dict["orig_width"] = orig_width - value_dict["orig_height"] = orig_height + params.orig_width = int(orig_width) + params.orig_height = int(orig_height) if key == "crop_coords_top_left": - crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) - crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) - - value_dict["crop_coords_top"] = crop_coord_top - value_dict["crop_coords_left"] = crop_coord_left - - if key == "aesthetic_score": - value_dict["aesthetic_score"] = 6.0 - value_dict["negative_aesthetic_score"] = 2.5 - - if key == "target_size_as_tuple": - value_dict["target_width"] = init_dict["target_width"] - value_dict["target_height"] = init_dict["target_height"] + crop_coord_top = st.number_input( + "crop_coords_top", value=params.crop_coords_top, min_value=0 + ) + crop_coord_left = st.number_input( + "crop_coords_left", value=params.crop_coords_left, min_value=0 + ) - return value_dict + params.crop_coords_top = int(crop_coord_top) + params.crop_coords_left = int(crop_coord_left) + return params def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watemark(samples) + samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( @@ -228,78 +122,26 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path -class Img2ImgDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 1.0): - self.discretization = discretization - self.strength = strength - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] - print("prune index:", max(int(self.strength * len(sigmas)), 1)) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas - - -class Txt2NoisyDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 0.0, original_steps=None): - self.discretization = discretization - self.strength = strength - self.original_steps = original_steps - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - if self.original_steps is None: - steps = len(sigmas) - else: - steps = self.original_steps + 1 - prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) - sigmas = sigmas[prune_index:] - print("prune index:", prune_index) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas - - -def get_guider(key): - guider = st.sidebar.selectbox( - f"Discretization #{key}", - [ - "VanillaCFG", - "IdentityGuider", - ], +def show_samples(samples, outputs): + if isinstance(samples, tuple): + samples, _ = samples + grid = embed_watermark(torch.stack([samples])) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) + + +def get_guider(params: SamplingParams, key=1) -> SamplingParams: + params.guider = Guider( + st.sidebar.selectbox( + f"Discretization #{key}", [member.value for member in Guider] + ) ) - if guider == "IdentityGuider": - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } - elif guider == "VanillaCFG": + if params.guider == Guider.VANILLA: scale = st.number_input( - f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 + f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0 ) - + params.scale = scale thresholder = st.sidebar.selectbox( f"Thresholder #{key}", [ @@ -308,182 +150,97 @@ def get_guider(key): ) if thresholder == "None": - dyn_thresh_config = { - "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" - } + params.thresholder = Thresholder.NONE else: raise NotImplementedError - - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", - "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, - } - else: - raise NotImplementedError - return guider_config + return params def init_sampling( + params: SamplingParams, key=1, - img2img_strength=1.0, specify_num_samples=True, - stage2strength=None, -): +) -> Tuple[SamplingParams, int, int]: num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=2, min_value=1, max_value=10 ) - steps = st.sidebar.number_input( - f"steps #{key}", value=40, min_value=1, max_value=1000 + params.steps = int( + st.sidebar.number_input( + f"steps #{key}", value=params.steps, min_value=1, max_value=1000 + ) ) - sampler = st.sidebar.selectbox( - f"Sampler #{key}", - [ - "EulerEDMSampler", - "HeunEDMSampler", - "EulerAncestralSampler", - "DPMPP2SAncestralSampler", - "DPMPP2MSampler", - "LinearMultistepSampler", - ], - 0, + + params.sampler = Sampler( + st.sidebar.selectbox( + f"Sampler #{key}", + [member.value for member in Sampler], + 0, + ) ) - discretization = st.sidebar.selectbox( - f"Discretization #{key}", - [ - "LegacyDDPMDiscretization", - "EDMDiscretization", - ], + params.discretization = Discretization( + st.sidebar.selectbox( + f"Discretization #{key}", + [member.value for member in Discretization], + ) ) - discretization_config = get_discretization(discretization, key=key) + params = get_discretization(params=params, key=key) + params = get_guider(params=params, key=key) + params = get_sampler(params=params, key=key) + + return params, num_rows, num_cols + - guider_config = get_guider(key=key) +def get_discretization(params: SamplingParams, key=1) -> SamplingParams: + if params.discretization == Discretization.EDM: + params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min) + params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max) + params.rho = st.number_input(f"rho #{key}", value=params.rho) + return params - sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) - if img2img_strength < 1.0: - st.warning( - f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" + +def get_sampler(params: SamplingParams, key=1) -> SamplingParams: + if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM): + params.s_churn = st.sidebar.number_input( + f"s_churn #{key}", value=params.s_churn, min_value=0.0 ) - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, strength=img2img_strength + params.s_tmin = st.sidebar.number_input( + f"s_tmin #{key}", value=params.s_tmin, min_value=0.0 ) - if stage2strength is not None: - sampler.discretization = Txt2NoisyDiscretizationWrapper( - sampler.discretization, strength=stage2strength, original_steps=steps + params.s_tmax = st.sidebar.number_input( + f"s_tmax #{key}", value=params.s_tmax, min_value=0.0 ) - return sampler, num_rows, num_cols - - -def get_discretization(discretization, key=1): - if discretization == "LegacyDDPMDiscretization": - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", - } - elif discretization == "EDMDiscretization": - sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292 - sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146 - rho = st.number_input(f"rho #{key}", value=3.0) - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", - "params": { - "sigma_min": sigma_min, - "sigma_max": sigma_max, - "rho": rho, - }, - } - - return discretization_config - - -def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): - if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": - s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) - s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) - s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) - s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) - - if sampler_name == "EulerEDMSampler": - sampler = EulerEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "HeunEDMSampler": - sampler = HeunEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) - elif ( - sampler_name == "EulerAncestralSampler" - or sampler_name == "DPMPP2SAncestralSampler" - ): - s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) - eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) - - if sampler_name == "EulerAncestralSampler": - sampler = EulerAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2SAncestralSampler": - sampler = DPMPP2SAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2MSampler": - sampler = DPMPP2MSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - verbose=True, + params.s_noise = st.sidebar.number_input( + f"s_noise #{key}", value=params.s_noise, min_value=0.0 ) - elif sampler_name == "LinearMultistepSampler": - order = st.sidebar.number_input("order", value=4, min_value=1) - sampler = LinearMultistepSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - order=order, - verbose=True, + + elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): + params.s_noise = st.sidebar.number_input( + "s_noise", value=params.s_noise, min_value=0.0 ) - else: - raise ValueError(f"unknown sampler {sampler_name}!") + params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) - return sampler + elif params.sampler == Sampler.LINEAR_MULTISTEP: + params.order = int( + st.sidebar.number_input("order", value=params.order, min_value=1) + ) + return params -def get_interactive_image(key=None) -> Image.Image: +def get_interactive_image(key=None) -> Optional[Image.Image]: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image + return None -def load_img(display=True, key=None): +def load_img(display=True, key=None) -> Optional[torch.Tensor]: image = get_interactive_image(key=key) if image is None: return None @@ -507,238 +264,3 @@ def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image - - -def do_sample( - model, - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings: List = None, - batch2model_input: List = None, - return_latents=False, - filter=None, -): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - if batch2model_input is None: - batch2model_input = [] - - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - num_samples = [num_samples] - load_model(model.conditioner) - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - unload_model(model.conditioner) - - for k in c: - if not k == "crossattn": - c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) - ) - - additional_model_inputs = {} - for k in batch2model_input: - additional_model_inputs[k] = batch[k] - - shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to("cuda") - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - load_model(model.denoiser) - load_model(model.model) - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - unload_model(model.model) - unload_model(model.denoiser) - - load_model(model.first_stage_model) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - unload_model(model.first_stage_model) - - if filter is not None: - samples = filter(samples) - - grid = torch.stack([samples]) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - - if return_latents: - return samples, samples_z - return samples - - -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): - # Hardcoded demo setups; might undergo some changes in the future - - batch = {} - batch_uc = {} - - for key in keys: - if key == "txt": - batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - elif key == "original_size_as_tuple": - batch["original_size_as_tuple"] = ( - torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) - .to(device) - .repeat(*N, 1) - ) - elif key == "crop_coords_top_left": - batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) - .to(device) - .repeat(*N, 1) - ) - elif key == "aesthetic_score": - batch["aesthetic_score"] = ( - torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) - ) - batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) - ) - - elif key == "target_size_as_tuple": - batch["target_size_as_tuple"] = ( - torch.tensor([value_dict["target_height"], value_dict["target_width"]]) - .to(device) - .repeat(*N, 1) - ) - else: - batch[key] = value_dict[key] - - for key in batch.keys(): - if key not in batch_uc and isinstance(batch[key], torch.Tensor): - batch_uc[key] = torch.clone(batch[key]) - return batch, batch_uc - - -@torch.no_grad() -def do_img2img( - img, - model, - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=[], - additional_kwargs={}, - offset_noise_level: int = 0.0, - return_latents=False, - skip_encode=False, - filter=None, - add_noise=True, -): - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - load_model(model.conditioner) - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - unload_model(model.conditioner) - for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) - - for k in additional_kwargs: - c[k] = uc[k] = additional_kwargs[k] - if skip_encode: - z = img - else: - load_model(model.first_stage_model) - z = model.encode_first_stage(img) - unload_model(model.first_stage_model) - - noise = torch.randn_like(z) - - sigmas = sampler.discretization(sampler.num_steps).cuda() - sigma = sigmas[0] - - st.info(f"all sigmas: {sigmas}") - st.info(f"noising sigma: {sigma}") - if offset_noise_level > 0.0: - noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim - ) - if add_noise: - noised_z = z + noise * append_dims(sigma, z.ndim).cuda() - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. - else: - noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) - - def denoiser(x, sigma, c): - return model.denoiser(model.model, x, sigma, c) - - load_model(model.denoiser) - load_model(model.model) - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - unload_model(model.model) - unload_model(model.denoiser) - - load_model(model.first_stage_model) - samples_x = model.decode_first_stage(samples_z) - unload_model(model.first_stage_model) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - grid = embed_watemark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - if return_latents: - return samples, samples_z - return samples diff --git a/sgm/__init__.py b/sgm/__init__.py index 24bc84af8..2a0589311 100644 --- a/sgm/__init__.py +++ b/sgm/__init__.py @@ -1,4 +1,4 @@ from .models import AutoencodingEngine, DiffusionEngine from .util import get_configs_path, instantiate_from_config -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 12efc064c..f4f2faa7c 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -1,11 +1,14 @@ from dataclasses import dataclass, asdict from enum import Enum from omegaconf import OmegaConf -import pathlib +import os from sgm.inference.helpers import ( do_sample, do_img2img, + DeviceModelManager, + get_model_manager, Img2ImgDiscretizationWrapper, + Txt2NoisyDiscretizationWrapper, ) from sgm.modules.diffusionmodules.sampling import ( EulerEDMSampler, @@ -15,17 +18,18 @@ DPMPP2MSampler, LinearMultistepSampler, ) -from sgm.util import load_model_from_config -from typing import Optional +from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path +import torch +from typing import Optional, Dict, Any, Union class ModelArchitecture(str, Enum): - SD_2_1 = "stable-diffusion-v2-1" - SD_2_1_768 = "stable-diffusion-v2-1-768" + SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base" + SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner" SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" - SDXL_V1_BASE = "stable-diffusion-xl-v1-base" - SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + SD_2_1 = "stable-diffusion-v2-1" + SD_2_1_768 = "stable-diffusion-v2-1-768" class Sampler(str, Enum): @@ -53,16 +57,20 @@ class Thresholder(str, Enum): @dataclass class SamplingParams: - width: int = 1024 - height: int = 1024 - steps: int = 50 - sampler: Sampler = Sampler.DPMPP2M + """ + Parameters for sampling. + """ + + width: Optional[int] = None + height: Optional[int] = None + steps: Optional[int] = None + sampler: Sampler = Sampler.EULER_EDM discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE - scale: float = 6.0 - aesthetic_score: float = 5.0 - negative_aesthetic_score: float = 5.0 + scale: float = 5.0 + aesthetic_score: float = 6.0 + negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 orig_width: int = 1024 orig_height: int = 1024 @@ -89,8 +97,10 @@ class SamplingSpec: config: str ckpt: str is_guided: bool + default_params: SamplingParams +# The defaults here are derived from user preference testing. model_specs = { ModelArchitecture.SD_2_1: SamplingSpec( height=512, @@ -101,6 +111,12 @@ class SamplingSpec: config="sd_2_1.yaml", ckpt="v2-1_512-ema-pruned.safetensors", is_guided=True, + default_params=SamplingParams( + width=512, + height=512, + steps=40, + scale=7.0, + ), ), ModelArchitecture.SD_2_1_768: SamplingSpec( height=768, @@ -111,6 +127,12 @@ class SamplingSpec: config="sd_2_1_768.yaml", ckpt="v2-1_768-ema-pruned.safetensors", is_guided=True, + default_params=SamplingParams( + width=768, + height=768, + steps=40, + scale=7.0, + ), ), ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( height=1024, @@ -121,6 +143,7 @@ class SamplingSpec: config="sd_xl_base.yaml", ckpt="sd_xl_base_0.9.safetensors", is_guided=True, + default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0), ), ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( height=1024, @@ -131,8 +154,11 @@ class SamplingSpec: config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_0.9.safetensors", is_guided=True, + default_params=SamplingParams( + width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15 + ), ), - ModelArchitecture.SDXL_V1_BASE: SamplingSpec( + ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec( height=1024, width=1024, channels=4, @@ -141,8 +167,9 @@ class SamplingSpec: config="sd_xl_base.yaml", ckpt="sd_xl_base_1.0.safetensors", is_guided=True, + default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0), ), - ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( + ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec( height=1024, width=1024, channels=4, @@ -151,34 +178,97 @@ class SamplingSpec: config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_1.0.safetensors", is_guided=True, + default_params=SamplingParams( + width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15 + ), ), } +def wrap_discretization( + discretization, image_strength=None, noise_strength=None, steps=None +): + if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( + discretization, Txt2NoisyDiscretizationWrapper + ): + return discretization # Already wrapped + if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) + + if ( + noise_strength is not None + and noise_strength < 1.0 + and noise_strength > 0.0 + and steps is not None + ): + discretization = Txt2NoisyDiscretizationWrapper( + discretization, + strength=noise_strength, + original_steps=steps, + ) + return discretization + + class SamplingPipeline: def __init__( self, - model_id: ModelArchitecture, - model_path="checkpoints", - config_path="configs/inference", - device="cuda", - use_fp16=True, + model_id: Optional[ModelArchitecture] = None, + model_spec: Optional[SamplingSpec] = None, + model_path: Optional[str] = None, + config_path: Optional[str] = None, + use_fp16: bool = True, + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ) -> None: - if model_id not in model_specs: - raise ValueError(f"Model {model_id} not supported") + """ + Sampling pipeline for generating images from a model. + + @param model_id: Model architecture to use. If not specified, model_spec must be specified. + @param model_spec: Model specification to use. If not specified, model_id must be specified. + @param model_path: Path to model checkpoints folder. + @param config_path: Path to model config folder. + @param use_fp16: Whether to use fp16 for sampling. + @param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible. + """ + self.model_id = model_id - self.specs = model_specs[self.model_id] - self.config = str(pathlib.Path(config_path, self.specs.config)) - self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) - self.device = device - self.model = self._load_model(device=device, use_fp16=use_fp16) + if model_spec is not None: + self.specs = model_spec + elif model_id is not None: + if model_id not in model_specs: + raise ValueError(f"Model {model_id} not supported") + self.specs = model_specs[model_id] + else: + raise ValueError("Either model_id or model_spec should be provided") + + if model_path is None: + model_path = get_checkpoints_path() + if config_path is None: + config_path = get_configs_path() + self.config = os.path.join(config_path, "inference", self.specs.config) + self.ckpt = os.path.join(model_path, self.specs.ckpt) + if not os.path.exists(self.config): + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) + if not os.path.exists(self.ckpt): + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) - def _load_model(self, device="cuda", use_fp16=True): + self.device_manager = get_model_manager(device) + + self.model = self._load_model( + device_manager=self.device_manager, use_fp16=use_fp16 + ) + + def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") - model.to(device) + device_manager.load(model) if use_fp16: model.conditioner.half() model.model.half() @@ -186,13 +276,34 @@ def _load_model(self, device="cuda", use_fp16=True): def text_to_image( self, - params: SamplingParams, prompt: str, + params: Optional[SamplingParams] = None, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, + noise_strength: Optional[float] = None, + filter=None, ): + if params is None: + params = self.specs.default_params + else: + # Set defaults if optional params are not specified + if params.width is None: + params.width = self.specs.default_params.width + if params.height is None: + params.height = self.specs.default_params.height + if params.steps is None: + params.steps = self.specs.default_params.steps + sampler = get_sampler_config(params) + + sampler.discretization = wrap_discretization( + sampler.discretization, + image_strength=None, + noise_strength=noise_strength, + steps=params.steps, + ) + value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt @@ -209,31 +320,40 @@ def text_to_image( self.specs.factor, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, - filter=None, + filter=filter, + device=self.device_manager, ) def image_to_image( self, - params: SamplingParams, - image, + image: torch.Tensor, prompt: str, + params: Optional[SamplingParams] = None, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, + noise_strength: Optional[float] = None, + filter=None, ): + if params is None: + params = self.specs.default_params sampler = get_sampler_config(params) - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) + sampler.discretization = wrap_discretization( + sampler.discretization, + image_strength=params.img2img_strength, + noise_strength=noise_strength, + steps=params.steps, + ) + height, width = image.shape[2], image.shape[3] value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = width value_dict["target_height"] = height + value_dict["orig_width"] = width + value_dict["orig_height"] = height return do_img2img( image, self.model, @@ -242,18 +362,24 @@ def image_to_image( samples, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, - filter=None, + filter=filter, + device=self.device_manager, ) def refiner( self, - params: SamplingParams, - image, + image: torch.Tensor, prompt: str, - negative_prompt: Optional[str] = None, + negative_prompt: str = "", + params: Optional[SamplingParams] = None, samples: int = 1, return_latents: bool = False, + filter: Any = None, + add_noise: bool = False, ): + if params is None: + params = self.specs.default_params + sampler = get_sampler_config(params) value_dict = { "orig_width": image.shape[3] * 8, @@ -268,6 +394,10 @@ def refiner( "negative_aesthetic_score": 2.5, } + sampler.discretization = wrap_discretization( + sampler.discretization, image_strength=params.img2img_strength + ) + return do_img2img( image, self.model, @@ -276,11 +406,14 @@ def refiner( samples, skip_encode=True, return_latents=return_latents, - filter=None, + filter=filter, + add_noise=add_noise, + device=self.device_manager, ) -def get_guider_config(params: SamplingParams): +def get_guider_config(params: SamplingParams) -> Dict[str, Any]: + guider_config: Dict[str, Any] if params.guider == Guider.IDENTITY: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" @@ -306,7 +439,8 @@ def get_guider_config(params: SamplingParams): return guider_config -def get_discretization_config(params: SamplingParams): +def get_discretization_config(params: SamplingParams) -> Dict[str, Any]: + discretization_config: Dict[str, Any] if params.discretization == Discretization.LEGACY_DDPM: discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 1c653708b..e84b8a270 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -1,3 +1,4 @@ +import contextlib import os from typing import Union, List, Optional @@ -8,7 +9,6 @@ from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig -from torch import autocast from sgm.util import append_dims @@ -58,6 +58,73 @@ def __call__(self, image: torch.Tensor): embed_watermark = WatermarkEmbedder(WATERMARK_BITS) +class DeviceModelManager(object): + """ + Default model loading class, should work for all device classes. + """ + + def __init__( + self, + device: Union[torch.device, str], + swap_device: Optional[Union[torch.device, str]] = None, + ) -> None: + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + self.device = torch.device(device) + self.swap_device = ( + torch.device(swap_device) if swap_device is not None else self.device + ) + + def load(self, model: torch.nn.Module) -> None: + """ + Loads a model to the (swap) device. + """ + model.to(self.swap_device) + + def autocast(self): + """ + Context manager that enables autocast for the device if supported. + """ + if self.device.type not in ("cuda", "cpu"): + return contextlib.nullcontext() + return torch.autocast(self.device.type) + + @contextlib.contextmanager + def use(self, model: torch.nn.Module): + """ + Context manager that ensures a model is on the correct device during use. + The default model loader does not perform any swapping, so the model will + stay on device. + """ + try: + model.to(self.device) + yield + finally: + if self.device != self.swap_device: + model.to(self.swap_device) + + +class CudaModelManager(DeviceModelManager): + """ + Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. + """ + + @contextlib.contextmanager + def use(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Context manager that ensures a model is on the correct device during use. + If a swap device was provided, this will move the model to it after use and clear cache. + """ + model.to(self.device) + yield + if self.device != self.swap_device: + model.to(self.swap_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) @@ -74,6 +141,20 @@ def perform_save_locally(save_path, samples): base_count += 1 +def get_model_manager( + device: Optional[Union[DeviceModelManager, str, torch.device]] +) -> DeviceModelManager: + if isinstance(device, DeviceModelManager): + return device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda": + return CudaModelManager(device=device) + else: + return DeviceModelManager(device=device) + + class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas @@ -98,6 +179,36 @@ def __call__(self, *args, **kwargs): return sigmas +class Txt2NoisyDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 0.0, original_steps=None): + self.discretization = discretization + self.strength = strength + self.original_steps = original_steps + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + if self.original_steps is None: + steps = len(sigmas) + else: + steps = self.original_steps + 1 + prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) + sigmas = sigmas[prune_index:] + print("prune index:", prune_index) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + + def do_sample( model, sampler, @@ -111,39 +222,45 @@ def do_sample( batch2model_input: Optional[List] = None, return_latents=False, filter=None, - device="cuda", + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] + device_manager = get_model_manager(device=device) + with torch.no_grad(): - with autocast(device) as precision_scope: + with device_manager.autocast(): with model.ema_scope(): num_samples = [num_samples] - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) + with device_manager.use(model.conditioner): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: if not k == "crossattn": c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) + lambda y: y[k][: math.prod(num_samples)].to( + device_manager.device + ), + (c, uc), ) additional_model_inputs = {} @@ -151,16 +268,20 @@ def do_sample( additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to(device) + randn = torch.randn(shape).to(device_manager.device) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + with device_manager.use(model.denoiser): + with device_manager.use(model.model): + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + + with device_manager.use(model.first_stage_model): + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) @@ -252,32 +373,40 @@ def do_img2img( return_latents=False, skip_encode=False, filter=None, - device="cuda", + add_noise=True, + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ): + device_manager = get_model_manager(device) with torch.no_grad(): - with autocast(device) as precision_scope: + with device_manager.autocast(): with model.ema_scope(): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) + with device_manager.use(model.conditioner): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) + c[k], uc[k] = map( + lambda y: y[k][:num_samples].to(device_manager.device), (c, uc) + ) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: - z = model.encode_first_stage(img) + with device_manager.use(model.first_stage_model): + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) @@ -285,17 +414,24 @@ def do_img2img( noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) - noised_z = z + noise * append_dims(sigma, z.ndim) - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + if add_noise: + noised_z = z + noise * append_dims(sigma, z.ndim).cuda() + noised_z = noised_z / torch.sqrt( + 1.0 + sigmas[0] ** 2.0 + ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + else: + noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + with device_manager.use(model.denoiser): + with device_manager.use(model.model): + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + + with device_manager.use(model.first_stage_model): + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) diff --git a/sgm/util.py b/sgm/util.py index c5e68f4b5..1f96aeb36 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -230,6 +230,24 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): return model +def get_checkpoints_path() -> str: + """ + Get the `checkpoints` directory. + This could be in the root of the repository for a working copy, + or in the cwd for other use cases. + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "checkpoints"), + os.path.join(os.getcwd(), "checkpoints"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}") + + def get_configs_path() -> str: """ Get the `configs` directory. diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 2b2af11e4..04eceb7a1 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -27,7 +27,7 @@ def pipeline(self, request) -> SamplingPipeline: @fixture( scope="class", params=[ - [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], + [ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER], [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], ], ids=["SDXL_V1", "SDXL_V0_9"], @@ -68,9 +68,7 @@ def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) - @pytest.mark.parametrize( - "use_init_image", [True, False], ids=["img2img", "txt2img"] - ) + @pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"]) def test_sdxl_with_refiner( self, sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], @@ -81,13 +79,12 @@ def test_sdxl_with_refiner( if use_init_image: output = base_pipeline.image_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), - image=self.create_init_image( - base_pipeline.specs.height, base_pipeline.specs.width - ), + image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, return_latents=True, + noise_strength=0.15, ) else: output = base_pipeline.text_to_image( @@ -96,6 +93,7 @@ def test_sdxl_with_refiner( negative_prompt="", samples=1, return_latents=True, + noise_strength=0.15, ) assert isinstance(output, (tuple, list)) @@ -103,9 +101,9 @@ def test_sdxl_with_refiner( assert samples is not None assert samples_z is not None refiner_pipeline.refiner( - params=SamplingParams(sampler=sampler_enum.value, steps=10), image=samples_z, prompt="A professional photograph of an astronaut riding a pig", + params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15), negative_prompt="", samples=1, ) diff --git a/tests/inference/test_modelmanager.py b/tests/inference/test_modelmanager.py new file mode 100644 index 000000000..bb1ab0e4b --- /dev/null +++ b/tests/inference/test_modelmanager.py @@ -0,0 +1,44 @@ +import pytest +import torch + +from sgm.inference.api import ( + SamplingPipeline, + ModelArchitecture, +) +import sgm.inference.helpers as helpers + +def get_torch_device(model: torch.nn.Module) -> torch.device: + param = next(model.parameters(), None) + if param is not None: + return param.device + else: + buf = next(model.buffers(), None) + if buf is not None: + return buf.device + else: + raise TypeError("Could not determine device of input model") + + +@pytest.mark.inference +def test_default_loading(): + pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1) + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + with pipeline.device_manager.use(pipeline.model.model): + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.model).type == "cuda" + with pipeline.device_manager.use(pipeline.model.conditioner): + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + +@pytest.mark.inference +def test_model_swapping(): + pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu")) + assert get_torch_device(pipeline.model.model).type == "cpu" + assert get_torch_device(pipeline.model.conditioner).type == "cpu" + with pipeline.device_manager.use(pipeline.model.model): + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.model).type == "cpu" + with pipeline.device_manager.use(pipeline.model.conditioner): + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cpu" \ No newline at end of file