From dd434519e361181b2d07b74f4884dd8fd3920277 Mon Sep 17 00:00:00 2001 From: KrishVenky Date: Thu, 21 May 2026 19:37:11 +0530 Subject: [PATCH 1/2] feat: add DispatchDataLoader for custom iterable dataloaders Closes #2975. Adds DispatchDataLoader, a lightweight wrapper that makes any Python iterable usable with Accelerator.prepare() without requiring it to be a torch.utils.data.DataLoader. - data_loader.py: new DispatchDataLoader class with device placement - utils/dataclasses.py: custom_classes field on DataLoaderConfiguration - accelerator.py: _prepare_one recognises DispatchDataLoader and custom_classes - __init__.py: exports DispatchDataLoader - tests/test_data_loader.py: 8 new tests, 39/39 passing --- src/accelerate/__init__.py | 2 +- src/accelerate/accelerator.py | 14 +++- src/accelerate/data_loader.py | 50 ++++++++++++ src/accelerate/utils/dataclasses.py | 13 ++++ tests/test_data_loader.py | 117 ++++++++++++++++++++++++++++ 5 files changed, 194 insertions(+), 2 deletions(-) diff --git a/src/accelerate/__init__.py b/src/accelerate/__init__.py index c2d82779660..87a57f65155 100644 --- a/src/accelerate/__init__.py +++ b/src/accelerate/__init__.py @@ -23,7 +23,7 @@ init_on_device, load_checkpoint_and_dispatch, ) -from .data_loader import skip_first_batches +from .data_loader import DispatchDataLoader, skip_first_batches from .inference import prepare_pippy from .launchers import debug_launcher, notebook_launcher from .parallelism_config import ParallelismConfig diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 9c8f4a3c24e..d2486dcc4df 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -36,7 +36,7 @@ from .big_modeling import _attach_context_parallel_hooks from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, DispatchDataLoader, prepare_data_loader, skip_first_batches from .logging import get_logger from .optimizer import AcceleratedOptimizer from .parallelism_config import ParallelismConfig @@ -1399,6 +1399,18 @@ def _prepare_one(self, obj, first_pass=False, device_placement=None): if first_pass: if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj, device_placement=device_placement) + elif isinstance(obj, DispatchDataLoader): + if device_placement: + obj._device = self.device + return obj + elif ( + self.dataloader_config.custom_classes is not None + and isinstance(obj, self.dataloader_config.custom_classes) + ): + wrapped = DispatchDataLoader(obj) + if device_placement: + wrapped._device = self.device + return wrapped elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) elif isinstance(obj, torch.optim.Optimizer): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a8d7eaa01a0..b1bf40aa814 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1459,3 +1459,53 @@ def skip_first_batches(dataloader, num_batches=0): dataloader = MpDeviceLoaderWrapper(dataloader, device) return dataloader + + +class DispatchDataLoader: + """ + A lightweight wrapper that makes any Python iterable compatible with `Accelerator.prepare()`. + + Useful when you have a custom iterable object that yields batches and you want to use it with + the Accelerator without converting it to a `torch.utils.data.DataLoader`. Register the custom + class via `DataLoaderConfiguration(custom_classes=(MyClass,))` so that `Accelerator.prepare()` + recognises it and wraps it automatically. + + The wrapped iterable is responsible for any distributed sharding — `DispatchDataLoader` only + handles device placement of the yielded batches. + + Args: + dataloader (`Iterable`): + Any iterable that yields batches (dicts, tensors, or nested structures thereof). + + Example: + + ```python + import torch + from accelerate import Accelerator + from accelerate.utils import DataLoaderConfiguration + + class MyDataSource: + def __iter__(self): + for i in range(10): + yield {"input_ids": torch.tensor([i])} + + dataloader_config = DataLoaderConfiguration(custom_classes=(MyDataSource,)) + accelerator = Accelerator(dataloader_config=dataloader_config) + loader = accelerator.prepare(MyDataSource()) + for batch in loader: + ... # batch tensors are on accelerator.device + ``` + """ + + def __init__(self, dataloader): + self.dataloader = dataloader + self._device = None + + def __iter__(self): + for batch in self.dataloader: + if self._device is not None: + batch = send_to_device(batch, self._device) + yield batch + + def __len__(self): + return len(self.dataloader) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 9c586f0057d..403bfbdca18 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -851,6 +851,11 @@ class DataLoaderConfiguration: If set to `True`, the dataloader prepared by the Accelerator will be backed by [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed. + custom_classes (`tuple` of types, defaults to `None`): + A tuple of custom class types that should be treated as iterable dataloaders when passed to + `Accelerator.prepare()`. Instances of these classes are wrapped in a `DispatchDataLoader` which + calls `__iter__` on the underlying object and handles automatic device placement. The classes must + be iterable (i.e. implement `__iter__`). """ split_batches: bool = field( @@ -909,6 +914,14 @@ class DataLoaderConfiguration: "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." }, ) + custom_classes: Optional[tuple] = field( + default=None, + metadata={ + "help": "A tuple of custom class types to treat as iterable dataloaders in `Accelerator.prepare()`. " + "Instances are wrapped in a `DispatchDataLoader` that calls `__iter__` on the object and handles " + "device placement. Each class must implement `__iter__`." + }, + ) @dataclass diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2057990a967..a01c01e95c7 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -26,6 +26,7 @@ DataLoaderDispatcher, DataLoaderShard, DataLoaderStateMixin, + DispatchDataLoader, IterableDatasetShard, SkipBatchSampler, SkipDataLoader, @@ -925,3 +926,119 @@ def get_all_batches(dl, device): gradient_state = GradientState() assert gradient_state.active_dataloader is None + + +class DispatchDataLoaderTester(AccelerateTestCase): + """Tests for DispatchDataLoader and DataLoaderConfiguration.custom_classes.""" + + def test_dispatch_dataloader_iterates(self): + """DispatchDataLoader yields every item from the wrapped iterable.""" + + class SimpleSource: + def __iter__(self): + yield from [{"x": torch.tensor(i)} for i in range(5)] + + def __len__(self): + return 5 + + loader = DispatchDataLoader(SimpleSource()) + batches = list(loader) + assert len(batches) == 5 + for i, batch in enumerate(DispatchDataLoader(SimpleSource())): + assert batch["x"].item() == i + + def test_dispatch_dataloader_len(self): + """DispatchDataLoader forwards __len__ to the wrapped object.""" + + class SizedSource: + def __iter__(self): + return iter([]) + + def __len__(self): + return 42 + + assert len(DispatchDataLoader(SizedSource())) == 42 + + def test_dispatch_dataloader_device_placement(self): + """Setting _device moves tensors onto the target device on each iteration.""" + device = torch.device("cpu") + + class TensorSource: + def __iter__(self): + yield {"val": torch.tensor(1)} + + loader = DispatchDataLoader(TensorSource()) + loader._device = device + batch = next(iter(loader)) + assert batch["val"].device == device + + def test_custom_classes_wraps_iterable(self): + """Accelerator.prepare() wraps a custom-class instance in DispatchDataLoader.""" + from accelerate.utils import DataLoaderConfiguration + + class MySource: + def __iter__(self): + yield from [{"x": torch.tensor(i)} for i in range(3)] + + config = DataLoaderConfiguration(custom_classes=(MySource,)) + accelerator = Accelerator(dataloader_config=config) + prepared = accelerator._prepare_one(MySource(), first_pass=True, device_placement=False) + assert isinstance(prepared, DispatchDataLoader) + batches = list(prepared) + assert len(batches) == 3 + + def test_custom_classes_none_does_not_wrap(self): + """Without custom_classes, non-DataLoader iterables pass through unprepared.""" + from accelerate.utils import DataLoaderConfiguration + + class MySource: + def __iter__(self): + yield from [] + + config = DataLoaderConfiguration(custom_classes=None) + accelerator = Accelerator(dataloader_config=config) + source = MySource() + result = accelerator._prepare_one(source, first_pass=True, device_placement=False) + assert result is source + + def test_dispatch_dataloader_prepare_directly(self): + """A DispatchDataLoader passed directly to _prepare_one is returned as-is.""" + from accelerate.utils import DataLoaderConfiguration + + loader = DispatchDataLoader(iter([])) + accelerator = Accelerator(dataloader_config=DataLoaderConfiguration()) + result = accelerator._prepare_one(loader, first_pass=True, device_placement=False) + assert result is loader + + def test_custom_classes_full_prepare_flow(self): + """accelerator.prepare() end-to-end with custom_classes yields correct batches.""" + from accelerate.utils import DataLoaderConfiguration + + class BatchSource: + def __iter__(self): + yield from [{"val": torch.tensor(i)} for i in range(4)] + + config = DataLoaderConfiguration(custom_classes=(BatchSource,)) + accelerator = Accelerator(dataloader_config=config) + prepared = accelerator.prepare(BatchSource()) + assert isinstance(prepared, DispatchDataLoader) + values = [batch["val"].item() for batch in prepared] + assert values == [0, 1, 2, 3] + + def test_custom_classes_multiple_types(self): + """custom_classes tuple with multiple types wraps instances of any listed class.""" + from accelerate.utils import DataLoaderConfiguration + + class SourceA: + def __iter__(self): + yield {"tag": torch.tensor(0)} + + class SourceB: + def __iter__(self): + yield {"tag": torch.tensor(1)} + + config = DataLoaderConfiguration(custom_classes=(SourceA, SourceB)) + accelerator = Accelerator(dataloader_config=config) + for source in [SourceA(), SourceB()]: + result = accelerator._prepare_one(source, first_pass=True, device_placement=False) + assert isinstance(result, DispatchDataLoader) From 06689d18b81a25a9224dada50bb8b45522993b0e Mon Sep 17 00:00:00 2001 From: KrishVenky Date: Sat, 23 May 2026 23:18:15 +0530 Subject: [PATCH 2/2] add __getattr__ delegation to DispatchDataLoader Forwards unknown attribute access to the wrapped iterable, matching the pattern used by DataLoaderAdapter. Raises AttributeError cleanly when the attribute does not exist on the wrapped object either. Adds two tests covering delegation and missing-attribute behaviour. --- src/accelerate/data_loader.py | 6 ++++++ tests/test_data_loader.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index b1bf40aa814..c225312eea9 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1507,5 +1507,11 @@ def __iter__(self): batch = send_to_device(batch, self._device) yield batch + def __getattr__(self, name): + try: + return getattr(object.__getattribute__(self, "dataloader"), name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def __len__(self): return len(self.dataloader) diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index a01c01e95c7..0f5f5bfcc33 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -1042,3 +1042,28 @@ def __iter__(self): for source in [SourceA(), SourceB()]: result = accelerator._prepare_one(source, first_pass=True, device_placement=False) assert isinstance(result, DispatchDataLoader) + + def test_dispatch_dataloader_getattr_delegates(self): + """__getattr__ forwards attribute access to the wrapped object.""" + + class SourceWithMeta: + batch_size = 16 + dataset_name = "my_dataset" + + def __iter__(self): + return iter([]) + + loader = DispatchDataLoader(SourceWithMeta()) + assert loader.batch_size == 16 + assert loader.dataset_name == "my_dataset" + + def test_dispatch_dataloader_getattr_missing_raises(self): + """__getattr__ raises AttributeError for attributes that don't exist.""" + + class SimpleSource: + def __iter__(self): + return iter([]) + + loader = DispatchDataLoader(SimpleSource()) + with pytest.raises(AttributeError): + _ = loader.nonexistent_attr