Skip to content

feat:Add dataset truncation to prevent deadlocks during training #3615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ class BatchSamplerShard(BatchSampler):
even_batches (`bool`, *optional*, defaults to `True`):
Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
multiple of (original batch size / number of processes).
truncate_dataset (`bool`, *optional*, defaults to `False`):
Whether to truncate the dataset to ensure optimal distribution across processes based on gradient steps.
When enabled, prevents deadlocks in multi-GPU training with sample tracking mechanisms.
gradient_steps (`int`, *optional*, defaults to `None`):
Number of gradient accumulation steps. Required when `truncate_dataset` is True.

<Tip warning={true}>

Expand All @@ -148,6 +153,8 @@ def __init__(
process_index: int = 0,
split_batches: bool = False,
even_batches: bool = True,
truncate_dataset: bool = False,
gradient_steps: Optional[int] = None,
):
if split_batches and batch_sampler.batch_size % num_processes != 0:
raise ValueError(
Expand All @@ -166,12 +173,28 @@ def __init__(
"You need to use `even_batches=False` when the batch sampler has no batch size. If you "
"are not calling this method directly, set `accelerator.even_batches=False` instead."
)
self.truncate_dataset = truncate_dataset
self.gradient_steps = gradient_steps
if truncate_dataset and gradient_steps is None:
raise ValueError("gradient_steps must be provided when truncate_dataset is True")
self._optimal_size = self._calculate_optimal_size()

def _calculate_optimal_size(self):
"""Calculate the optimal dataset size to prevent GPU deadlocks."""
if not self.truncate_dataset:
return None
# Calculate total samples needed per epoch to ensure even distribution
samples_per_epoch = self.batch_size * self.gradient_steps * self.num_processes
return samples_per_epoch

@property
def total_length(self):
return len(self.batch_sampler)

def __len__(self):
if self.truncate_dataset and self._optimal_size is not None:
total_batches = self._optimal_size // self.batch_size
return total_batches // self.num_processes
if self.split_batches:
# Split batches does not change the length of the batch sampler
return len(self.batch_sampler)
Expand All @@ -195,15 +218,28 @@ def __iter__(self):
def _iter_with_split(self):
initial_data = []
batch_length = self.batch_sampler.batch_size // self.num_processes
batch_yielded = 0

max_batches = None
if self.truncate_dataset and self._optimal_size is not None:
max_batches = self._optimal_size // self.batch_size
max_batches = max_batches // self.num_processes

for idx, batch in enumerate(self.batch_sampler):
if max_batches is not None and batch_yielded >= max_batches:
break
if idx == 0:
initial_data = batch
if len(batch) == self.batch_size:
# If the batch is full, we yield the part of it this process is responsible of.
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
batch_yielded += 1

# If drop_last is True of the last batch was full, iteration is over, otherwise...
if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
if max_batches is not None and batch_yielded >= max_batches:
return

if not self.even_batches:
if len(batch) > batch_length * self.process_index:
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
Expand All @@ -217,7 +253,16 @@ def _iter_with_split(self):
def _iter_with_no_split(self):
initial_data = []
batch_to_yield = []
batches_yielded = 0

max_batches = None
if self.truncate_dataset and self._optimal_size is not None:
max_batches = self._optimal_size // self.batch_size
max_batches = max_batches // self.num_processes

for idx, batch in enumerate(self.batch_sampler):
if max_batches is not None and batches_yielded >= max_batches:
break
# We gather the initial indices in case we need to circle back at the end.
if not self.drop_last and idx < self.num_processes:
initial_data += batch
Expand All @@ -229,9 +274,12 @@ def _iter_with_no_split(self):
self.batch_size is None or len(batch) == self.batch_size
):
yield batch_to_yield
batches_yielded += 1
batch_to_yield = []

# If drop_last is True, iteration is over, otherwise...
if max_batches is not None and batches_yielded >= max_batches:
return
if not self.drop_last and len(initial_data) > 0:
if not self.even_batches:
if len(batch_to_yield) > 0:
Expand Down
58 changes: 57 additions & 1 deletion tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def set_epoch(self, epoch):
class DataLoaderTester(AccelerateTestCase):
def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):
batch_sampler_shards = [
BatchSamplerShard(batch_sampler, 2, i, split_batches=split_batches, even_batches=even_batches)
BatchSamplerShard(batch_sampler, 2, i, split_batches=split_batches, even_batches=even_batches, truncate_dataset=False, gradient_steps=None)
for i in range(2)
]
batch_sampler_lists = [list(batch_sampler_shard) for batch_sampler_shard in batch_sampler_shards]
Expand Down Expand Up @@ -557,6 +557,62 @@ def __call__(self, *args, **kwds):
assert dataloader_ref() is None
assert gradient_state_ref() is None

def test_batch_sampler_shards_with_truncation(self):
# Test with truncation enabled
batch_sampler = BatchSampler(range(100), batch_size=4, drop_last=False)
gradient_steps = 10
num_processes = 2

# Calculate expected size
samples_per_step = batch_sampler.batch_size * num_processes
optimal_size = (gradient_steps * samples_per_step)

batch_sampler_shards = [
BatchSamplerShard(
batch_sampler,
num_processes,
i,
split_batches=False,
even_batches=True,
truncate_dataset=True,
gradient_steps=gradient_steps
)
for i in range(num_processes)
]

# Check that each shard has the optimal size
for shard in batch_sampler_shards:
assert len(shard) == optimal_size // (batch_sampler.batch_size * num_processes)

# Test that truncation is disabled by default
batch_sampler_shards_no_truncate = [
BatchSamplerShard(
batch_sampler,
num_processes,
i,
split_batches=False,
even_batches=True
)
for i in range(num_processes)
]

# Check that without truncation, we get the full dataset
for shard in batch_sampler_shards_no_truncate:
if shard.even_batches:
expected = (len(batch_sampler) // num_processes) + 1
else:
expected = len(batch_sampler) // num_processes
assert len(shard) == expected

# Test that gradient_steps is required when truncate_dataset is True
with pytest.raises(ValueError):
BatchSamplerShard(
batch_sampler,
num_processes,
0,
truncate_dataset=True
)


class StatefulDataLoaderTester(AccelerateTestCase):
@require_torchdata_stateful_dataloader
Expand Down