diff --git a/datumaro/datumaro/plugins/transforms.py b/datumaro/datumaro/plugins/transforms.py index 502414c2d50e..f373932841c4 100644 --- a/datumaro/datumaro/plugins/transforms.py +++ b/datumaro/datumaro/plugins/transforms.py @@ -5,6 +5,7 @@ import logging as log import os.path as osp +import random import pycocotools.mask as mask_utils @@ -295,6 +296,66 @@ def transform_item(self, item): return self.wrap_item(item, subset=self._mapping.get(item.subset, item.subset)) +class RandomSplit(Transform, CliPlugin): + """ + Joins all subsets into one and splits the result into few parts. + It is expected that item ids are unique and subset ratios sum up to 1.|n + |n + Example:|n + |s|s%(prog)s --subset train:.67 --subset test:.33 + """ + + @staticmethod + def _split_arg(s): + parts = s.split(':') + if len(parts) != 2: + import argparse + raise argparse.ArgumentTypeError() + return (parts[0], float(parts[1])) + + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-s', '--subset', action='append', + type=cls._split_arg, dest='splits', + help="Subsets in the form of: ':' (repeatable)") + parser.add_argument('--seed', type=int, help="Random seed") + return parser + + def __init__(self, extractor, splits, seed=None): + super().__init__(extractor) + + total_ratio = sum((s[1] for s in splits), 0) + if not total_ratio == 1: + raise Exception( + "Sum of ratios is expected to be 1, got %s, which is %s" % + (splits, total_ratio)) + + dataset_size = len(extractor) + indices = list(range(dataset_size)) + + random.seed(seed) + random.shuffle(indices) + + parts = [] + s = 0 + for subset, ratio in splits: + s += ratio + boundary = int(s * dataset_size) + parts.append((boundary, subset)) + + self._parts = parts + + def _find_split(self, index): + for boundary, subset in self._parts: + if index < boundary: + return subset + return subset + + def __iter__(self): + for i, item in enumerate(self._extractor): + yield self.wrap_item(item, subset=self._find_split(i)) + class IdFromImageName(Transform, CliPlugin): def transform_item(self, item): name = item.id diff --git a/datumaro/tests/test_transforms.py b/datumaro/tests/test_transforms.py index 85b776e3edb9..19f9bea2fa77 100644 --- a/datumaro/tests/test_transforms.py +++ b/datumaro/tests/test_transforms.py @@ -320,3 +320,40 @@ def __iter__(self): actual = transforms.BoxesToMasks(SrcExtractor()) compare_datasets(self, DstExtractor(), actual) + + def test_random_split(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset="a"), + DatasetItem(id=2, subset="a"), + DatasetItem(id=3, subset="b"), + DatasetItem(id=4, subset="b"), + DatasetItem(id=5, subset="b"), + DatasetItem(id=6, subset=""), + DatasetItem(id=7, subset=""), + ]) + + actual = transforms.RandomSplit(SrcExtractor(), splits=[ + ('train', 4.0 / 7.0), + ('test', 3.0 / 7.0), + ]) + + self.assertEqual(4, len(actual.get_subset('train'))) + self.assertEqual(3, len(actual.get_subset('test'))) + + def test_random_split_gives_error_on_non1_ratios(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([DatasetItem(id=1)]) + + has_error = False + try: + transforms.RandomSplit(SrcExtractor(), splits=[ + ('train', 0.5), + ('test', 0.7), + ]) + except Exception: + has_error = True + + self.assertTrue(has_error) \ No newline at end of file