diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index cfe012e7..1bac7674 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -154,7 +154,8 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc @mark.parametrize( - "architecture", [WithBatchNorm, WithSideEffect, Randomness, WithModuleTrackingRunningStats] + "architecture", + [WithBatchNorm, WithSideEffect, Randomness, WithModuleTrackingRunningStats, WithRNN], ) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None]) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 7a231528..fc47c32c 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -698,15 +698,18 @@ def forward(self, input: Tensor) -> Tensor: return torch.einsum("bi,icdef->bcdef", input, self.tensor) -class WithRNN(nn.Module): - """Simple model containing an RNN module (that is not even used).""" +class WithRNN(ShapedModule): + """Simple model containing an RNN module.""" + + INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8) + OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5) def __init__(self): super().__init__() - self.rnn = nn.RNN(input_size=10, hidden_size=5) + self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True) - def forward(self, input: Tensor) -> None: - pass + def forward(self, input: Tensor) -> Tensor: + return self.rnn(input) class WithModuleTrackingRunningStats(ShapedModule):