|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | +from typing import Any, Iterable, Iterator |
| 9 | + |
| 10 | +from .dataloader_wrapper import DataloaderWrapper |
| 11 | + |
| 12 | + |
| 13 | +class DataloaderLimitWrapper(DataloaderWrapper): |
| 14 | + """ |
| 15 | + Dataloader which wraps another dataloader and only returns a limited |
| 16 | + number of items. |
| 17 | +
|
| 18 | + This is useful for Iterable datasets where the length of the datasets isn't known. |
| 19 | + Such datasets can wrap their returned iterators with this class. See |
| 20 | + :func:`SyntheticImageStreamingDataset.iterator` for an example. |
| 21 | +
|
| 22 | + Attribute accesses are passed to the wrapped dataloader. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, dataloader: Iterable, limit: int, wrap_around: bool = True |
| 27 | + ) -> None: |
| 28 | + """Constructor for DataloaderLimitWrapper. |
| 29 | +
|
| 30 | + Args: |
| 31 | + dataloader: The dataloader to wrap around |
| 32 | + limit: Specify the number of calls to the underlying dataloader. The wrapper |
| 33 | + will raise a `StopIteration` after `limit` calls. |
| 34 | + wrap_around: Whether to wrap around the original datatloader if the |
| 35 | + dataloader is exhausted before `limit` calls. |
| 36 | + Raises: |
| 37 | + RuntimeError: If `wrap_around` is set to `False` and the underlying |
| 38 | + dataloader is exhausted before `limit` calls. |
| 39 | + """ |
| 40 | + super().__init__(dataloader) |
| 41 | + # we use self.__dict__ to set the attributes since the __setattr__ method |
| 42 | + # is overridden |
| 43 | + attributes = {"limit": limit, "wrap_around": wrap_around, "_count": None} |
| 44 | + self.__dict__.update(attributes) |
| 45 | + |
| 46 | + def __iter__(self) -> Iterator[Any]: |
| 47 | + self._iter = iter(self.dataloader) |
| 48 | + self._count = 0 |
| 49 | + return self |
| 50 | + |
| 51 | + def __next__(self) -> Any: |
| 52 | + if self._count >= self.limit: |
| 53 | + raise StopIteration |
| 54 | + self._count += 1 |
| 55 | + try: |
| 56 | + return next(self._iter) |
| 57 | + except StopIteration: |
| 58 | + if self.wrap_around: |
| 59 | + # create a new iterator to load data from the beginning |
| 60 | + logging.info( |
| 61 | + f"Wrapping around after {self._count} calls. Limit: {self.limit}" |
| 62 | + ) |
| 63 | + try: |
| 64 | + self._iter = iter(self.dataloader) |
| 65 | + return next(self._iter) |
| 66 | + except StopIteration: |
| 67 | + raise RuntimeError( |
| 68 | + "Looks like the dataset is empty, " |
| 69 | + "have you configured it properly?" |
| 70 | + ) |
| 71 | + else: |
| 72 | + raise RuntimeError( |
| 73 | + f"StopIteration raised before {self.limit} items were returned" |
| 74 | + ) |
| 75 | + |
| 76 | + def __len__(self) -> int: |
| 77 | + return self.limit |
0 commit comments