From 5fec5f7ea6d72f38e04140e7e6264a9740867354 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Oct 2025 01:20:32 +0200 Subject: [PATCH] test: Remove trivial ShapedModules * Add support for RNN, BatchNorm2d and InstanceNorm2d in get_in_out_shapes * Remove WithRNN, WithBatchNorm and WithModuleTrackingRunningStats - use simple factories instead * Whenever a ShapedModule is simply a wrapper around a single nn.Module, we can replace it by a line in get_in_out_shapes and make a factory directly out of the nn.Module. It's not strictly equivalent (because now the module isn't wrapped), but it's closer to what a user would do. --- tests/unit/autogram/test_engine.py | 20 +++++------ tests/utils/architectures.py | 54 +++++++----------------------- 2 files changed, 22 insertions(+), 52 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 00a3855d..3e33cbe4 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -6,7 +6,7 @@ import torch from pytest import mark, param from torch import Tensor -from torch.nn import Linear +from torch.nn import RNN, BatchNorm2d, InstanceNorm2d, Linear from torch.optim import SGD from torch.testing import assert_close from utils.architectures import ( @@ -47,10 +47,8 @@ SomeUnusedOutput, SomeUnusedParam, SqueezeNet, - WithBatchNorm, WithBuffered, WithDropout, - WithModuleTrackingRunningStats, WithModuleWithHybridPyTreeArg, WithModuleWithHybridPyTreeKwarg, WithModuleWithStringArg, @@ -58,7 +56,6 @@ WithModuleWithStringOutput, WithMultiHeadAttention, WithNoTensorOutput, - WithRNN, WithSideEffect, WithSomeFrozenModule, WithTransformer, @@ -178,11 +175,14 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int @mark.parametrize( "factory", [ - ModuleFactory(WithBatchNorm), + ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), ModuleFactory(WithSideEffect), ModuleFactory(Randomness), - ModuleFactory(WithModuleTrackingRunningStats), - param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda), + ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), + param( + ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), + marks=mark.xfail_if_cuda, + ), ], ) @mark.parametrize("batch_size", [1, 3, 32]) @@ -397,9 +397,9 @@ def test_autograd_while_modules_are_hooked( @mark.parametrize( ["factory", "batch_dim"], [ - (ModuleFactory(WithModuleTrackingRunningStats), 0), - (ModuleFactory(WithRNN), 0), - (ModuleFactory(WithBatchNorm), 0), + (ModuleFactory(InstanceNorm2d, num_features=3, affine=True, track_running_stats=True), 0), + (ModuleFactory(RNN, input_size=8, hidden_size=5, batch_first=True), 0), + (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0), ], ) def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 35eb13a8..3337be09 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -44,6 +44,18 @@ def __init_subclass__(cls): def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]: if isinstance(module, ShapedModule): return module.INPUT_SHAPES, module.OUTPUT_SHAPES + + elif isinstance(module, nn.RNN): + assert module.batch_first + SEQ_LEN = 20 # Arbitrary choice + return (SEQ_LEN, module.input_size), (SEQ_LEN, module.hidden_size) + + elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)): + HEIGHT = 6 # Arbitrary choice + WIDTH = 6 # Arbitrary choice + shape = (module.num_features, HEIGHT, WIDTH) + return shape, shape + else: raise ValueError("Unknown input / output shapes of module", module) @@ -725,48 +737,6 @@ def forward(self, input: Tensor) -> Tensor: return torch.einsum("bi,icdef->bcdef", input, self.tensor) -class WithRNN(ShapedModule): - """Simple model containing an RNN module.""" - - INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8) - OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5) - - def __init__(self): - super().__init__() - self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True) - - def forward(self, input: Tensor) -> Tensor: - return self.rnn(input) - - -class WithModuleTrackingRunningStats(ShapedModule): - """Simple model containing a module that has side-effects and modifies tensors in-place.""" - - INPUT_SHAPES = (3, 6, 6) - OUTPUT_SHAPES = (3, 6, 6) - - def __init__(self): - super().__init__() - self.instance_norm = nn.InstanceNorm2d(3, affine=True, track_running_stats=True) - - def forward(self, input: Tensor) -> Tensor: - return self.instance_norm(input) - - -class WithBatchNorm(ShapedModule): - """Simple model containing a BatchNorm layer.""" - - INPUT_SHAPES = (3, 6, 6) - OUTPUT_SHAPES = (3, 6, 6) - - def __init__(self): - super().__init__() - self.batch_norm = nn.BatchNorm2d(3, affine=True, track_running_stats=False) - - def forward(self, input: Tensor) -> Tensor: - return self.batch_norm(input) - - class WithDropout(ShapedModule): """Simple model containing Dropout layers."""