From 5b22cb37a0a75d422596ee08852365cdd4f295a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 28 Sep 2025 15:47:07 +0200 Subject: [PATCH 1/2] Make proper implementation of WithRNN --- tests/utils/architectures.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 7a231528..fc47c32c 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -698,15 +698,18 @@ def forward(self, input: Tensor) -> Tensor: return torch.einsum("bi,icdef->bcdef", input, self.tensor) -class WithRNN(nn.Module): - """Simple model containing an RNN module (that is not even used).""" +class WithRNN(ShapedModule): + """Simple model containing an RNN module.""" + + INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8) + OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5) def __init__(self): super().__init__() - self.rnn = nn.RNN(input_size=10, hidden_size=5) + self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True) - def forward(self, input: Tensor) -> None: - pass + def forward(self, input: Tensor) -> Tensor: + return self.rnn(input) class WithModuleTrackingRunningStats(ShapedModule): From 83c1c0fd0518a9de4d67785dd2ccc0a8847583ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 28 Sep 2025 15:47:36 +0200 Subject: [PATCH 2/2] Add WithRNN to test_compute_gramian_with_weird_modules parametrization --- tests/unit/autogram/test_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index cfe012e7..1bac7674 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -154,7 +154,8 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc @mark.parametrize( - "architecture", [WithBatchNorm, WithSideEffect, Randomness, WithModuleTrackingRunningStats] + "architecture", + [WithBatchNorm, WithSideEffect, Randomness, WithModuleTrackingRunningStats, WithRNN], ) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None])