diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 782f572c..bd3922d5 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -51,6 +51,8 @@ WithBuffered, WithDropout, WithModuleTrackingRunningStats, + WithModuleWithStringArg, + WithModuleWithStringOutput, WithNoTensorOutput, WithRNN, WithSideEffect, @@ -164,6 +166,8 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc Randomness, WithModuleTrackingRunningStats, param(WithRNN, marks=mark.xfail_if_cuda), + WithModuleWithStringArg, + param(WithModuleWithStringOutput, marks=mark.xfail), ], ) @mark.parametrize("batch_size", [1, 3, 32]) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 0b26d89e..073420d2 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -772,6 +772,54 @@ def forward(self, input: Tensor) -> Tensor: return input @ self.linear.weight.T + self.linear.bias +class WithModuleWithStringArg(ShapedModule): + """Model containing a module that has a string argument.""" + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + class WithStringArg(nn.Module): + def __init__(self): + super().__init__() + self.matrix = nn.Parameter(torch.randn(2, 3)) + + def forward(self, s: str, input: Tensor) -> Tensor: + if s == "two": + return input @ self.matrix * 2.0 + else: + return input @ self.matrix + + def __init__(self): + super().__init__() + self.with_string_arg = self.WithStringArg() + + def forward(self, input: Tensor) -> Tensor: + return self.with_string_arg("two", input) + + +class WithModuleWithStringOutput(ShapedModule): + """Model containing a module that has a string output.""" + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + class WithStringOutput(nn.Module): + def __init__(self): + super().__init__() + self.matrix = nn.Parameter(torch.randn(2, 3)) + + def forward(self, input: Tensor) -> tuple[str, Tensor]: + return "test", input @ self.matrix + + def __init__(self): + super().__init__() + self.with_string_output = self.WithStringOutput() + + def forward(self, input: Tensor) -> Tensor: + _, output = self.with_string_output(input) + return output + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the