14
14
from torch .utils .data import DataLoader
15
15
from torchvision import transforms
16
16
17
- from diffusion .datasets .laion .transforms import LargestCenterSquare , RandomCropAspectRatioTransform , RandomCropSquare
17
+ from diffusion .datasets .laion .transforms import (LargestCenterSquare , RandomCropAspectRatioTransform ,
18
+ RandomCropBucketedAspectRatioTransform , RandomCropSquare )
18
19
from diffusion .datasets .utils import make_streams
19
20
20
21
log = logging .getLogger (__name__ )
@@ -32,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
32
33
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
33
34
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
34
35
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
36
+ aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
35
37
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
36
38
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
37
39
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
@@ -40,6 +42,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
40
42
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
41
43
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
42
44
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``.
45
+ drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``.
43
46
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
44
47
"""
45
48
@@ -53,10 +56,12 @@ def __init__(
53
56
image_key : str = 'image' ,
54
57
caption_keys : Tuple [str , ...] = ('caption' ,),
55
58
caption_selection_probs : Tuple [float , ...] = (1.0 ,),
59
+ aspect_ratio_bucket_key : Optional [str ] = None ,
56
60
text_latent_keys : Tuple [str , ...] = ('T5_LATENTS' , 'CLIP_LATENTS' ),
57
61
text_latent_shapes : Tuple [Tuple [int , int ], ...] = ((512 , 4096 ), (77 , 768 )),
58
62
attention_mask_keys : Tuple [str , ...] = ('T5_ATTENTION_MASK' , 'CLIP_ATTENTION_MASK' ),
59
63
latent_dtype : torch .dtype = torch .bfloat16 ,
64
+ drop_nans : bool = True ,
60
65
** streaming_kwargs ,
61
66
):
62
67
@@ -72,10 +77,14 @@ def __init__(
72
77
self .image_key = image_key
73
78
self .caption_keys = caption_keys
74
79
self .caption_selection_probs = caption_selection_probs
80
+ self .aspect_ratio_bucket_key = aspect_ratio_bucket_key
81
+ if isinstance (self .crop , RandomCropBucketedAspectRatioTransform ):
82
+ assert self .aspect_ratio_bucket_key is not None , 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
75
83
self .text_latent_keys = text_latent_keys
76
84
self .text_latent_shapes = text_latent_shapes
77
85
self .attention_mask_keys = attention_mask_keys
78
86
self .latent_dtype = latent_dtype
87
+ self .drop_nans = drop_nans
79
88
80
89
def __getitem__ (self , index ):
81
90
sample = super ().__getitem__ (index )
@@ -90,15 +99,16 @@ def __getitem__(self, index):
90
99
out ['cond_original_size' ] = torch .tensor (img .size )
91
100
92
101
# Image transforms
93
- if self .crop is not None :
102
+ if isinstance (self .crop , RandomCropBucketedAspectRatioTransform ):
103
+ img , crop_top , crop_left = self .crop (img , sample [self .aspect_ratio_bucket_key ])
104
+ elif self .crop is not None :
94
105
img , crop_top , crop_left = self .crop (img )
95
106
else :
96
107
crop_top , crop_left = 0 , 0
97
- out ['cond_crops_coords_top_left' ] = torch .tensor ([crop_top , crop_left ])
98
-
99
108
if self .transform is not None :
100
109
img = self .transform (img )
101
110
out ['image' ] = img
111
+ out ['cond_crops_coords_top_left' ] = torch .tensor ([crop_top , crop_left ])
102
112
103
113
# Get the new height and width
104
114
if isinstance (img , torch .Tensor ):
@@ -140,6 +150,13 @@ def __getitem__(self, index):
140
150
if 'CLIP_LATENTS' in latent_key :
141
151
clip_pooled = np .frombuffer (sample [f'{ caption_key } _CLIP_POOLED_TEXT' ], dtype = np .float32 ).copy ()
142
152
out ['CLIP_POOLED' ] = torch .from_numpy (clip_pooled ).to (self .latent_dtype ).reshape (latent_shape [1 ])
153
+ if self .drop_nans :
154
+ for latent_key , attn_key in zip (self .text_latent_keys , self .attention_mask_keys ):
155
+ if out [latent_key ].isnan ().any ():
156
+ out [latent_key ] = torch .zeros_like (out [latent_key ])
157
+ out [attn_key ] = torch .zeros_like (out [attn_key ])
158
+ if 'CLIP_LATENTS' in latent_key and out ['CLIP_POOLED' ].isnan ().any ():
159
+ out ['CLIP_POOLED' ] = torch .zeros_like (out ['CLIP_POOLED' ])
143
160
return out
144
161
145
162
@@ -160,6 +177,7 @@ def build_streaming_image_caption_latents_dataloader(
160
177
text_latent_shapes : Tuple [Tuple , ...] = ((512 , 4096 ), (77 , 768 )),
161
178
attention_mask_keys : Tuple [str , ...] = ('T5_ATTENTION_MASK' , 'CLIP_ATTENTION_MASK' ),
162
179
latent_dtype : str = 'torch.bfloat16' ,
180
+ aspect_ratio_bucket_key : Optional [str ] = None ,
163
181
streaming_kwargs : Optional [Dict ] = None ,
164
182
dataloader_kwargs : Optional [Dict ] = None ,
165
183
):
@@ -178,11 +196,12 @@ def build_streaming_image_caption_latents_dataloader(
178
196
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
179
197
Default: ``None``.
180
198
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
181
- crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
199
+ crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio' ].
182
200
Default: ``'square'``.
183
201
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
184
202
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
185
203
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
204
+ aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
186
205
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
187
206
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
188
207
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
@@ -192,18 +211,22 @@ def build_streaming_image_caption_latents_dataloader(
192
211
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
193
212
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
194
213
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
214
+ aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset.
215
+ Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``.
195
216
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
196
217
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
197
218
"""
198
219
# Check crop type
199
220
if crop_type is not None :
200
221
crop_type = crop_type .lower ()
201
- if crop_type not in ['square' , 'random' , 'aspect_ratio' ]:
202
- raise ValueError (f'Invalid crop_type: { crop_type } . Must be ["square", "random", "aspect_ratio", None]' )
203
- if crop_type == 'aspect_ratio' and (isinstance (resize_size , int ) or isinstance (resize_size [0 ], int )):
222
+ if crop_type not in ['square' , 'random' , 'aspect_ratio' , 'bucketed_aspect_ratio' ]:
204
223
raise ValueError (
205
- 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.' )
206
-
224
+ f'Invalid crop_type: { crop_type } . Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
225
+ )
226
+ if crop_type in ['aspect_ratio' , 'bucketed_aspect_ratio' ] and (isinstance (resize_size , int ) or
227
+ isinstance (resize_size [0 ], int )):
228
+ raise ValueError (
229
+ 'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.' )
207
230
# Check latent dtype
208
231
dtypes = {'torch.float16' : torch .float16 , 'torch.float32' : torch .float32 , 'torch.bfloat16' : torch .bfloat16 }
209
232
assert latent_dtype in dtypes , f'Invalid latent_dtype: { latent_dtype } . Must be one of { list (dtypes .keys ())} '
@@ -225,6 +248,9 @@ def build_streaming_image_caption_latents_dataloader(
225
248
crop = RandomCropSquare (resize_size )
226
249
elif crop_type == 'aspect_ratio' :
227
250
crop = RandomCropAspectRatioTransform (resize_size , ar_bucket_boundaries ) # type: ignore
251
+ elif crop_type == 'bucketed_aspect_ratio' :
252
+ assert aspect_ratio_bucket_key is not None , 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
253
+ crop = RandomCropBucketedAspectRatioTransform (resize_size ) # type: ignore
228
254
else :
229
255
crop = None
230
256
@@ -242,6 +268,7 @@ def build_streaming_image_caption_latents_dataloader(
242
268
image_key = image_key ,
243
269
caption_keys = caption_keys ,
244
270
caption_selection_probs = caption_selection_probs ,
271
+ aspect_ratio_bucket_key = aspect_ratio_bucket_key ,
245
272
text_latent_keys = text_latent_keys ,
246
273
text_latent_shapes = text_latent_shapes ,
247
274
attention_mask_keys = attention_mask_keys ,
0 commit comments