Skip to content

Commit 8afad24

Browse files
authored
Merge pull request #1420 from bghira/bugfix/video-model-image-training
When running sample transforms, the dataset_type should be considered so that we do not run video transforms on image
2 parents 1a5a65b + 35940a7 commit 8afad24

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

helpers/caching/vae.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
self.vae_batch_size = vae_batch_size
117117
self.instance_data_dir = instance_data_dir
118118
self.model = model
119-
self.transform_sample = model.get_transforms()
119+
self.transform_sample = model.get_transforms(dataset_type=dataset_type)
120120
self.num_video_frames = None
121121
if self.dataset_type == "video":
122122
self.num_video_frames = num_video_frames
@@ -894,13 +894,19 @@ def _encode_images_in_batch(
894894
count_to_process = min(qlen, self.vae_batch_size)
895895
for idx in range(0, count_to_process):
896896
if image_pixel_values:
897-
pixel_values, filepath, aspect_bucket, is_final_sample = (
898-
image_pixel_values.pop()
899-
)
897+
(
898+
pixel_values,
899+
filepath,
900+
aspect_bucket,
901+
is_final_sample,
902+
) = image_pixel_values.pop()
900903
else:
901-
pixel_values, filepath, aspect_bucket, is_final_sample = (
902-
self.vae_input_queue.get()
903-
)
904+
(
905+
pixel_values,
906+
filepath,
907+
aspect_bucket,
908+
is_final_sample,
909+
) = self.vae_input_queue.get()
904910

905911
if batch_aspect_bucket is None:
906912
batch_aspect_bucket = aspect_bucket

helpers/models/common.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import inspect
77
import os
8+
from torchvision import transforms
89
from diffusers import DiffusionPipeline
910
from torch.distributions import Beta
1011
from helpers.training.wrappers import unwrap_model
@@ -250,10 +251,16 @@ def get_flavour_choices(cls):
250251
"""
251252
return list(cls.HUGGINGFACE_PATHS.keys())
252253

253-
def get_transforms(self):
254+
def get_transforms(self, dataset_type: str = "image"):
254255
"""
255256
Returns nothing, but subclasses can implement different torchvision transforms as needed.
257+
258+
dataset_type is passed in for models that support transforming videos or images etc.
256259
"""
260+
if dataset_type in ["video"]:
261+
raise ValueError(
262+
f"{dataset_type} transforms are not supported by {self.NAME}."
263+
)
257264
return transforms.Compose(
258265
[
259266
transforms.ToTensor(),
@@ -1328,12 +1335,10 @@ def __init__(self, config, accelerator):
13281335
# }
13291336
# The trainer or child class might call self._init_text_encoders() at the right time.
13301337

1331-
def get_transforms(self):
1332-
from torchvision import transforms
1333-
1338+
def get_transforms(self, dataset_type: str = "image"):
13341339
return transforms.Compose(
13351340
[
1336-
VideoToTensor(),
1341+
VideoToTensor() if dataset_type == "video" else transforms.ToTensor(),
13371342
]
13381343
)
13391344

0 commit comments

Comments
 (0)