diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index ebab8a8c057..6ac346c0cc6 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -259,7 +259,9 @@ def noop(*args, **kwargs): pass with ( - patch(f"torch.{torch_device}.set_device", noop), + # Some backends such as MPS do not expose a module-level `set_device`. + # This test only exercises env var parsing, so a synthetic attribute is enough. + patch(f"torch.{torch_device}.set_device", noop, create=True), patch_environment(ACCELERATE_TORCH_DEVICE=f"{torch_device}:64"), ): accelerator = Accelerator()