Skip to content

Commit afa6c66

Browse files
authored
Add composer model class for running with precomputed CLIP and T5 text latents (#171)
1 parent 4d6e4aa commit afa6c66

File tree

5 files changed

+886
-24
lines changed

5 files changed

+886
-24
lines changed

diffusion/callbacks/log_diffusion_images.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ class LogDiffusionImages(Callback):
3939
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
4040
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
4141
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
42+
t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``.
43+
t5_mask_key: (str): key to use for the T5 attention mask in the batch. Default: ``'T5_ATTENTION_MASK'``.
44+
clip_latent_key: (str): key to use for the CLIP latents in the batch. Default: ``'CLIP_LATENTS'``.
45+
clip_mask_key: (str): key to use for the CLIP attention mask in the batch. Default: ``'CLIP_ATTENTION_MASK'``.
46+
clip_pooled_key: (str): key to use for the CLIP pooled in the batch. Default: ``'CLIP_POOLED'``.
4247
cache_dir: (str, optional): path for HF to cache files while downloading model
4348
"""
4449

@@ -53,6 +58,11 @@ def __init__(self,
5358
use_table: bool = False,
5459
t5_encoder: Optional[str] = None,
5560
clip_encoder: Optional[str] = None,
61+
t5_latent_key: str = 'T5_LATENTS',
62+
t5_mask_key: str = 'T5_ATTENTION_MASK',
63+
clip_latent_key: str = 'CLIP_LATENTS',
64+
clip_mask_key: str = 'CLIP_ATTENTION_MASK',
65+
clip_pooled_key: str = 'CLIP_POOLED',
5666
cache_dir: Optional[str] = '/tmp/hf_files'):
5767
self.prompts = prompts
5868
self.size = (size, size) if isinstance(size, int) else size
@@ -61,6 +71,11 @@ def __init__(self,
6171
self.rescaled_guidance = rescaled_guidance
6272
self.seed = seed
6373
self.use_table = use_table
74+
self.t5_latent_key = t5_latent_key
75+
self.t5_mask_key = t5_mask_key
76+
self.clip_latent_key = clip_latent_key
77+
self.clip_mask_key = clip_mask_key
78+
self.clip_pooled_key = clip_pooled_key
6479
self.cache_dir = cache_dir
6580

6681
# Batch prompts
@@ -120,10 +135,11 @@ def __init__(self,
120135
clip_pooled = clip_outputs[1].cpu()
121136
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)
122137

123-
latent_batch['T5_LATENTS'] = t5_latents
124-
latent_batch['CLIP_LATENTS'] = clip_latents
125-
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
126-
latent_batch['CLIP_POOLED'] = clip_pooled
138+
latent_batch[self.t5_latent_key] = t5_latents
139+
latent_batch[self.t5_mask_key] = t5_attention_mask
140+
latent_batch[self.clip_latent_key] = clip_latents
141+
latent_batch[self.clip_mask_key] = clip_attention_mask
142+
latent_batch[self.clip_pooled_key] = clip_pooled
127143
self.batched_latents.append(latent_batch)
128144

129145
del t5_model
@@ -143,12 +159,11 @@ def eval_start(self, state: State, logger: Logger):
143159
all_gen_images = []
144160
if self.precomputed_latents:
145161
for batch in self.batched_latents:
146-
pooled_prompt = batch['CLIP_POOLED'].cuda()
147-
prompt_mask = batch['ATTENTION_MASK'].cuda()
148-
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
149-
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
150-
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)
151-
162+
pooled_prompt = batch[self.clip_pooled_key].cuda()
163+
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
164+
batch[self.clip_latent_key].cuda(),
165+
batch[self.t5_mask_key].cuda(),
166+
batch[self.clip_mask_key].cuda())
152167
gen_images = model.generate(prompt_embeds=prompt_embeds,
153168
pooled_prompt=pooled_prompt,
154169
prompt_mask=prompt_mask,

diffusion/datasets/image_caption_latents.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from torch.utils.data import DataLoader
1515
from torchvision import transforms
1616

17-
from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare
17+
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
18+
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
1819
from diffusion.datasets.utils import make_streams
1920

2021
log = logging.getLogger(__name__)
@@ -32,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
3233
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
3334
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
3435
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
36+
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
3537
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
3638
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
3739
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
@@ -40,6 +42,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
4042
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
4143
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
4244
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``.
45+
drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``.
4346
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
4447
"""
4548

@@ -53,10 +56,12 @@ def __init__(
5356
image_key: str = 'image',
5457
caption_keys: Tuple[str, ...] = ('caption',),
5558
caption_selection_probs: Tuple[float, ...] = (1.0,),
59+
aspect_ratio_bucket_key: Optional[str] = None,
5660
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
5761
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
5862
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
5963
latent_dtype: torch.dtype = torch.bfloat16,
64+
drop_nans: bool = True,
6065
**streaming_kwargs,
6166
):
6267

@@ -72,10 +77,14 @@ def __init__(
7277
self.image_key = image_key
7378
self.caption_keys = caption_keys
7479
self.caption_selection_probs = caption_selection_probs
80+
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
81+
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
82+
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
7583
self.text_latent_keys = text_latent_keys
7684
self.text_latent_shapes = text_latent_shapes
7785
self.attention_mask_keys = attention_mask_keys
7886
self.latent_dtype = latent_dtype
87+
self.drop_nans = drop_nans
7988

8089
def __getitem__(self, index):
8190
sample = super().__getitem__(index)
@@ -90,15 +99,16 @@ def __getitem__(self, index):
9099
out['cond_original_size'] = torch.tensor(img.size)
91100

92101
# Image transforms
93-
if self.crop is not None:
102+
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
103+
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
104+
elif self.crop is not None:
94105
img, crop_top, crop_left = self.crop(img)
95106
else:
96107
crop_top, crop_left = 0, 0
97-
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])
98-
99108
if self.transform is not None:
100109
img = self.transform(img)
101110
out['image'] = img
111+
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])
102112

103113
# Get the new height and width
104114
if isinstance(img, torch.Tensor):
@@ -140,6 +150,13 @@ def __getitem__(self, index):
140150
if 'CLIP_LATENTS' in latent_key:
141151
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy()
142152
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1])
153+
if self.drop_nans:
154+
for latent_key, attn_key in zip(self.text_latent_keys, self.attention_mask_keys):
155+
if out[latent_key].isnan().any():
156+
out[latent_key] = torch.zeros_like(out[latent_key])
157+
out[attn_key] = torch.zeros_like(out[attn_key])
158+
if 'CLIP_LATENTS' in latent_key and out['CLIP_POOLED'].isnan().any():
159+
out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED'])
143160
return out
144161

145162

@@ -160,6 +177,7 @@ def build_streaming_image_caption_latents_dataloader(
160177
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
161178
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
162179
latent_dtype: str = 'torch.bfloat16',
180+
aspect_ratio_bucket_key: Optional[str] = None,
163181
streaming_kwargs: Optional[Dict] = None,
164182
dataloader_kwargs: Optional[Dict] = None,
165183
):
@@ -178,11 +196,12 @@ def build_streaming_image_caption_latents_dataloader(
178196
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
179197
Default: ``None``.
180198
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
181-
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
199+
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio'].
182200
Default: ``'square'``.
183201
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
184202
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
185203
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
204+
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
186205
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
187206
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
188207
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
@@ -192,18 +211,22 @@ def build_streaming_image_caption_latents_dataloader(
192211
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
193212
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
194213
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
214+
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset.
215+
Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``.
195216
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
196217
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
197218
"""
198219
# Check crop type
199220
if crop_type is not None:
200221
crop_type = crop_type.lower()
201-
if crop_type not in ['square', 'random', 'aspect_ratio']:
202-
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
203-
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
222+
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']:
204223
raise ValueError(
205-
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')
206-
224+
f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
225+
)
226+
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or
227+
isinstance(resize_size[0], int)):
228+
raise ValueError(
229+
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.')
207230
# Check latent dtype
208231
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16}
209232
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}'
@@ -225,6 +248,9 @@ def build_streaming_image_caption_latents_dataloader(
225248
crop = RandomCropSquare(resize_size)
226249
elif crop_type == 'aspect_ratio':
227250
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
251+
elif crop_type == 'bucketed_aspect_ratio':
252+
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
253+
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
228254
else:
229255
crop = None
230256

@@ -242,6 +268,7 @@ def build_streaming_image_caption_latents_dataloader(
242268
image_key=image_key,
243269
caption_keys=caption_keys,
244270
caption_selection_probs=caption_selection_probs,
271+
aspect_ratio_bucket_key=aspect_ratio_bucket_key,
245272
text_latent_keys=text_latent_keys,
246273
text_latent_shapes=text_latent_shapes,
247274
attention_mask_keys=attention_mask_keys,

diffusion/models/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
"""Diffusion models."""
55

66
from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
7-
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl,
8-
text_to_image_transformer)
7+
discrete_pixel_diffusion, precomputed_text_latent_diffusion, stable_diffusion_2,
8+
stable_diffusion_xl, text_to_image_transformer)
99
from diffusion.models.noop import NoOpModel
1010
from diffusion.models.pixel_diffusion import PixelDiffusion
11+
from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
1112
from diffusion.models.stable_diffusion import StableDiffusion
1213

1314
__all__ = [
@@ -17,8 +18,10 @@
1718
'discrete_pixel_diffusion',
1819
'NoOpModel',
1920
'PixelDiffusion',
21+
'precomputed_text_latent_diffusion',
2022
'stable_diffusion_2',
2123
'stable_diffusion_xl',
2224
'StableDiffusion',
25+
'PrecomputedTextLatentDiffusion',
2326
'text_to_image_transformer',
2427
]

0 commit comments

Comments
 (0)