|
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | from packaging import version
|
22 |
| -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler |
| 22 | +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler |
23 | 23 |
|
24 | 24 | from .logging import get_logger
|
25 | 25 | from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
|
@@ -631,10 +631,12 @@ def get_sampler(self):
|
631 | 631 | return get_sampler(self)
|
632 | 632 |
|
633 | 633 | def set_sampler(self, sampler):
|
634 |
| - if isinstance(sampler, BatchSampler): |
635 |
| - self.sampler.batch_sampler = sampler |
636 |
| - elif isinstance(sampler, Sampler): |
| 634 | + if isinstance(self.sampler, BatchSampler): |
637 | 635 | self.sampler.sampler = sampler
|
| 636 | + else: |
| 637 | + self.batch_sampler.sampler = sampler |
| 638 | + if hasattr(self.batch_sampler, "batch_sampler"): |
| 639 | + self.batch_sampler.batch_sampler.sampler = sampler |
638 | 640 |
|
639 | 641 |
|
640 | 642 | if is_torch_xla_available():
|
@@ -955,12 +957,12 @@ def get_sampler(self):
|
955 | 957 | return get_sampler(self)
|
956 | 958 |
|
957 | 959 | def set_sampler(self, sampler):
|
958 |
| - if isinstance(sampler, BatchSampler): |
959 |
| - self.sampler.batch_sampler = sampler |
960 |
| - elif isinstance(sampler, Sampler): |
| 960 | + if isinstance(self.sampler, BatchSampler): |
961 | 961 | self.sampler.sampler = sampler
|
962 | 962 | else:
|
963 |
| - raise ValueError(f"{sampler} must be of type torch.utils.data.Sampler or torch.utils.data.BatchSampler") |
| 963 | + self.batch_sampler.sampler = sampler |
| 964 | + if hasattr(self.batch_sampler, "batch_sampler"): |
| 965 | + self.batch_sampler.batch_sampler.sampler = sampler |
964 | 966 |
|
965 | 967 |
|
966 | 968 | def get_sampler(dataloader):
|
|
0 commit comments