Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 9bd745e

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Move Dataloader Wrappers to OSS (#455)
Summary: Pull Request resolved: #455 This will be helpful for OSS users who implement their own Iterable datasets. Reviewed By: vreis Differential Revision: D20605900 fbshipit-source-id: 9243ab7be3e47e55f4e2ae6a7bdb3b27ae10be92
1 parent ffbad54 commit 9bd745e

6 files changed

+308
-0
lines changed

classy_vision/dataset/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,30 @@ def register_dataset_cls(cls):
7272
from .classy_imagenet import ImageNetDataset # isort:skip
7373
from .classy_kinetics400 import Kinetics400Dataset # isort:skip
7474
from .classy_synthetic_image import SyntheticImageDataset # isort:skip
75+
from .classy_synthetic_image_streaming import ( # isort:skip
76+
SyntheticImageStreamingDataset, # isort:skip
77+
) # isort:skip
7578
from .classy_synthetic_video import SyntheticVideoDataset # isort:skip
7679
from .classy_ucf101 import UCF101Dataset # isort:skip
7780
from .classy_video_dataset import ClassyVideoDataset # isort:skip
81+
from .dataloader_limit_wrapper import DataloaderLimitWrapper # isort:skip
82+
from .dataloader_skip_none_wrapper import DataloaderSkipNoneWrapper # isort:skip
83+
from .dataloader_wrapper import DataloaderWrapper # isort:skip
7884
from .image_path_dataset import ImagePathDataset # isort:skip
7985

8086
__all__ = [
8187
"CIFARDataset",
8288
"ClassyDataset",
8389
"ClassyVideoDataset",
90+
"DataloaderLimitWrapper",
91+
"DataloaderSkipNoneWrapper",
92+
"DataloaderWrapper",
8493
"HMDB51Dataset",
8594
"ImageNetDataset",
8695
"ImagePathDataset",
8796
"Kinetics400Dataset",
8897
"SyntheticImageDataset",
98+
"SyntheticImageStreamingDataset",
8999
"SyntheticVideoDataset",
90100
"UCF101Dataset",
91101
"build_dataset",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torchvision.transforms as transforms
9+
from classy_vision.dataset import register_dataset
10+
from classy_vision.dataset.classy_dataset import ClassyDataset
11+
from classy_vision.dataset.core import RandomImageBinaryClassDataset
12+
from classy_vision.dataset.dataloader_limit_wrapper import DataloaderLimitWrapper
13+
from classy_vision.dataset.transforms.util import (
14+
ImagenetConstants,
15+
build_field_transform_default_imagenet,
16+
)
17+
18+
19+
@register_dataset("synthetic_image_streaming")
20+
class SyntheticImageStreamingDataset(ClassyDataset):
21+
"""
22+
Synthetic image dataset that behaves like a streaming dataset.
23+
24+
Requires a "num_samples" argument which decides the number of samples in the
25+
phase. Also takes an optional "length" input which sets the length of the
26+
dataset.
27+
"""
28+
29+
def __init__(
30+
self,
31+
batchsize_per_replica,
32+
shuffle,
33+
transform,
34+
num_samples,
35+
crop_size,
36+
class_ratio,
37+
seed,
38+
length=None,
39+
):
40+
if length is None:
41+
# If length not provided, set to be same as num_samples
42+
length = num_samples
43+
44+
dataset = RandomImageBinaryClassDataset(crop_size, class_ratio, length, seed)
45+
super().__init__(
46+
dataset, batchsize_per_replica, shuffle, transform, num_samples
47+
)
48+
49+
@classmethod
50+
def from_config(cls, config):
51+
assert all(key in config for key in ["crop_size", "class_ratio", "seed"])
52+
length = config.get("length")
53+
crop_size = config["crop_size"]
54+
class_ratio = config["class_ratio"]
55+
seed = config["seed"]
56+
(
57+
transform_config,
58+
batchsize_per_replica,
59+
shuffle,
60+
num_samples,
61+
) = cls.parse_config(config)
62+
default_transform = transforms.Compose(
63+
[
64+
transforms.ToTensor(),
65+
transforms.Normalize(
66+
mean=ImagenetConstants.MEAN, std=ImagenetConstants.STD
67+
),
68+
]
69+
)
70+
transform = build_field_transform_default_imagenet(
71+
transform_config, default_transform=default_transform
72+
)
73+
return cls(
74+
batchsize_per_replica,
75+
shuffle,
76+
transform,
77+
num_samples,
78+
crop_size,
79+
class_ratio,
80+
seed,
81+
length=length,
82+
)
83+
84+
def iterator(self, *args, **kwargs):
85+
return DataloaderLimitWrapper(
86+
super().iterator(*args, **kwargs),
87+
self.num_samples // self.get_global_batchsize(),
88+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from typing import Any, Iterable, Iterator
9+
10+
from .dataloader_wrapper import DataloaderWrapper
11+
12+
13+
class DataloaderLimitWrapper(DataloaderWrapper):
14+
"""
15+
Dataloader which wraps another dataloader and only returns a limited
16+
number of items.
17+
18+
This is useful for Iterable datasets where the length of the datasets isn't known.
19+
Such datasets can wrap their returned iterators with this class. See
20+
:func:`SyntheticImageStreamingDataset.iterator` for an example.
21+
22+
Attribute accesses are passed to the wrapped dataloader.
23+
"""
24+
25+
def __init__(
26+
self, dataloader: Iterable, limit: int, wrap_around: bool = True
27+
) -> None:
28+
"""Constructor for DataloaderLimitWrapper.
29+
30+
Args:
31+
dataloader: The dataloader to wrap around
32+
limit: Specify the number of calls to the underlying dataloader. The wrapper
33+
will raise a `StopIteration` after `limit` calls.
34+
wrap_around: Whether to wrap around the original datatloader if the
35+
dataloader is exhausted before `limit` calls.
36+
Raises:
37+
RuntimeError: If `wrap_around` is set to `False` and the underlying
38+
dataloader is exhausted before `limit` calls.
39+
"""
40+
super().__init__(dataloader)
41+
# we use self.__dict__ to set the attributes since the __setattr__ method
42+
# is overridden
43+
attributes = {"limit": limit, "wrap_around": wrap_around, "_count": None}
44+
self.__dict__.update(attributes)
45+
46+
def __iter__(self) -> Iterator[Any]:
47+
self._iter = iter(self.dataloader)
48+
self._count = 0
49+
return self
50+
51+
def __next__(self) -> Any:
52+
if self._count >= self.limit:
53+
raise StopIteration
54+
self._count += 1
55+
try:
56+
return next(self._iter)
57+
except StopIteration:
58+
if self.wrap_around:
59+
# create a new iterator to load data from the beginning
60+
logging.info(
61+
f"Wrapping around after {self._count} calls. Limit: {self.limit}"
62+
)
63+
try:
64+
self._iter = iter(self.dataloader)
65+
return next(self._iter)
66+
except StopIteration:
67+
raise RuntimeError(
68+
"Looks like the dataset is empty, "
69+
"have you configured it properly?"
70+
)
71+
else:
72+
raise RuntimeError(
73+
f"StopIteration raised before {self.limit} items were returned"
74+
)
75+
76+
def __len__(self) -> int:
77+
return self.limit
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Iterable, Iterator
8+
9+
from .dataloader_wrapper import DataloaderWrapper
10+
11+
12+
class DataloaderSkipNoneWrapper(DataloaderWrapper):
13+
"""
14+
Dataloader which wraps another dataloader and skip `None` batch data.
15+
16+
Attribute accesses are passed to the wrapped dataloader.
17+
"""
18+
19+
def __init__(self, dataloader: Iterable) -> None:
20+
super().__init__(dataloader)
21+
22+
def __iter__(self) -> Iterator[Any]:
23+
self._iter = iter(self.dataloader)
24+
return self
25+
26+
def __next__(self) -> Any:
27+
# we may get `None` batch data when all the images/videos in the batch
28+
# are corrupted. In such case, we keep getting the next batch until
29+
# meeting a good batch.
30+
next_batch = None
31+
while next_batch is None:
32+
next_batch = next(self._iter)
33+
return next_batch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from abc import ABC, abstractmethod
8+
from typing import Any, Iterable, Iterator
9+
10+
11+
class DataloaderWrapper(ABC):
12+
"""
13+
Abstract class representing dataloader which wraps another dataloader.
14+
15+
Attribute accesses are passed to the wrapped dataloader.
16+
"""
17+
18+
def __init__(self, dataloader: Iterable) -> None:
19+
# we use self.__dict__ to set the attributes since the __setattr__ method
20+
# is overridden
21+
attributes = {"dataloader": dataloader, "_iter": None}
22+
self.__dict__.update(attributes)
23+
24+
@abstractmethod
25+
def __iter__(self) -> Iterator[Any]:
26+
pass
27+
28+
@abstractmethod
29+
def __next__(self) -> Any:
30+
pass
31+
32+
def __getattr__(self, attr) -> Any:
33+
"""
34+
Pass the getattr call to the wrapped dataloader
35+
"""
36+
if attr in self.__dict__:
37+
return self.__dict__[attr]
38+
return getattr(self.dataloader, attr)
39+
40+
def __setattr__(self, attr, value) -> None:
41+
"""
42+
Pass the setattr call to the wrapped dataloader
43+
"""
44+
if attr in self.__dict__:
45+
self.__dict__[attr] = value
46+
else:
47+
setattr(self.dataloader, attr, value)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from test.generic.config_utils import get_test_task_config
9+
10+
from classy_vision.tasks import build_task
11+
12+
13+
class TestDataloaderLimitWrapper(unittest.TestCase):
14+
def _test_number_of_batches(self, data_iterator, expected_batches):
15+
num_batches = 0
16+
for _ in data_iterator:
17+
num_batches += 1
18+
self.assertEqual(num_batches, expected_batches)
19+
20+
def test_streaming_dataset(self):
21+
"""
22+
Test that streaming datasets return the correct number of batches, and that
23+
the length is also calculated correctly.
24+
"""
25+
config = get_test_task_config()
26+
dataset_config = {
27+
"name": "synthetic_image_streaming",
28+
"split": "train",
29+
"crop_size": 224,
30+
"class_ratio": 0.5,
31+
"num_samples": 2000,
32+
"length": 4000,
33+
"seed": 0,
34+
"batchsize_per_replica": 32,
35+
"use_shuffle": True,
36+
}
37+
expected_batches = 62
38+
config["dataset"]["train"] = dataset_config
39+
task = build_task(config)
40+
task.prepare()
41+
task.advance_phase()
42+
# test that the number of batches expected is correct
43+
self.assertEqual(task.num_batches_per_phase, expected_batches)
44+
45+
# test that the data iterator returns the expected number of batches
46+
data_iterator = task.get_data_iterator()
47+
self._test_number_of_batches(data_iterator, expected_batches)
48+
49+
# test that the dataloader can be rebuilt from the dataset inside it
50+
task._recreate_data_loader_from_dataset()
51+
task.create_data_iterator()
52+
data_iterator = task.get_data_iterator()
53+
self._test_number_of_batches(data_iterator, expected_batches)

0 commit comments

Comments
 (0)