diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index db4df51e..871d1c1c 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -4,6 +4,8 @@ functions here to test them. """ +from pytest import mark + def test_amp(): import torch @@ -198,12 +200,13 @@ def test_autogram(): test_autogram() +@mark.filterwarnings( + "ignore::DeprecationWarning", "ignore::lightning.fabric.utilities.warnings.PossibleUserWarning" +) def test_lightning_integration(): # Extra ---------------------------------------------------------------------------------------- import logging - import warnings - warnings.filterwarnings("ignore") logging.disable(logging.INFO) # ---------------------------------------------------------------------------------------------- diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index af523412..5fed3869 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -1,6 +1,4 @@ -import warnings - -from pytest import mark +from pytest import mark, param from torch import Tensor, tensor from torch.testing import assert_close @@ -98,7 +96,12 @@ from torchjd.aggregation import NashMTL AGGREGATOR_PARAMETRIZATIONS.append( - (NashMTL(n_tasks=2), J_base, tensor([0.0542, 0.7061, 0.7061])) + param( + NashMTL(n_tasks=2), + J_base, + tensor([0.0542, 0.7061, 0.7061]), + marks=mark.filterwarnings("ignore::UserWarning"), + ) ) except ImportError: @@ -109,9 +112,6 @@ def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor): """Test that the output values of an aggregator are fixed (on cpu).""" - if str(A).startswith("NashMTL"): - warnings.filterwarnings("ignore") - assert_close(A(J), expected_output, rtol=0, atol=1e-4)