diff --git a/tests/utils/contexts.py b/tests/utils/contexts.py index dc508130..db2d54eb 100644 --- a/tests/utils/contexts.py +++ b/tests/utils/contexts.py @@ -10,7 +10,7 @@ @contextmanager def fork_rng(seed: int = 0) -> Generator[Any, None, None]: - devices = [DEVICE] if DEVICE.type == "cuda" else [] + devices = [] if DEVICE.type == "cpu" else [DEVICE] with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx: torch.manual_seed(seed) yield ctx