Skip to content

Commit b60fd09

Browse files
committed
revert to original set_sampler logic
1 parent fc76f6b commit b60fd09

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/accelerate/data_loader.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121
from packaging import version
22-
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler
22+
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
2323

2424
from .logging import get_logger
2525
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
@@ -631,10 +631,12 @@ def get_sampler(self):
631631
return get_sampler(self)
632632

633633
def set_sampler(self, sampler):
634-
if isinstance(sampler, BatchSampler):
635-
self.sampler.batch_sampler = sampler
636-
elif isinstance(sampler, Sampler):
634+
if isinstance(self.sampler, BatchSampler):
637635
self.sampler.sampler = sampler
636+
else:
637+
self.batch_sampler.sampler = sampler
638+
if hasattr(self.batch_sampler, "batch_sampler"):
639+
self.batch_sampler.batch_sampler.sampler = sampler
638640

639641

640642
if is_torch_xla_available():
@@ -955,12 +957,12 @@ def get_sampler(self):
955957
return get_sampler(self)
956958

957959
def set_sampler(self, sampler):
958-
if isinstance(sampler, BatchSampler):
959-
self.sampler.batch_sampler = sampler
960-
elif isinstance(sampler, Sampler):
960+
if isinstance(self.sampler, BatchSampler):
961961
self.sampler.sampler = sampler
962962
else:
963-
raise ValueError(f"{sampler} must be of type torch.utils.data.Sampler or torch.utils.data.BatchSampler")
963+
self.batch_sampler.sampler = sampler
964+
if hasattr(self.batch_sampler, "batch_sampler"):
965+
self.batch_sampler.batch_sampler.sampler = sampler
964966

965967

966968
def get_sampler(dataloader):

0 commit comments

Comments
 (0)