Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,17 @@ def __init__(
split_batches: bool = False,
even_batches: bool = True,
):
if split_batches and batch_sampler.batch_size % num_processes != 0:
raise ValueError(
f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
f"needs to be a round multiple of the number of processes ({num_processes})."
)
self.batch_sampler = batch_sampler
self.num_processes = num_processes
self.process_index = process_index
self.split_batches = split_batches
self.even_batches = even_batches
self.batch_size = getattr(batch_sampler, "batch_size", None)
self.drop_last = getattr(batch_sampler, "drop_last", False)
if self.batch_size is None and self.even_batches:
if split_batches and (self.batch_size is None or self.batch_size % num_processes != 0):
raise ValueError(
"You need to use `even_batches=False` when the batch sampler has no batch size. If you "
"are not calling this method directly, set `accelerator.even_batches=False` instead."
f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({self.batch_size}) "
f"needs to be a round multiple of the number of processes ({num_processes})."
)

@property
Expand Down Expand Up @@ -217,11 +212,15 @@ def _iter_with_split(self):

def _iter_with_no_split(self):
initial_data = []
batch_to_yield = []
batch_to_yield = None
for idx, batch in enumerate(self.batch_sampler):
# We gather the initial indices in case we need to circle back at the end.
if not self.drop_last and idx < self.num_processes:
initial_data += batch
if self.batch_size is None:
# If batch size is None, `batch` is considered to be a list of indices with dynamic length.
initial_data.append(batch)
else:
initial_data += batch
# We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
# yielding it.
if idx % self.num_processes == self.process_index:
Expand All @@ -230,35 +229,44 @@ def _iter_with_no_split(self):
self.batch_size is None or len(batch) == self.batch_size
):
yield batch_to_yield
batch_to_yield = []
batch_to_yield = None

# If drop_last is True, iteration is over, otherwise...
if not self.drop_last and len(initial_data) > 0:
if not self.even_batches:
if len(batch_to_yield) > 0:
if batch_to_yield:
yield batch_to_yield
else:
# ... we yield the complete batch we had saved before if it has the proper length
if len(batch_to_yield) == self.batch_size:
if batch_to_yield and (self.batch_size is None or len(batch_to_yield) == self.batch_size):
yield batch_to_yield

# For degenerate cases where the dataset has less than num_process * batch_size samples
while len(initial_data) < self.num_processes * self.batch_size:
_min_length_needed = (
self.num_processes * self.batch_size if self.batch_size is not None else self.num_processes
)
while len(initial_data) < _min_length_needed:
initial_data += initial_data

# If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
if len(batch) == self.batch_size:
if self.batch_size is None or len(batch) == self.batch_size:
batch = []
idx += 1

# Make sure we yield a multiple of self.num_processes batches
cycle_index = 0
while idx % self.num_processes != 0 or len(batch) > 0:
end_index = cycle_index + self.batch_size - len(batch)
batch += initial_data[cycle_index:end_index]
if idx % self.num_processes == self.process_index:
yield batch
cycle_index = end_index
if self.batch_size is None:
batch = initial_data[cycle_index]
if idx % self.num_processes == self.process_index:
yield batch
cycle_index += 1
else:
end_index = cycle_index + self.batch_size - len(batch)
batch += initial_data[cycle_index:end_index]
if idx % self.num_processes == self.process_index:
yield batch
cycle_index = end_index
batch = []
idx += 1

Expand Down
101 changes: 101 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,107 @@ def test_batch_sampler_with_varying_batch_size(self):
assert list(batch_sampler_shards[0]) == [[0, 1, 2], [5, 6, 7, 8], [12, 13]]
assert list(batch_sampler_shards[1]) == [[3, 4], [9, 10, 11]]

def test_batch_sampler_varying_batch_size_even_batches(self):
"""
Tests for the new dynamic batch size + even_batches=True support in BatchSamplerShard.
This covers the modifications to _iter_with_no_split when batch_size is None.
"""
# --- Case 1: even number of batches, no padding needed ---
# 4 batches total, 2 processes -> each process sees 2 batches
batch_sampler = [[0, 1, 2], [3, 4], [5, 6, 7, 8], [9, 10, 11]]
shards = [BatchSamplerShard(batch_sampler, 2, i, even_batches=True) for i in range(2)]
assert list(shards[0]) == [[0, 1, 2], [5, 6, 7, 8]]
assert list(shards[1]) == [[3, 4], [9, 10, 11]]

# --- Case 2: odd number of batches -> process 1 needs padding from initial_data ---
# 5 batches, 2 processes: P0 gets batches 0,2,4; P1 gets 1,3 then is padded.
# Padding cycles initial_data (first num_processes batches = [[0,1,2],[3,4]]) from index 0,
# so P1's pad batch is initial_data[0] = [0,1,2].
batch_sampler = [[0, 1, 2], [3, 4], [5, 6, 7, 8], [9, 10, 11], [12, 13]]
shards = [BatchSamplerShard(batch_sampler, 2, i, even_batches=True) for i in range(2)]
assert list(shards[0]) == [[0, 1, 2], [5, 6, 7, 8], [12, 13]]
assert list(shards[1]) == [[3, 4], [9, 10, 11], [0, 1, 2]]

# --- Case 3: single batch (degenerate, dataset smaller than num_processes) ---
# 1 batch, 2 processes: initial_data=[B0], needs to double to length 2.
# P0 gets B0; P1 gets padding = initial_data[0] = B0.
batch_sampler = [[0, 1, 2]]
shards = [BatchSamplerShard(batch_sampler, 2, i, even_batches=True) for i in range(2)]
assert list(shards[0]) == [[0, 1, 2]]
assert list(shards[1]) == [[0, 1, 2]]

# --- Case 3b: 1 batch, 3 processes (cycle_index must reach beyond initial length after doubling) ---
# initial_data=[B0], doubled to [B0, B0]; P0 gets B0, P1 pads B0, P2 pads B0.
batch_sampler = [[0, 1, 2]]
shards = [BatchSamplerShard(batch_sampler, 3, i, even_batches=True) for i in range(3)]
assert list(shards[0]) == [[0, 1, 2]]
assert list(shards[1]) == [[0, 1, 2]]
assert list(shards[2]) == [[0, 1, 2]]

# --- Case 4: drop_last=True with dynamic batch size ---
# The last incomplete "round" of num_processes batches is dropped
batch_sampler = [[0, 1, 2], [3, 4], [5, 6, 7, 8], [9, 10, 11], [12, 13]]
# This is a raw list sampler; we need to set drop_last on the shard itself.
# drop_last is read via getattr(batch_sampler, "drop_last", False), defaults to False for a plain list.
# We test by wrapping in an object with drop_last=True.

class DropLastBatchSampler:
drop_last = True

def __init__(self, data):
self.data = data

def __iter__(self):
return iter(self.data)

def __len__(self):
return len(self.data)

dl = DropLastBatchSampler(batch_sampler)
shards = [BatchSamplerShard(dl, 2, i, even_batches=True) for i in range(2)]
# drop_last=True means the tail batch ([12, 13]) and its round are dropped entirely
assert list(shards[0]) == [[0, 1, 2], [5, 6, 7, 8]]
assert list(shards[1]) == [[3, 4], [9, 10, 11]]

def test_batch_sampler_varying_batch_size_many_processes(self):
"""
Tests dynamic batch size sharding with num_processes > 2.
"""
# 6 batches, 3 processes -> each gets 2 batches, no padding needed
batch_sampler = [[0], [1, 2], [3, 4, 5], [6], [7, 8], [9, 10, 11]]
shards = [BatchSamplerShard(batch_sampler, 3, i, even_batches=True) for i in range(3)]
assert list(shards[0]) == [[0], [6]]
assert list(shards[1]) == [[1, 2], [7, 8]]
assert list(shards[2]) == [[3, 4, 5], [9, 10, 11]]

# 7 batches, 3 processes -> needs padding to reach 9 total (3 rounds)
batch_sampler = [[0], [1, 2], [3, 4, 5], [6], [7, 8], [9, 10, 11], [12, 13]]
shards = [BatchSamplerShard(batch_sampler, 3, i, even_batches=True) for i in range(3)]
# batch indices: 0->p0, 1->p1, 2->p2, 3->p0, 4->p1, 5->p2, 6->p0(batch_to_yield=[12,13])
# After main loop with batch_to_yield not yielded: cycle padding fills rest of round 3
# cycle_index starts at 0: idx=7->p1 gets initial_data[0]=[0]; idx=8->p2 gets initial_data[1]=[1,2]
assert list(shards[0]) == [[0], [6], [12, 13]]
assert list(shards[1]) == [[1, 2], [7, 8], [0]]
assert list(shards[2]) == [[3, 4, 5], [9, 10, 11], [1, 2]]

def test_split_batches_validates_dynamic_batch_size(self):
"""
Tests that split_batches=True raises ValueError when batch_size is None.
This validates the new combined validation in __init__.
"""
# A plain list has no .batch_size attribute -> batch_size will be None
batch_sampler = [[0, 1, 2, 3], [4, 5, 6, 7]]
with pytest.raises(ValueError, match="split_batches"):
BatchSamplerShard(batch_sampler, 2, 0, split_batches=True, even_batches=True)

def test_split_batches_validates_non_divisible_batch_size(self):
"""
Tests that split_batches=True raises ValueError when batch_size is not divisible by num_processes.
"""
batch_sampler = BatchSampler(range(20), batch_size=3, drop_last=False)
with pytest.raises(ValueError, match="round multiple"):
BatchSamplerShard(batch_sampler, 2, 0, split_batches=True)

def check_iterable_dataset_shards(
self, dataset, seed, batch_size, drop_last=False, num_processes=2, split_batches=False
):
Expand Down