diff --git a/src/accelerate/__init__.py b/src/accelerate/__init__.py index c2d82779660..87a57f65155 100644 --- a/src/accelerate/__init__.py +++ b/src/accelerate/__init__.py @@ -23,7 +23,7 @@ init_on_device, load_checkpoint_and_dispatch, ) -from .data_loader import skip_first_batches +from .data_loader import DispatchDataLoader, skip_first_batches from .inference import prepare_pippy from .launchers import debug_launcher, notebook_launcher from .parallelism_config import ParallelismConfig diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 9c8f4a3c24e..d2486dcc4df 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -36,7 +36,7 @@ from .big_modeling import _attach_context_parallel_hooks from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, DispatchDataLoader, prepare_data_loader, skip_first_batches from .logging import get_logger from .optimizer import AcceleratedOptimizer from .parallelism_config import ParallelismConfig @@ -1399,6 +1399,18 @@ def _prepare_one(self, obj, first_pass=False, device_placement=None): if first_pass: if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj, device_placement=device_placement) + elif isinstance(obj, DispatchDataLoader): + if device_placement: + obj._device = self.device + return obj + elif ( + self.dataloader_config.custom_classes is not None + and isinstance(obj, self.dataloader_config.custom_classes) + ): + wrapped = DispatchDataLoader(obj) + if device_placement: + wrapped._device = self.device + return wrapped elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) elif isinstance(obj, torch.optim.Optimizer): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a8d7eaa01a0..c225312eea9 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1459,3 +1459,59 @@ def skip_first_batches(dataloader, num_batches=0): dataloader = MpDeviceLoaderWrapper(dataloader, device) return dataloader + + +class DispatchDataLoader: + """ + A lightweight wrapper that makes any Python iterable compatible with `Accelerator.prepare()`. + + Useful when you have a custom iterable object that yields batches and you want to use it with + the Accelerator without converting it to a `torch.utils.data.DataLoader`. Register the custom + class via `DataLoaderConfiguration(custom_classes=(MyClass,))` so that `Accelerator.prepare()` + recognises it and wraps it automatically. + + The wrapped iterable is responsible for any distributed sharding — `DispatchDataLoader` only + handles device placement of the yielded batches. + + Args: + dataloader (`Iterable`): + Any iterable that yields batches (dicts, tensors, or nested structures thereof). + + Example: + + ```python + import torch + from accelerate import Accelerator + from accelerate.utils import DataLoaderConfiguration + + class MyDataSource: + def __iter__(self): + for i in range(10): + yield {"input_ids": torch.tensor([i])} + + dataloader_config = DataLoaderConfiguration(custom_classes=(MyDataSource,)) + accelerator = Accelerator(dataloader_config=dataloader_config) + loader = accelerator.prepare(MyDataSource()) + for batch in loader: + ... # batch tensors are on accelerator.device + ``` + """ + + def __init__(self, dataloader): + self.dataloader = dataloader + self._device = None + + def __iter__(self): + for batch in self.dataloader: + if self._device is not None: + batch = send_to_device(batch, self._device) + yield batch + + def __getattr__(self, name): + try: + return getattr(object.__getattribute__(self, "dataloader"), name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __len__(self): + return len(self.dataloader) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 9c586f0057d..403bfbdca18 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -851,6 +851,11 @@ class DataLoaderConfiguration: If set to `True`, the dataloader prepared by the Accelerator will be backed by [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed. + custom_classes (`tuple` of types, defaults to `None`): + A tuple of custom class types that should be treated as iterable dataloaders when passed to + `Accelerator.prepare()`. Instances of these classes are wrapped in a `DispatchDataLoader` which + calls `__iter__` on the underlying object and handles automatic device placement. The classes must + be iterable (i.e. implement `__iter__`). """ split_batches: bool = field( @@ -909,6 +914,14 @@ class DataLoaderConfiguration: "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." }, ) + custom_classes: Optional[tuple] = field( + default=None, + metadata={ + "help": "A tuple of custom class types to treat as iterable dataloaders in `Accelerator.prepare()`. " + "Instances are wrapped in a `DispatchDataLoader` that calls `__iter__` on the object and handles " + "device placement. Each class must implement `__iter__`." + }, + ) @dataclass diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2057990a967..0f5f5bfcc33 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -26,6 +26,7 @@ DataLoaderDispatcher, DataLoaderShard, DataLoaderStateMixin, + DispatchDataLoader, IterableDatasetShard, SkipBatchSampler, SkipDataLoader, @@ -925,3 +926,144 @@ def get_all_batches(dl, device): gradient_state = GradientState() assert gradient_state.active_dataloader is None + + +class DispatchDataLoaderTester(AccelerateTestCase): + """Tests for DispatchDataLoader and DataLoaderConfiguration.custom_classes.""" + + def test_dispatch_dataloader_iterates(self): + """DispatchDataLoader yields every item from the wrapped iterable.""" + + class SimpleSource: + def __iter__(self): + yield from [{"x": torch.tensor(i)} for i in range(5)] + + def __len__(self): + return 5 + + loader = DispatchDataLoader(SimpleSource()) + batches = list(loader) + assert len(batches) == 5 + for i, batch in enumerate(DispatchDataLoader(SimpleSource())): + assert batch["x"].item() == i + + def test_dispatch_dataloader_len(self): + """DispatchDataLoader forwards __len__ to the wrapped object.""" + + class SizedSource: + def __iter__(self): + return iter([]) + + def __len__(self): + return 42 + + assert len(DispatchDataLoader(SizedSource())) == 42 + + def test_dispatch_dataloader_device_placement(self): + """Setting _device moves tensors onto the target device on each iteration.""" + device = torch.device("cpu") + + class TensorSource: + def __iter__(self): + yield {"val": torch.tensor(1)} + + loader = DispatchDataLoader(TensorSource()) + loader._device = device + batch = next(iter(loader)) + assert batch["val"].device == device + + def test_custom_classes_wraps_iterable(self): + """Accelerator.prepare() wraps a custom-class instance in DispatchDataLoader.""" + from accelerate.utils import DataLoaderConfiguration + + class MySource: + def __iter__(self): + yield from [{"x": torch.tensor(i)} for i in range(3)] + + config = DataLoaderConfiguration(custom_classes=(MySource,)) + accelerator = Accelerator(dataloader_config=config) + prepared = accelerator._prepare_one(MySource(), first_pass=True, device_placement=False) + assert isinstance(prepared, DispatchDataLoader) + batches = list(prepared) + assert len(batches) == 3 + + def test_custom_classes_none_does_not_wrap(self): + """Without custom_classes, non-DataLoader iterables pass through unprepared.""" + from accelerate.utils import DataLoaderConfiguration + + class MySource: + def __iter__(self): + yield from [] + + config = DataLoaderConfiguration(custom_classes=None) + accelerator = Accelerator(dataloader_config=config) + source = MySource() + result = accelerator._prepare_one(source, first_pass=True, device_placement=False) + assert result is source + + def test_dispatch_dataloader_prepare_directly(self): + """A DispatchDataLoader passed directly to _prepare_one is returned as-is.""" + from accelerate.utils import DataLoaderConfiguration + + loader = DispatchDataLoader(iter([])) + accelerator = Accelerator(dataloader_config=DataLoaderConfiguration()) + result = accelerator._prepare_one(loader, first_pass=True, device_placement=False) + assert result is loader + + def test_custom_classes_full_prepare_flow(self): + """accelerator.prepare() end-to-end with custom_classes yields correct batches.""" + from accelerate.utils import DataLoaderConfiguration + + class BatchSource: + def __iter__(self): + yield from [{"val": torch.tensor(i)} for i in range(4)] + + config = DataLoaderConfiguration(custom_classes=(BatchSource,)) + accelerator = Accelerator(dataloader_config=config) + prepared = accelerator.prepare(BatchSource()) + assert isinstance(prepared, DispatchDataLoader) + values = [batch["val"].item() for batch in prepared] + assert values == [0, 1, 2, 3] + + def test_custom_classes_multiple_types(self): + """custom_classes tuple with multiple types wraps instances of any listed class.""" + from accelerate.utils import DataLoaderConfiguration + + class SourceA: + def __iter__(self): + yield {"tag": torch.tensor(0)} + + class SourceB: + def __iter__(self): + yield {"tag": torch.tensor(1)} + + config = DataLoaderConfiguration(custom_classes=(SourceA, SourceB)) + accelerator = Accelerator(dataloader_config=config) + for source in [SourceA(), SourceB()]: + result = accelerator._prepare_one(source, first_pass=True, device_placement=False) + assert isinstance(result, DispatchDataLoader) + + def test_dispatch_dataloader_getattr_delegates(self): + """__getattr__ forwards attribute access to the wrapped object.""" + + class SourceWithMeta: + batch_size = 16 + dataset_name = "my_dataset" + + def __iter__(self): + return iter([]) + + loader = DispatchDataLoader(SourceWithMeta()) + assert loader.batch_size == 16 + assert loader.dataset_name == "my_dataset" + + def test_dispatch_dataloader_getattr_missing_raises(self): + """__getattr__ raises AttributeError for attributes that don't exist.""" + + class SimpleSource: + def __iter__(self): + return iter([]) + + loader = DispatchDataLoader(SimpleSource()) + with pytest.raises(AttributeError): + _ = loader.nonexistent_attr