diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 1ba8cc20..a7749b60 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -20,6 +20,7 @@ InterModuleParamReuse, MIMOBranched, MISOBranched, + ModelUsingSubmoduleParamsDirectly, ModuleReuse, MultiInputMultiOutput, MultiInputSingleOutput, @@ -172,6 +173,21 @@ def test_compute_gramian_with_weird_modules( _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) +@mark.xfail +@mark.parametrize("architecture", [ModelUsingSubmoduleParamsDirectly]) +@mark.parametrize("batch_size", [1, 3, 32]) +@mark.parametrize("batch_dim", [0, None]) +def test_compute_gramian_unsupported_architectures( + architecture: type[ShapedModule], batch_size: int, batch_dim: int | None +): + """ + Tests compute_gramian on some architectures that are known to be unsupported. It is expected to + fail. + """ + + _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) + + @mark.parametrize("batch_size", [1, 3, 16]) @mark.parametrize( ["reduction", "movedim_source", "movedim_destination", "batch_dim"], diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index facf852c..0b26d89e 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -755,6 +755,23 @@ def forward(self, input: Tensor) -> Tensor: return self.dropout(self.conv2d(self.dropout(input))) +class ModelUsingSubmoduleParamsDirectly(ShapedModule): + """ + Model that uses its submodule's parameters directly and that does not call its submodule's + forward. + """ + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 3) + + def forward(self, input: Tensor) -> Tensor: + return input @ self.linear.weight.T + self.linear.bias + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the