diff --git a/tests/settings.py b/tests/settings.py index c032e194..b7fe2345 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -2,22 +2,25 @@ import torch +_POSSIBLE_TEST_DEVICES = {"cpu", "cuda:0"} +_POSSIBLE_TEST_DTYPES = {"float32", "float64"} + try: _device_str = os.environ["PYTEST_TORCH_DEVICE"] except KeyError: _device_str = "cpu" # Default to cpu if environment variable not set -if _device_str != "cuda:0" and _device_str != "cpu": - raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}") +if _device_str not in _POSSIBLE_TEST_DEVICES: + raise ValueError( + f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}.\n" + f"Possible values: {_POSSIBLE_TEST_DEVICES}." + ) if _device_str == "cuda:0" and not torch.cuda.is_available(): raise ValueError('Requested device "cuda:0" but cuda is not available.') DEVICE = torch.device(_device_str) - -_POSSIBLE_TEST_DTYPES = {"float32", "float64"} - try: _dtype_str = os.environ["PYTEST_TORCH_DTYPE"] except KeyError: