diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 3e33cbe4..296ab5a9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -6,7 +6,7 @@ import torch from pytest import mark, param from torch import Tensor -from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear +from torch.nn import BatchNorm2d, InstanceNorm2d, Linear from torch.optim import SGD from torch.testing import assert_close from utils.architectures import ( @@ -56,6 +56,7 @@ WithModuleWithStringOutput, WithMultiHeadAttention, WithNoTensorOutput, + WithRNN, WithSideEffect, WithSomeFrozenModule, WithTransformer, @@ -179,10 +180,7 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int ModuleFactory(WithSideEffect), ModuleFactory(Randomness), ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), - param( - ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), - marks=mark.xfail_if_cuda, - ), + param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda), ], ) @mark.parametrize("batch_size", [1, 3, 32]) @@ -398,7 +396,7 @@ def test_autograd_while_modules_are_hooked( ["factory", "batch_dim"], [ (ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0), - (ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0), + param(ModuleFactory(WithRNN), 0), (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0), ], ) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 3337be09..46aaa592 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -45,11 +45,6 @@ def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]: if isinstance(module, ShapedModule): return module.INPUT_SHAPES, module.OUTPUT_SHAPES - elif isinstance(module, nn.RNN): - assert module.batch_first - SEQ_LEN = 20 # Arbitrary choice - return (SEQ_LEN, module.input_size), (SEQ_LEN, module.hidden_size) - elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)): HEIGHT = 6 # Arbitrary choice WIDTH = 6 # Arbitrary choice @@ -737,6 +732,21 @@ def forward(self, input: Tensor) -> Tensor: return torch.einsum("bi,icdef->bcdef", input, self.tensor) +class WithRNN(ShapedModule): + """Simple model containing an RNN module.""" + + INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8) + OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5) + + def __init__(self): + super().__init__() + self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True) + + def forward(self, input: Tensor) -> Tensor: + output, _ = self.rnn(input) + return output + + class WithDropout(ShapedModule): """Simple model containing Dropout layers."""