-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
System Info
- `Accelerate` version: 0.24.0
- Platform: Linux-4.15.0-213-generic-x86_64-with-glibc2.27
- Python version: 3.9.0
- Numpy version: 1.26.1
- PyTorch version (GPU?): 1.13.1+cu117 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 503.79 GB
- GPU type: Tesla V100-PCIE-32GB
- `Accelerate` default config:
Not found
Information
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
) - My own task or dataset (give details below)
Reproduction
I defined my own batch sampler, aiming to output a batch of indices each time. It works well when it's passed as the batch_sampler
argument to torch.util.data.DataLoader
. Meanwhile, it functions after wrapping the dataloader with accelerate.data_loader.prepare_data_loader()
for the version accelerate==0.20.3
. But now I upgrade the accelerate
to the latest version 0.24.0
, it raises an Attribute Error: 'MySampler' object has no attribute 'sampler'
. The code snippets are pasted below.
I have compared the source code of accelerate/data_loader.py
in version 0.20.3 and 0.24.0. I found the main difference to this bug is the sampler_is_batch_sampler
variable in prepare_data_loade()
function. In line 718 in version 0.20.3, sampler_is_batch_sampler
is set to False
, while it is set as sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
in line 834 in version 0.24.0. The condition may be not right in my case, where sampler
of dataloader is None
and batch_sampler
is set to my own sampler (no member sampler
in the batch sampler) instead.
import accelerate
import numpy as np
from torch.utils.data import Sampler, DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self, data) -> None:
super().__init__()
self.data = data
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class MySampler:
"""
BaseSampler is an iterator which could be used as a `batch_sampler` in DataLoader.
It iterates a batch of sample index each time. And BaseSampler could only handle uniformly
sampling. It works with `num_workers` in DataLoader because each worker aims to load
a batch of samples each time.
"""
def __init__(self, dataset_length: int, batch_size: int, shuffle:bool=True) -> None:
self.batch_size = batch_size
self.data_index = np.arange(dataset_length)
self.shuffle = shuffle
def __iter__(self):
batch_num = len(self)
if self.shuffle:
index = np.random.permutation(self.data_index)
else:
index = self.data_index
output = np.array_split(index, batch_num)
yield from output
def __len__(self):
return (len(self.data_index) + self.batch_size - 1) // self.batch_size
dataset = MyDataset(np.arange(10000))
batch_sampler = MySampler(len(dataset), 32)
dataloader_pt = DataLoader(dataset, batch_sampler=batch_sampler)
# Original pytorch DataLoader, no problem. Pytorch version 1.13.1
for d in dataloader_pt:
print(d)
break
# Accelerate DataLoader, attribute not found. Accelerate verision 0.24.0
dataloader_al = accelerate.data_loader.prepare_data_loader(dataloader_pt)
for d in dataloader_al:
print(d)
break
The output of the snippets is:
tensor([3851, 7922, 4125, 2075, 4539, 6159, 9525, 8622, 967, 3022, 7877, 9807,
5243, 3136, 1554, 5355, 4284, 3041, 5014, 4597, 7593, 1324, 4064, 4886,
7167, 4549, 7643, 6493, 9435, 5662, 4689, 2710])
AttributeError Traceback (most recent call last)
/home/xxxx/xxxx.ipynb Cell 21 line 5
File ~/.conda/envs/xxx/lib/python3.9/site-packages/accelerate/data_loader.py:838, in prepare_data_loader(dataloader, device, num_processes, process_index, split_batches, put_on_device, rng_types, dispatch_batches, even_batches, slice_fn_for_dispatch)
836 sampler = dataloader.sampler.sampler
837 else:
--> 838 sampler = dataloader.batch_sampler.sampler
839 if isinstance(sampler, RandomSampler) and num_processes > 1:
840 # When iterating through the dataloader during distributed processes
841 # we want to ensure that on each process we are iterating through the same
842 # samples in the same order if a seed is set. This requires a tweak
843 # to the `torch.utils.data.RandomSampler` class (if used).
844 sampler = SeedableRandomSampler(
845 data_source=sampler.data_source,
846 replacement=sampler.replacement,
847 num_samples=sampler._num_samples,
848 generator=getattr(sampler, \"generator\", torch.Generator()),
849 )
AttributeError: 'MySampler' object has no attribute 'sampler'"
Expected behavior
The expected output should be normal, where no error raised.