diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 7c88ae9d..d4cac1ec 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -20,6 +20,7 @@ InterModuleParamReuse, MIMOBranched, MISOBranched, + ModelAlsoUsingSubmoduleParamsDirectly, ModelUsingSubmoduleParamsDirectly, ModuleReuse, MultiInputMultiOutput, @@ -190,7 +191,9 @@ def test_compute_gramian_with_weird_modules( @mark.xfail -@mark.parametrize("architecture", [ModelUsingSubmoduleParamsDirectly]) +@mark.parametrize( + "architecture", [ModelUsingSubmoduleParamsDirectly, ModelAlsoUsingSubmoduleParamsDirectly] +) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [0, None]) def test_compute_gramian_unsupported_architectures( diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 3dc126ad..c537c2d7 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -772,6 +772,22 @@ def forward(self, input: Tensor) -> Tensor: return input @ self.linear.weight.T + self.linear.bias +class ModelAlsoUsingSubmoduleParamsDirectly(ShapedModule): + """ + Model that uses its submodule's parameters directly but that also calls its submodule's forward. + """ + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 3) + + def forward(self, input: Tensor) -> Tensor: + return input @ self.linear.weight.T + self.linear.bias + self.linear(input) + + class _WithStringArg(nn.Module): def __init__(self): super().__init__()