Skip to content

Minimal txt2img, v2 #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stabl
## Installation:
<a name="installation"></a>

**NOTE:** This is tested under `python3.8` and `python3.10`. For other Python versions, you might encounter version conflicts.

#### 1. Clone the repo

```shell
Expand All @@ -60,36 +62,26 @@ cd generative-models

This is assuming you have navigated to the `generative-models` root after cloning it.

**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.


**PyTorch 1.13**

```shell
# install required packages from pypi
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
python3 -m venv venv
source venv/bin/activate
pip install -U setuptools wheel
```

**PyTorch 2.0**


```shell
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install -r requirements/pt2.txt
```
Then, depending on your use case, choose a set of requirements to install.

* `pip install -r requirements/demo-streamlit.txt`: Demo inference dependencies, enough to run the Streamlit demo
* `pip install -r requirements/demo-minimal.txt`: Demo inference dependencies, enough to run inference
* `pip install -r requirements/pt2.txt`: PyTorch 2, including training dependencies
* `pip install -r requirements/pt13.txt`: PyTorch 1.13, including training dependencies

#### 3. Install `sgm`

```shell
pip3 install .
```

#### 4. Install `sdata` for training
#### 4. Optionally install `sdata` for training

```shell
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
Expand All @@ -114,6 +106,16 @@ depending on your use case and PyTorch version, manually.

## Inference

### Minimal txt2img demo

There is a minimal text-to-image demo available as `txt2img.py`:

```
python txt2img.py --prompt "Big fluffy cat in a cereal bowl" --steps 25 --seed 1050
```

### Streamlit demo

We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
The following models are currently supported:
Expand Down
9 changes: 9 additions & 0 deletions requirements/demo-minimal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
einops
invisible-watermark~=0.2.0
kornia~=0.6.12
omegaconf
open-clip-torch
pytorch-lightning~=2.0.5
safetensors~=0.3.1
torchvision~=0.15.2
transformers~=4.31.0
3 changes: 3 additions & 0 deletions requirements/demo-streamlit.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r ./demo-minimal.txt
-e git+https://github.com/openai/CLIP.git@main#egg=clip
streamlit
22 changes: 19 additions & 3 deletions sgm/inference/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from contextlib import nullcontext
from typing import Union, List, Optional

import math
Expand Down Expand Up @@ -98,6 +99,13 @@ def __call__(self, *args, **kwargs):
return sigmas


def safe_autocast(device):
"""Autocast that doesn't crash on devices unsupported by autocast."""
if device not in ("cpu", "cuda"):
return nullcontext()
return torch.autocast(device)


def do_sample(
model,
sampler,
Expand All @@ -119,13 +127,14 @@ def do_sample(
batch2model_input = []

with torch.no_grad():
with autocast(device) as precision_scope:
with safe_autocast(device):
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
device=device,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
Expand Down Expand Up @@ -170,7 +179,13 @@ def denoiser(input, sigma, c):
return samples


def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
def get_batch(
keys,
value_dict,
N: Union[List, ListConfig],
*,
device: str,
):
# Hardcoded demo setups; might undergo some changes in the future

batch = {}
Expand Down Expand Up @@ -255,12 +270,13 @@ def do_img2img(
device="cuda",
):
with torch.no_grad():
with autocast(device) as precision_scope:
with safe_autocast(device):
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
device=device,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
Expand Down
13 changes: 10 additions & 3 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Tuple, Union

import pytorch_lightning as pl
Expand All @@ -13,6 +13,7 @@
from ..util import (
default,
disabled_train,
get_default_device_name,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
Expand Down Expand Up @@ -114,16 +115,22 @@ def get_input(self, batch):
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]

def _first_stage_autocast_context(self):
device = get_default_device_name()
if device not in ("cpu", "cuda"):
return nullcontext()
return torch.autocast(device, enabled=not self.disable_first_stage_autocast)

@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
with self._first_stage_autocast_context():
out = self.first_stage_model.decode(z)
return out

@torch.no_grad()
def encode_first_stage(self, x):
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
with self._first_stage_autocast_context():
z = self.first_stage_model.encode(x)
z = self.scale_factor * z
return z
Expand Down
4 changes: 3 additions & 1 deletion sgm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import math
from inspect import isfunction
from typing import Any, Optional
Expand Down Expand Up @@ -393,7 +395,7 @@ def __init__(
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
warnings.warn(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
Expand Down
8 changes: 8 additions & 0 deletions sgm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder
import warnings
import math
from typing import Any, Callable, Optional

Expand Down Expand Up @@ -288,6 +289,13 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
if attn_type == "vanilla-xformers" and not XFORMERS_IS_AVAILABLE:
warnings.warn(
f"Requested attention type {attn_type!r} but Xformers is not available; "
f"falling back to vanilla attention"
)
attn_type = "vanilla"
attn_kwargs = None
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
Expand Down
9 changes: 5 additions & 4 deletions sgm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
timestep_embedding,
zero_module,
)
from ...util import default, exists
from ...util import default, exists, get_default_device_name


# dummy replace
Expand Down Expand Up @@ -1241,6 +1241,7 @@ def __init__(self, in_channels=3, model_channels=64):
]
)

device = get_default_device_name()
model = UNetModel(
use_checkpoint=True,
image_size=64,
Expand All @@ -1255,8 +1256,8 @@ def __init__(self, in_channels=3, model_channels=64):
use_linear_in_transformer=True,
transformer_depth=1,
legacy=False,
).cuda()
x = th.randn(11, 4, 64, 64).cuda()
t = th.randint(low=0, high=10, size=(11,), device="cuda")
).to(device)
x = th.randn(11, 4, 64, 64).to(device)
t = th.randint(low=0, high=10, size=(11,), device=device)
o = model(x, t)
print("done.")
6 changes: 4 additions & 2 deletions sgm/modules/diffusionmodules/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
to_neg_log_sigma,
to_sigma,
)
from ...util import append_dims, default, instantiate_from_config
from ...util import append_dims, default, instantiate_from_config, get_default_device_name

DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}

Expand All @@ -28,8 +28,10 @@ def __init__(
num_steps: Union[int, None] = None,
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
verbose: bool = False,
device: str = "cuda",
device: Union[str, None] = None,
):
if device is None:
device = get_default_device_name()
self.num_steps = num_steps
self.discretization = instantiate_from_config(discretization_config)
self.guider = instantiate_from_config(
Expand Down
26 changes: 18 additions & 8 deletions sgm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
default,
disabled_train,
expand_dims_like,
get_default_device_name,
instantiate_from_config,
)

Expand Down Expand Up @@ -236,7 +237,9 @@ def forward(self, c):
c = c[:, None, :]
return c

def get_unconditional_conditioning(self, bs, device="cuda"):
def get_unconditional_conditioning(self, bs, device=None):
if device is None:
device = get_default_device_name()
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
Expand All @@ -261,9 +264,10 @@ class FrozenT5Embedder(AbstractEmbModel):
"""Uses the T5 transformer encoder for text"""

def __init__(
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
device = device or get_default_device_name()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
Expand Down Expand Up @@ -304,9 +308,10 @@ class FrozenByT5Embedder(AbstractEmbModel):
"""

def __init__(
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
self, version="google/byt5-base", device=None, max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
device = device or get_default_device_name()
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
Expand Down Expand Up @@ -348,14 +353,15 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
always_return_pooled=False,
): # clip-vit-base-patch32
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -416,14 +422,15 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
always_return_pooled=False,
legacy=True,
):
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch,
Expand Down Expand Up @@ -518,12 +525,13 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
):
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
Expand Down Expand Up @@ -588,7 +596,7 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
antialias=True,
Expand All @@ -599,6 +607,7 @@ def __init__(
output_tokens=False,
):
super().__init__()
device = device or get_default_device_name()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
Expand Down Expand Up @@ -744,11 +753,12 @@ def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
device=None,
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
device = device or get_default_device_name()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
Expand Down
Loading