diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py old mode 100755 new mode 100644 index a1cbfaa8d68..4f86e05b714 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -129,7 +129,7 @@ SCALER_NAME, ) from .utils.modeling import get_state_dict_offloaded_model -from .utils.other import compile_regions, compile_regions_deepspeed, is_compiled_module +from .utils.other import compile_regions, compile_regions_deepspeed, has_compiled_regions, is_compiled_module if is_deepspeed_available(): @@ -1885,6 +1885,15 @@ def prepare_model( device_ids, output_device = [self.local_process_index], self.local_process_index else: device_ids, output_device = None, None + # When torch.compile is enabled, compile the inner model before wrapping with DDP. + # This prevents DDP's internal methods (e.g. logger.set_runtime_stats_and_log) + # from being traced by Dynamo, which would cause an "Unsupported method call" error. + # This follows the PyTorch-recommended pattern for DDP + torch.compile. + if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): + if self.state.dynamo_plugin.use_regional_compilation: + model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs()) + else: + model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=device_ids, output_device=output_device, **kwargs ) @@ -2046,6 +2055,13 @@ def prepare_model( self._models[-1] = model elif self.distributed_type == DistributedType.MULTI_CPU: kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {} + # When torch.compile is enabled, compile the inner model before wrapping with DDP + # to avoid Dynamo tracing DDP internals (see MULTI_GPU path above for details). + if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): + if self.state.dynamo_plugin.use_regional_compilation: + model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs()) + else: + model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) if self.ddp_handler is not None: self.ddp_handler.register_comm_hook(model) @@ -2054,8 +2070,11 @@ def prepare_model( # Now we can apply the FP8 autocast if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast: model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) - # torch.compile should be called last and only if the model isn't already compiled - if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): + # torch.compile should be called last and only if the model isn't already compiled. + # Also skip if the model has compiled submodules (e.g. the inner module of DDP was already + # compiled before DDP wrapping to avoid Dynamo tracing DDP internals like + # logger.set_runtime_stats_and_log which would cause an "Unsupported method call" error). + if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model) and not has_compiled_regions(model): if self.state.dynamo_plugin.use_regional_compilation: model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs()) else: