From a9e2dd5e6a104e3a8fcba2c67d353a64ed4464dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 3 Oct 2025 15:33:09 +0200 Subject: [PATCH] Add ModelAlsoUsingSubmoduleParamsDirectly --- tests/unit/autogram/test_engine.py | 5 ++++- tests/utils/architectures.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 7c88ae9d..d4cac1ec 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -20,6 +20,7 @@ InterModuleParamReuse, MIMOBranched, MISOBranched, + ModelAlsoUsingSubmoduleParamsDirectly, ModelUsingSubmoduleParamsDirectly, ModuleReuse, MultiInputMultiOutput, @@ -190,7 +191,9 @@ def test_compute_gramian_with_weird_modules( @mark.xfail -@mark.parametrize("architecture", [ModelUsingSubmoduleParamsDirectly]) +@mark.parametrize( + "architecture", [ModelUsingSubmoduleParamsDirectly, ModelAlsoUsingSubmoduleParamsDirectly] +) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [0, None]) def test_compute_gramian_unsupported_architectures( diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 3dc126ad..c537c2d7 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -772,6 +772,22 @@ def forward(self, input: Tensor) -> Tensor: return input @ self.linear.weight.T + self.linear.bias +class ModelAlsoUsingSubmoduleParamsDirectly(ShapedModule): + """ + Model that uses its submodule's parameters directly but that also calls its submodule's forward. + """ + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 3) + + def forward(self, input: Tensor) -> Tensor: + return input @ self.linear.weight.T + self.linear.bias + self.linear(input) + + class _WithStringArg(nn.Module): def __init__(self): super().__init__()