From bf1f1c764c2cc9c44d990cede5ebe63beccfe82e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 29 Sep 2025 14:04:27 +0200 Subject: [PATCH 1/2] Add possibility to mark test as xfail_if_cuda --- tests/unit/conftest.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 91bca4f9..2537ba50 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -37,13 +37,14 @@ def pytest_addoption(parser): def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") + config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda") def pytest_collection_modifyitems(config, items): - if config.getoption("--runslow"): - return - skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.") + xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}") for item in items: - if "slow" in item.keywords: + if "slow" in item.keywords and not config.getoption("--runslow"): item.add_marker(skip_slow) + if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"): + item.add_marker(xfail_cuda) From 5a7dd47e05901154928c477c35aaa11cefcf080a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 29 Sep 2025 14:04:45 +0200 Subject: [PATCH 2/2] Mark WithRNN as xfail_if_cuda --- tests/unit/autogram/test_engine.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 1ba8cc20..87e0fe1c 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -157,7 +157,13 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc @mark.parametrize( "architecture", - [WithBatchNorm, WithSideEffect, Randomness, WithModuleTrackingRunningStats, WithRNN], + [ + WithBatchNorm, + WithSideEffect, + Randomness, + WithModuleTrackingRunningStats, + param(WithRNN, marks=mark.xfail_if_cuda), + ], ) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None])