Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/accelerate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
init_on_device,
load_checkpoint_and_dispatch,
)
from .data_loader import skip_first_batches
from .data_loader import DispatchDataLoader, skip_first_batches
from .inference import prepare_pippy
from .launchers import debug_launcher, notebook_launcher
from .parallelism_config import ParallelismConfig
Expand Down
14 changes: 13 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from .big_modeling import _attach_context_parallel_hooks
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .data_loader import DataLoaderDispatcher, DispatchDataLoader, prepare_data_loader, skip_first_batches
from .logging import get_logger
from .optimizer import AcceleratedOptimizer
from .parallelism_config import ParallelismConfig
Expand Down Expand Up @@ -1399,6 +1399,18 @@ def _prepare_one(self, obj, first_pass=False, device_placement=None):
if first_pass:
if isinstance(obj, torch.utils.data.DataLoader):
return self.prepare_data_loader(obj, device_placement=device_placement)
elif isinstance(obj, DispatchDataLoader):
if device_placement:
obj._device = self.device
return obj
elif (
self.dataloader_config.custom_classes is not None
and isinstance(obj, self.dataloader_config.custom_classes)
):
wrapped = DispatchDataLoader(obj)
if device_placement:
wrapped._device = self.device
return wrapped
elif isinstance(obj, torch.nn.Module):
return self.prepare_model(obj, device_placement=device_placement)
elif isinstance(obj, torch.optim.Optimizer):
Expand Down
56 changes: 56 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,3 +1459,59 @@ def skip_first_batches(dataloader, num_batches=0):
dataloader = MpDeviceLoaderWrapper(dataloader, device)

return dataloader


class DispatchDataLoader:
"""
A lightweight wrapper that makes any Python iterable compatible with `Accelerator.prepare()`.

Useful when you have a custom iterable object that yields batches and you want to use it with
the Accelerator without converting it to a `torch.utils.data.DataLoader`. Register the custom
class via `DataLoaderConfiguration(custom_classes=(MyClass,))` so that `Accelerator.prepare()`
recognises it and wraps it automatically.

The wrapped iterable is responsible for any distributed sharding — `DispatchDataLoader` only
handles device placement of the yielded batches.

Args:
dataloader (`Iterable`):
Any iterable that yields batches (dicts, tensors, or nested structures thereof).

Example:

```python
import torch
from accelerate import Accelerator
from accelerate.utils import DataLoaderConfiguration

class MyDataSource:
def __iter__(self):
for i in range(10):
yield {"input_ids": torch.tensor([i])}

dataloader_config = DataLoaderConfiguration(custom_classes=(MyDataSource,))
accelerator = Accelerator(dataloader_config=dataloader_config)
loader = accelerator.prepare(MyDataSource())
for batch in loader:
... # batch tensors are on accelerator.device
```
"""

def __init__(self, dataloader):
self.dataloader = dataloader
self._device = None

def __iter__(self):
for batch in self.dataloader:
if self._device is not None:
batch = send_to_device(batch, self._device)
yield batch

def __getattr__(self, name):
try:
return getattr(object.__getattribute__(self, "dataloader"), name)
except AttributeError:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __len__(self):
return len(self.dataloader)
13 changes: 13 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,11 @@ class DataLoaderConfiguration:
If set to `True`, the dataloader prepared by the Accelerator will be backed by
[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed.
custom_classes (`tuple` of types, defaults to `None`):
A tuple of custom class types that should be treated as iterable dataloaders when passed to
`Accelerator.prepare()`. Instances of these classes are wrapped in a `DispatchDataLoader` which
calls `__iter__` on the underlying object and handles automatic device placement. The classes must
be iterable (i.e. implement `__iter__`).
"""

split_batches: bool = field(
Expand Down Expand Up @@ -909,6 +914,14 @@ class DataLoaderConfiguration:
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
},
)
custom_classes: Optional[tuple] = field(
default=None,
metadata={
"help": "A tuple of custom class types to treat as iterable dataloaders in `Accelerator.prepare()`. "
"Instances are wrapped in a `DispatchDataLoader` that calls `__iter__` on the object and handles "
"device placement. Each class must implement `__iter__`."
},
)


@dataclass
Expand Down
142 changes: 142 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DataLoaderDispatcher,
DataLoaderShard,
DataLoaderStateMixin,
DispatchDataLoader,
IterableDatasetShard,
SkipBatchSampler,
SkipDataLoader,
Expand Down Expand Up @@ -925,3 +926,144 @@ def get_all_batches(dl, device):

gradient_state = GradientState()
assert gradient_state.active_dataloader is None


class DispatchDataLoaderTester(AccelerateTestCase):
"""Tests for DispatchDataLoader and DataLoaderConfiguration.custom_classes."""

def test_dispatch_dataloader_iterates(self):
"""DispatchDataLoader yields every item from the wrapped iterable."""

class SimpleSource:
def __iter__(self):
yield from [{"x": torch.tensor(i)} for i in range(5)]

def __len__(self):
return 5

loader = DispatchDataLoader(SimpleSource())
batches = list(loader)
assert len(batches) == 5
for i, batch in enumerate(DispatchDataLoader(SimpleSource())):
assert batch["x"].item() == i

def test_dispatch_dataloader_len(self):
"""DispatchDataLoader forwards __len__ to the wrapped object."""

class SizedSource:
def __iter__(self):
return iter([])

def __len__(self):
return 42

assert len(DispatchDataLoader(SizedSource())) == 42

def test_dispatch_dataloader_device_placement(self):
"""Setting _device moves tensors onto the target device on each iteration."""
device = torch.device("cpu")

class TensorSource:
def __iter__(self):
yield {"val": torch.tensor(1)}

loader = DispatchDataLoader(TensorSource())
loader._device = device
batch = next(iter(loader))
assert batch["val"].device == device

def test_custom_classes_wraps_iterable(self):
"""Accelerator.prepare() wraps a custom-class instance in DispatchDataLoader."""
from accelerate.utils import DataLoaderConfiguration

class MySource:
def __iter__(self):
yield from [{"x": torch.tensor(i)} for i in range(3)]

config = DataLoaderConfiguration(custom_classes=(MySource,))
accelerator = Accelerator(dataloader_config=config)
prepared = accelerator._prepare_one(MySource(), first_pass=True, device_placement=False)
assert isinstance(prepared, DispatchDataLoader)
batches = list(prepared)
assert len(batches) == 3

def test_custom_classes_none_does_not_wrap(self):
"""Without custom_classes, non-DataLoader iterables pass through unprepared."""
from accelerate.utils import DataLoaderConfiguration

class MySource:
def __iter__(self):
yield from []

config = DataLoaderConfiguration(custom_classes=None)
accelerator = Accelerator(dataloader_config=config)
source = MySource()
result = accelerator._prepare_one(source, first_pass=True, device_placement=False)
assert result is source

def test_dispatch_dataloader_prepare_directly(self):
"""A DispatchDataLoader passed directly to _prepare_one is returned as-is."""
from accelerate.utils import DataLoaderConfiguration

loader = DispatchDataLoader(iter([]))
accelerator = Accelerator(dataloader_config=DataLoaderConfiguration())
result = accelerator._prepare_one(loader, first_pass=True, device_placement=False)
assert result is loader

def test_custom_classes_full_prepare_flow(self):
"""accelerator.prepare() end-to-end with custom_classes yields correct batches."""
from accelerate.utils import DataLoaderConfiguration

class BatchSource:
def __iter__(self):
yield from [{"val": torch.tensor(i)} for i in range(4)]

config = DataLoaderConfiguration(custom_classes=(BatchSource,))
accelerator = Accelerator(dataloader_config=config)
prepared = accelerator.prepare(BatchSource())
assert isinstance(prepared, DispatchDataLoader)
values = [batch["val"].item() for batch in prepared]
assert values == [0, 1, 2, 3]

def test_custom_classes_multiple_types(self):
"""custom_classes tuple with multiple types wraps instances of any listed class."""
from accelerate.utils import DataLoaderConfiguration

class SourceA:
def __iter__(self):
yield {"tag": torch.tensor(0)}

class SourceB:
def __iter__(self):
yield {"tag": torch.tensor(1)}

config = DataLoaderConfiguration(custom_classes=(SourceA, SourceB))
accelerator = Accelerator(dataloader_config=config)
for source in [SourceA(), SourceB()]:
result = accelerator._prepare_one(source, first_pass=True, device_placement=False)
assert isinstance(result, DispatchDataLoader)

def test_dispatch_dataloader_getattr_delegates(self):
"""__getattr__ forwards attribute access to the wrapped object."""

class SourceWithMeta:
batch_size = 16
dataset_name = "my_dataset"

def __iter__(self):
return iter([])

loader = DispatchDataLoader(SourceWithMeta())
assert loader.batch_size == 16
assert loader.dataset_name == "my_dataset"

def test_dispatch_dataloader_getattr_missing_raises(self):
"""__getattr__ raises AttributeError for attributes that don't exist."""

class SimpleSource:
def __iter__(self):
return iter([])

loader = DispatchDataLoader(SimpleSource())
with pytest.raises(AttributeError):
_ = loader.nonexistent_attr