From 01c75a0198d358b644b8aa8ca201a3fdb91dcbfa Mon Sep 17 00:00:00 2001 From: YU Xinyuan Date: Tue, 10 Mar 2026 16:20:21 +0800 Subject: [PATCH] Support dynamic batch size in BatchSamplerShard with even_batches --- src/accelerate/data_loader.py | 48 +++++++++------- tests/test_data_loader.py | 101 ++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 20 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a8d7eaa01a0..f370ea3fb79 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -150,11 +150,6 @@ def __init__( split_batches: bool = False, even_batches: bool = True, ): - if split_batches and batch_sampler.batch_size % num_processes != 0: - raise ValueError( - f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) " - f"needs to be a round multiple of the number of processes ({num_processes})." - ) self.batch_sampler = batch_sampler self.num_processes = num_processes self.process_index = process_index @@ -162,10 +157,10 @@ def __init__( self.even_batches = even_batches self.batch_size = getattr(batch_sampler, "batch_size", None) self.drop_last = getattr(batch_sampler, "drop_last", False) - if self.batch_size is None and self.even_batches: + if split_batches and (self.batch_size is None or self.batch_size % num_processes != 0): raise ValueError( - "You need to use `even_batches=False` when the batch sampler has no batch size. If you " - "are not calling this method directly, set `accelerator.even_batches=False` instead." + f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({self.batch_size}) " + f"needs to be a round multiple of the number of processes ({num_processes})." ) @property @@ -217,11 +212,15 @@ def _iter_with_split(self): def _iter_with_no_split(self): initial_data = [] - batch_to_yield = [] + batch_to_yield = None for idx, batch in enumerate(self.batch_sampler): # We gather the initial indices in case we need to circle back at the end. if not self.drop_last and idx < self.num_processes: - initial_data += batch + if self.batch_size is None: + # If batch size is None, `batch` is considered to be a list of indices with dynamic length. + initial_data.append(batch) + else: + initial_data += batch # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually # yielding it. if idx % self.num_processes == self.process_index: @@ -230,35 +229,44 @@ def _iter_with_no_split(self): self.batch_size is None or len(batch) == self.batch_size ): yield batch_to_yield - batch_to_yield = [] + batch_to_yield = None # If drop_last is True, iteration is over, otherwise... if not self.drop_last and len(initial_data) > 0: if not self.even_batches: - if len(batch_to_yield) > 0: + if batch_to_yield: yield batch_to_yield else: # ... we yield the complete batch we had saved before if it has the proper length - if len(batch_to_yield) == self.batch_size: + if batch_to_yield and (self.batch_size is None or len(batch_to_yield) == self.batch_size): yield batch_to_yield # For degenerate cases where the dataset has less than num_process * batch_size samples - while len(initial_data) < self.num_processes * self.batch_size: + _min_length_needed = ( + self.num_processes * self.batch_size if self.batch_size is not None else self.num_processes + ) + while len(initial_data) < _min_length_needed: initial_data += initial_data # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next - if len(batch) == self.batch_size: + if self.batch_size is None or len(batch) == self.batch_size: batch = [] idx += 1 # Make sure we yield a multiple of self.num_processes batches cycle_index = 0 while idx % self.num_processes != 0 or len(batch) > 0: - end_index = cycle_index + self.batch_size - len(batch) - batch += initial_data[cycle_index:end_index] - if idx % self.num_processes == self.process_index: - yield batch - cycle_index = end_index + if self.batch_size is None: + batch = initial_data[cycle_index] + if idx % self.num_processes == self.process_index: + yield batch + cycle_index += 1 + else: + end_index = cycle_index + self.batch_size - len(batch) + batch += initial_data[cycle_index:end_index] + if idx % self.num_processes == self.process_index: + yield batch + cycle_index = end_index batch = [] idx += 1 diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2057990a967..2d438f79a8d 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -358,6 +358,107 @@ def test_batch_sampler_with_varying_batch_size(self): assert list(batch_sampler_shards[0]) == [[0, 1, 2], [5, 6, 7, 8], [12, 13]] assert list(batch_sampler_shards[1]) == [[3, 4], [9, 10, 11]] + def test_batch_sampler_varying_batch_size_even_batches(self): + """ + Tests for the new dynamic batch size + even_batches=True support in BatchSamplerShard. + This covers the modifications to _iter_with_no_split when batch_size is None. + """ + # --- Case 1: even number of batches, no padding needed --- + # 4 batches total, 2 processes -> each process sees 2 batches + batch_sampler = [[0, 1, 2], [3, 4], [5, 6, 7, 8], [9, 10, 11]] + shards = [BatchSamplerShard(batch_sampler, 2, i, even_batches=True) for i in range(2)] + assert list(shards[0]) == [[0, 1, 2], [5, 6, 7, 8]] + assert list(shards[1]) == [[3, 4], [9, 10, 11]] + + # --- Case 2: odd number of batches -> process 1 needs padding from initial_data --- + # 5 batches, 2 processes: P0 gets batches 0,2,4; P1 gets 1,3 then is padded. + # Padding cycles initial_data (first num_processes batches = [[0,1,2],[3,4]]) from index 0, + # so P1's pad batch is initial_data[0] = [0,1,2]. + batch_sampler = [[0, 1, 2], [3, 4], [5, 6, 7, 8], [9, 10, 11], [12, 13]] + shards = [BatchSamplerShard(batch_sampler, 2, i, even_batches=True) for i in range(2)] + assert list(shards[0]) == [[0, 1, 2], [5, 6, 7, 8], [12, 13]] + assert list(shards[1]) == [[3, 4], [9, 10, 11], [0, 1, 2]] + + # --- Case 3: single batch (degenerate, dataset smaller than num_processes) --- + # 1 batch, 2 processes: initial_data=[B0], needs to double to length 2. + # P0 gets B0; P1 gets padding = initial_data[0] = B0. + batch_sampler = [[0, 1, 2]] + shards = [BatchSamplerShard(batch_sampler, 2, i, even_batches=True) for i in range(2)] + assert list(shards[0]) == [[0, 1, 2]] + assert list(shards[1]) == [[0, 1, 2]] + + # --- Case 3b: 1 batch, 3 processes (cycle_index must reach beyond initial length after doubling) --- + # initial_data=[B0], doubled to [B0, B0]; P0 gets B0, P1 pads B0, P2 pads B0. + batch_sampler = [[0, 1, 2]] + shards = [BatchSamplerShard(batch_sampler, 3, i, even_batches=True) for i in range(3)] + assert list(shards[0]) == [[0, 1, 2]] + assert list(shards[1]) == [[0, 1, 2]] + assert list(shards[2]) == [[0, 1, 2]] + + # --- Case 4: drop_last=True with dynamic batch size --- + # The last incomplete "round" of num_processes batches is dropped + batch_sampler = [[0, 1, 2], [3, 4], [5, 6, 7, 8], [9, 10, 11], [12, 13]] + # This is a raw list sampler; we need to set drop_last on the shard itself. + # drop_last is read via getattr(batch_sampler, "drop_last", False), defaults to False for a plain list. + # We test by wrapping in an object with drop_last=True. + + class DropLastBatchSampler: + drop_last = True + + def __init__(self, data): + self.data = data + + def __iter__(self): + return iter(self.data) + + def __len__(self): + return len(self.data) + + dl = DropLastBatchSampler(batch_sampler) + shards = [BatchSamplerShard(dl, 2, i, even_batches=True) for i in range(2)] + # drop_last=True means the tail batch ([12, 13]) and its round are dropped entirely + assert list(shards[0]) == [[0, 1, 2], [5, 6, 7, 8]] + assert list(shards[1]) == [[3, 4], [9, 10, 11]] + + def test_batch_sampler_varying_batch_size_many_processes(self): + """ + Tests dynamic batch size sharding with num_processes > 2. + """ + # 6 batches, 3 processes -> each gets 2 batches, no padding needed + batch_sampler = [[0], [1, 2], [3, 4, 5], [6], [7, 8], [9, 10, 11]] + shards = [BatchSamplerShard(batch_sampler, 3, i, even_batches=True) for i in range(3)] + assert list(shards[0]) == [[0], [6]] + assert list(shards[1]) == [[1, 2], [7, 8]] + assert list(shards[2]) == [[3, 4, 5], [9, 10, 11]] + + # 7 batches, 3 processes -> needs padding to reach 9 total (3 rounds) + batch_sampler = [[0], [1, 2], [3, 4, 5], [6], [7, 8], [9, 10, 11], [12, 13]] + shards = [BatchSamplerShard(batch_sampler, 3, i, even_batches=True) for i in range(3)] + # batch indices: 0->p0, 1->p1, 2->p2, 3->p0, 4->p1, 5->p2, 6->p0(batch_to_yield=[12,13]) + # After main loop with batch_to_yield not yielded: cycle padding fills rest of round 3 + # cycle_index starts at 0: idx=7->p1 gets initial_data[0]=[0]; idx=8->p2 gets initial_data[1]=[1,2] + assert list(shards[0]) == [[0], [6], [12, 13]] + assert list(shards[1]) == [[1, 2], [7, 8], [0]] + assert list(shards[2]) == [[3, 4, 5], [9, 10, 11], [1, 2]] + + def test_split_batches_validates_dynamic_batch_size(self): + """ + Tests that split_batches=True raises ValueError when batch_size is None. + This validates the new combined validation in __init__. + """ + # A plain list has no .batch_size attribute -> batch_size will be None + batch_sampler = [[0, 1, 2, 3], [4, 5, 6, 7]] + with pytest.raises(ValueError, match="split_batches"): + BatchSamplerShard(batch_sampler, 2, 0, split_batches=True, even_batches=True) + + def test_split_batches_validates_non_divisible_batch_size(self): + """ + Tests that split_batches=True raises ValueError when batch_size is not divisible by num_processes. + """ + batch_sampler = BatchSampler(range(20), batch_size=3, drop_last=False) + with pytest.raises(ValueError, match="round multiple"): + BatchSamplerShard(batch_sampler, 2, 0, split_batches=True) + def check_iterable_dataset_shards( self, dataset, seed, batch_size, drop_last=False, num_processes=2, split_batches=False ):