Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/accelerate/accelerator.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
SCALER_NAME,
)
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import compile_regions, compile_regions_deepspeed, is_compiled_module
from .utils.other import compile_regions, compile_regions_deepspeed, has_compiled_regions, is_compiled_module


if is_deepspeed_available():
Expand Down Expand Up @@ -1885,6 +1885,15 @@ def prepare_model(
device_ids, output_device = [self.local_process_index], self.local_process_index
else:
device_ids, output_device = None, None
# When torch.compile is enabled, compile the inner model before wrapping with DDP.
# This prevents DDP's internal methods (e.g. logger.set_runtime_stats_and_log)
# from being traced by Dynamo, which would cause an "Unsupported method call" error.
# This follows the PyTorch-recommended pattern for DDP + torch.compile.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if self.state.dynamo_plugin.use_regional_compilation:
model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
else:
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=device_ids, output_device=output_device, **kwargs
)
Expand Down Expand Up @@ -2046,6 +2055,13 @@ def prepare_model(
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {}
# When torch.compile is enabled, compile the inner model before wrapping with DDP
# to avoid Dynamo tracing DDP internals (see MULTI_GPU path above for details).
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if self.state.dynamo_plugin.use_regional_compilation:
model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
else:
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
Expand All @@ -2054,8 +2070,11 @@ def prepare_model(
# Now we can apply the FP8 autocast
if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
# torch.compile should be called last and only if the model isn't already compiled.
# Also skip if the model has compiled submodules (e.g. the inner module of DDP was already
# compiled before DDP wrapping to avoid Dynamo tracing DDP internals like
# logger.set_runtime_stats_and_log which would cause an "Unsupported method call" error).
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model) and not has_compiled_regions(model):
if self.state.dynamo_plugin.use_regional_compilation:
model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
else:
Expand Down