From 317ff8d5e672540920a78f850cd9d090d3bb4d7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 30 Sep 2025 13:58:02 +0200 Subject: [PATCH 1/4] Add WithModuleWithStringArg --- tests/utils/architectures.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 0b26d89e..26cdb277 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -772,6 +772,31 @@ def forward(self, input: Tensor) -> Tensor: return input @ self.linear.weight.T + self.linear.bias +class WithModuleWithStringArg(ShapedModule): + """Model containing a module that has a string argument.""" + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + class WithStringArg(nn.Module): + def __init__(self): + super().__init__() + self.matrix = nn.Parameter(torch.randn(2, 3)) + + def forward(self, input: Tensor, s: str) -> Tensor: + if s == "two": + return input @ self.matrix * 2.0 + else: + return input @ self.matrix + + def __init__(self): + super().__init__() + self.with_string_arg = self.WithStringArg() + + def forward(self, input: Tensor) -> Tensor: + return self.with_string_arg(input, "two") + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the From d8ef771017e9797f00bbf5e9b7a8dc310465c7d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 13:48:13 +0200 Subject: [PATCH 2/4] Change order between string and tensor args in WithModuleWithStringArg * Having the string first has more chance of messing things up --- tests/utils/architectures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 26cdb277..4b93d2d4 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -783,7 +783,7 @@ def __init__(self): super().__init__() self.matrix = nn.Parameter(torch.randn(2, 3)) - def forward(self, input: Tensor, s: str) -> Tensor: + def forward(self, s: str, input: Tensor) -> Tensor: if s == "two": return input @ self.matrix * 2.0 else: @@ -794,7 +794,7 @@ def __init__(self): self.with_string_arg = self.WithStringArg() def forward(self, input: Tensor) -> Tensor: - return self.with_string_arg(input, "two") + return self.with_string_arg("two", input) class FreeParam(ShapedModule): From 42ebcff0db706d396ed4160da105b2d218254f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 14:22:32 +0200 Subject: [PATCH 3/4] Add WithModuleWithStringOutput --- tests/utils/architectures.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 4b93d2d4..073420d2 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -797,6 +797,29 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg("two", input) +class WithModuleWithStringOutput(ShapedModule): + """Model containing a module that has a string output.""" + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + class WithStringOutput(nn.Module): + def __init__(self): + super().__init__() + self.matrix = nn.Parameter(torch.randn(2, 3)) + + def forward(self, input: Tensor) -> tuple[str, Tensor]: + return "test", input @ self.matrix + + def __init__(self): + super().__init__() + self.with_string_output = self.WithStringOutput() + + def forward(self, input: Tensor) -> Tensor: + _, output = self.with_string_output(input) + return output + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the From d48d055e5b82f774c47ab4e80b7730dd9798235a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 14:27:14 +0200 Subject: [PATCH 4/4] Use new architectures in tests (mostly xfail). --- tests/unit/autogram/test_engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 782f572c..bd3922d5 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -51,6 +51,8 @@ WithBuffered, WithDropout, WithModuleTrackingRunningStats, + WithModuleWithStringArg, + WithModuleWithStringOutput, WithNoTensorOutput, WithRNN, WithSideEffect, @@ -164,6 +166,8 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc Randomness, WithModuleTrackingRunningStats, param(WithRNN, marks=mark.xfail_if_cuda), + WithModuleWithStringArg, + param(WithModuleWithStringOutput, marks=mark.xfail), ], ) @mark.parametrize("batch_size", [1, 3, 32])