diff --git a/README.md b/README.md index 8d7ddfb1a..8fa14135c 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stabl ## Installation: +**NOTE:** This is tested under `python3.8` and `python3.10`. For other Python versions, you might encounter version conflicts. + #### 1. Clone the repo ```shell @@ -60,28 +62,18 @@ cd generative-models This is assuming you have navigated to the `generative-models` root after cloning it. -**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts. - - -**PyTorch 1.13** - ```shell -# install required packages from pypi -python3 -m venv .pt13 -source .pt13/bin/activate -pip3 install -r requirements/pt13.txt +python3 -m venv venv +source venv/bin/activate +pip install -U setuptools wheel ``` -**PyTorch 2.0** - - -```shell -# install required packages from pypi -python3 -m venv .pt2 -source .pt2/bin/activate -pip3 install -r requirements/pt2.txt -``` +Then, depending on your use case, choose a set of requirements to install. +* `pip install -r requirements/demo-streamlit.txt`: Demo inference dependencies, enough to run the Streamlit demo +* `pip install -r requirements/demo-minimal.txt`: Demo inference dependencies, enough to run inference +* `pip install -r requirements/pt2.txt`: PyTorch 2, including training dependencies +* `pip install -r requirements/pt13.txt`: PyTorch 1.13, including training dependencies #### 3. Install `sgm` @@ -89,7 +81,7 @@ pip3 install -r requirements/pt2.txt pip3 install . ``` -#### 4. Install `sdata` for training +#### 4. Optionally install `sdata` for training ```shell pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata @@ -114,6 +106,16 @@ depending on your use case and PyTorch version, manually. ## Inference +### Minimal txt2img demo + +There is a minimal text-to-image demo available as `txt2img.py`: + +``` +python txt2img.py --prompt "Big fluffy cat in a cereal bowl" --steps 25 --seed 1050 +``` + +### Streamlit demo + We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that). The following models are currently supported: diff --git a/requirements/demo-minimal.txt b/requirements/demo-minimal.txt new file mode 100644 index 000000000..2fd21baf2 --- /dev/null +++ b/requirements/demo-minimal.txt @@ -0,0 +1,9 @@ +einops +invisible-watermark~=0.2.0 +kornia~=0.6.12 +omegaconf +open-clip-torch +pytorch-lightning~=2.0.5 +safetensors~=0.3.1 +torchvision~=0.15.2 +transformers~=4.31.0 diff --git a/requirements/demo-streamlit.txt b/requirements/demo-streamlit.txt new file mode 100644 index 000000000..0fba8e900 --- /dev/null +++ b/requirements/demo-streamlit.txt @@ -0,0 +1,3 @@ +-r ./demo-minimal.txt +-e git+https://github.com/openai/CLIP.git@main#egg=clip +streamlit diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 1c653708b..7d588b072 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -1,4 +1,5 @@ import os +from contextlib import nullcontext from typing import Union, List, Optional import math @@ -98,6 +99,13 @@ def __call__(self, *args, **kwargs): return sigmas +def safe_autocast(device): + """Autocast that doesn't crash on devices unsupported by autocast.""" + if device not in ("cpu", "cuda"): + return nullcontext() + return torch.autocast(device) + + def do_sample( model, sampler, @@ -119,13 +127,14 @@ def do_sample( batch2model_input = [] with torch.no_grad(): - with autocast(device) as precision_scope: + with safe_autocast(device): with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, + device=device, ) for key in batch: if isinstance(batch[key], torch.Tensor): @@ -170,7 +179,13 @@ def denoiser(input, sigma, c): return samples -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): +def get_batch( + keys, + value_dict, + N: Union[List, ListConfig], + *, + device: str, +): # Hardcoded demo setups; might undergo some changes in the future batch = {} @@ -255,12 +270,13 @@ def do_img2img( device="cuda", ): with torch.no_grad(): - with autocast(device) as precision_scope: + with safe_autocast(device): with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], + device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index e1f139757..e1e956ec7 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -1,4 +1,4 @@ -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Tuple, Union import pytorch_lightning as pl @@ -13,6 +13,7 @@ from ..util import ( default, disabled_train, + get_default_device_name, get_obj_from_str, instantiate_from_config, log_txt_as_img, @@ -114,16 +115,22 @@ def get_input(self, batch): # image tensors should be scaled to -1 ... 1 and in bchw format return batch[self.input_key] + def _first_stage_autocast_context(self): + device = get_default_device_name() + if device not in ("cpu", "cuda"): + return nullcontext() + return torch.autocast(device, enabled=not self.disable_first_stage_autocast) + @torch.no_grad() def decode_first_stage(self, z): z = 1.0 / self.scale_factor * z - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + with self._first_stage_autocast_context(): out = self.first_stage_model.decode(z) return out @torch.no_grad() def encode_first_stage(self, x): - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + with self._first_stage_autocast_context(): z = self.first_stage_model.encode(x) z = self.scale_factor * z return z diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index f813be233..3a65dddc1 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -1,3 +1,5 @@ +import warnings + import math from inspect import isfunction from typing import Any, Optional @@ -393,7 +395,7 @@ def __init__( super().__init__() assert attn_mode in self.ATTENTION_MODES if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: - print( + warnings.warn( f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" ) diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py index 26efd0784..550f916f7 100644 --- a/sgm/modules/diffusionmodules/model.py +++ b/sgm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import warnings import math from typing import Any, Callable, Optional @@ -288,6 +289,13 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" ) attn_type = "vanilla-xformers" + if attn_type == "vanilla-xformers" and not XFORMERS_IS_AVAILABLE: + warnings.warn( + f"Requested attention type {attn_type!r} but Xformers is not available; " + f"falling back to vanilla attention" + ) + attn_type = "vanilla" + attn_kwargs = None print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index e19b83f98..fbdf3337e 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -19,7 +19,7 @@ timestep_embedding, zero_module, ) -from ...util import default, exists +from ...util import default, exists, get_default_device_name # dummy replace @@ -1241,6 +1241,7 @@ def __init__(self, in_channels=3, model_channels=64): ] ) + device = get_default_device_name() model = UNetModel( use_checkpoint=True, image_size=64, @@ -1255,8 +1256,8 @@ def __init__(self, in_channels=3, model_channels=64): use_linear_in_transformer=True, transformer_depth=1, legacy=False, - ).cuda() - x = th.randn(11, 4, 64, 64).cuda() - t = th.randint(low=0, high=10, size=(11,), device="cuda") + ).to(device) + x = th.randn(11, 4, 64, 64).to(device) + t = th.randint(low=0, high=10, size=(11,), device=device) o = model(x, t) print("done.") diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py index 6346829c8..93a251541 100644 --- a/sgm/modules/diffusionmodules/sampling.py +++ b/sgm/modules/diffusionmodules/sampling.py @@ -16,7 +16,7 @@ to_neg_log_sigma, to_sigma, ) -from ...util import append_dims, default, instantiate_from_config +from ...util import append_dims, default, instantiate_from_config, get_default_device_name DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} @@ -28,8 +28,10 @@ def __init__( num_steps: Union[int, None] = None, guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, verbose: bool = False, - device: str = "cuda", + device: Union[str, None] = None, ): + if device is None: + device = get_default_device_name() self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) self.guider = instantiate_from_config( diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index ed3f2d215..dc03f40c0 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -29,6 +29,7 @@ default, disabled_train, expand_dims_like, + get_default_device_name, instantiate_from_config, ) @@ -236,7 +237,9 @@ def forward(self, c): c = c[:, None, :] return c - def get_unconditional_conditioning(self, bs, device="cuda"): + def get_unconditional_conditioning(self, bs, device=None): + if device is None: + device = get_default_device_name() uc_class = ( self.n_classes - 1 ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) @@ -261,9 +264,10 @@ class FrozenT5Embedder(AbstractEmbModel): """Uses the T5 transformer encoder for text""" def __init__( - self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True + self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() + device = device or get_default_device_name() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device @@ -304,9 +308,10 @@ class FrozenByT5Embedder(AbstractEmbModel): """ def __init__( - self, version="google/byt5-base", device="cuda", max_length=77, freeze=True + self, version="google/byt5-base", device=None, max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() + device = device or get_default_device_name() self.tokenizer = ByT5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device @@ -348,7 +353,7 @@ class FrozenCLIPEmbedder(AbstractEmbModel): def __init__( self, version="openai/clip-vit-large-patch14", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", @@ -356,6 +361,7 @@ def __init__( always_return_pooled=False, ): # clip-vit-base-patch32 super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) @@ -416,7 +422,7 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", @@ -424,6 +430,7 @@ def __init__( legacy=True, ): super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, @@ -518,12 +525,13 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", ): super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version @@ -588,7 +596,7 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, antialias=True, @@ -599,6 +607,7 @@ def __init__( output_tokens=False, ): super().__init__() + device = device or get_default_device_name() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), @@ -744,11 +753,12 @@ def __init__( self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", - device="cuda", + device=None, clip_max_length=77, t5_max_length=77, ): super().__init__() + device = device or get_default_device_name() self.clip_encoder = FrozenCLIPEmbedder( clip_version, device, max_length=clip_max_length ) diff --git a/sgm/util.py b/sgm/util.py index c5e68f4b5..7690ae34a 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import importlib import os @@ -11,6 +13,10 @@ from safetensors.torch import load_file as load_safetensors +def get_default_device_name() -> str: + return os.environ.get("SGM_DEFAULT_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") + + def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" @@ -199,19 +205,25 @@ def append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def load_model_from_config(config, ckpt, verbose=True, freeze=True): - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - +def load_model_from_config( + config, + ckpt: str | None, + verbose=True, + freeze=True, + device="cpu", +): model = instantiate_from_config(config.model) + if ckpt: + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location=device) + if verbose and "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt, device=device) + else: + raise NotImplementedError m, u = model.load_state_dict(sd, strict=False) @@ -226,7 +238,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): for param in model.parameters(): param.requires_grad = False - model.eval() return model diff --git a/txt2img.py b/txt2img.py new file mode 100644 index 000000000..4eac1da24 --- /dev/null +++ b/txt2img.py @@ -0,0 +1,171 @@ +""" +This is a very minimal txt2img example using `sgm.inference.api`. +""" +from __future__ import annotations + +import argparse +import dataclasses +import logging +import os +import time +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import torch +from PIL import Image +import einops +import omegaconf +import pytorch_lightning + +from sgm import get_configs_path +from sgm.inference.api import ( + model_specs, + ModelArchitecture, + SamplingParams, + SamplingSpec, + get_sampler_config, + Discretization, +) +from sgm.inference.helpers import do_sample +from sgm.util import load_model_from_config, get_default_device_name + +logger = logging.getLogger("txt2img") + + +def run_txt2img( + *, + model, + spec: SamplingSpec, + prompt: str, + steps: int, + width: int | None, + height: int | None, + scale: float | None, + num_samples=1, + seed: int, + device: str, +): + params = SamplingParams( + discretization=Discretization.EDM, + height=(height or spec.height), + rho=7, + steps=steps, + width=(width or spec.width), + ) + if scale: + params.scale = scale + + with torch.no_grad(), model.ema_scope(): + pytorch_lightning.seed_everything(seed) + sampler = get_sampler_config(params) + value_dict = { + **dataclasses.asdict(params), + "prompt": prompt, + "negative_prompt": "", + "target_width": params.width, + "target_height": params.height, + } + logger.info("Starting sampling with %s", params) + return do_sample( + model, + sampler, + value_dict, + num_samples, + params.height, + params.width, + spec.channels, + spec.factor, + force_uc_zero_embeddings=["txt"] if not spec.is_legacy else [], + return_latents=False, + filter=None, + device=device, + ) + + +@torch.no_grad() +def fast_load(*, config, ckpt, device): + config = omegaconf.OmegaConf.load(config) + logger.info("Loading model") + # This patch is borrowed from AUTOMATIC1111's stable-diffusion-webui; + # we don't need to initialize the weights just for them to be overwritten + # by the checkpoint. + with ( + patch.object(torch.nn.init, "kaiming_uniform_"), + patch.object(torch.nn.init, "_no_grad_normal_"), + patch.object(torch.nn.init, "_no_grad_uniform_"), + ): + model = load_model_from_config( + config, + ckpt=ckpt, + device="cpu", + freeze=True, + verbose=False, + ) + logger.info("Moving model to device") + model.to(device) + model.eval() + return model + + +def main(): + logging.basicConfig( + level=logging.INFO, format="[%(levelname)s] %(name)s: %(message)s" + ) + # Quiesce some uninformative CLIP and attention logging. + logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) + logging.getLogger("sgm.modules.attention").setLevel(logging.ERROR) + + ap = argparse.ArgumentParser() + ap.add_argument( + "--spec", + default=ModelArchitecture.SDXL_V1_BASE.value, + choices=[s.value for s in ModelArchitecture], + ) + ap.add_argument("--device", default=get_default_device_name()) + ap.add_argument( + "--prompt", + default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--steps", type=int, default=20) + ap.add_argument("--width", type=int) + ap.add_argument("--height", type=int) + ap.add_argument("--scale", type=float) + ap.add_argument("--num-samples", type=int, default=1) + args = ap.parse_args() + spec = model_specs[ModelArchitecture(args.spec)] + logger.info(f"Using model spec: {spec}") + model = fast_load( + config=os.path.join(get_configs_path(), "inference", spec.config), + ckpt=os.path.join("checkpoints", spec.ckpt), + device=args.device, + ) + + samples = run_txt2img( + model=model, + spec=spec, + prompt=args.prompt, + steps=args.steps, + width=args.width, + height=args.height, + scale=args.scale, + num_samples=args.num_samples, + device=args.device, + seed=args.seed, + ) + + out_path = Path("outputs") + out_path.mkdir(exist_ok=True) + + prefix = int(time.time()) + + for i, sample in enumerate(samples, 1): + filename = out_path / f"{prefix}-{i:04}.png" + print(f"Saving {i}/{len(samples)}: {filename}") + sample = 255.0 * einops.rearrange(sample, "c h w -> h w c") + Image.fromarray(sample.cpu().numpy().astype(np.uint8)).save(filename) + + +if __name__ == "__main__": + main()