diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index d72782f1..7a231528 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -49,7 +49,7 @@ def __init__(self): nn.Linear(13, 14), ) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.seq(input) @@ -64,7 +64,7 @@ def __init__(self): self.matrix1 = nn.Parameter(torch.randn(50, 60)) self.matrix2 = nn.Parameter(torch.randn(50, 60)) - def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + def forward(self, inputs: tuple[Tensor, Tensor]) -> Tensor: input1, input2 = inputs output = input1 @ self.matrix1 + input2 @ self.matrix2 return output @@ -133,7 +133,7 @@ def __init__(self): self.matrix4 = nn.Parameter(torch.randn(12, 80)) self.matrix5 = nn.Parameter(torch.randn(14, 90)) - def forward(self, inputs: PyTree) -> PyTree: + def forward(self, inputs: PyTree) -> Tensor: input1 = inputs["one"][0][0] input2 = inputs["one"][0][1][0] input3 = inputs["one"][0][1][1] @@ -259,7 +259,7 @@ def __init__(self): super().__init__() self.sipo = SingleInputPyTreeOutput() - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: first, second, third = self.sipo(input).values() output1, output23 = first output2, output3 = output23 @@ -282,7 +282,7 @@ def __init__(self): super().__init__() self.piso = PyTreeInputSingleOutput() - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: input1 = input[:, 0:10] input2 = input[:, 10:30] input3 = input[:, 30:60] @@ -363,7 +363,7 @@ def __init__(self, shape: tuple[int, ...]): super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) - def forward(self, _: PyTree) -> PyTree: + def forward(self, _: PyTree) -> tuple: return tuple() class _EmptyPytreeOutput(nn.Module): @@ -382,7 +382,7 @@ def __init__(self): self.empty_pytree_output = self._EmptyPytreeOutput((27, 10)) self.linear = nn.Linear(27, 10) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: _ = self.none_output(input) _ = self.none_pytree_output(input) _ = self.empty_tuple_output(input) @@ -400,7 +400,7 @@ def __init__(self): super().__init__() self.matrix = nn.Parameter(torch.randn(50, 10)) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return input @ self.matrix + (input**2 / 5.0) @ self.matrix @@ -429,7 +429,7 @@ def __init__(self): self.module1 = self._MatMulModule(matrix) self.module2 = self._MatMulModule(matrix) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.module1(input) + self.module2(input**2 / 5.0) @@ -443,7 +443,7 @@ def __init__(self): super().__init__() self.module = nn.Linear(50, 10) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.module(input) + self.module(input**2 / 5.0) @@ -458,7 +458,7 @@ def __init__(self): self.unused_param = nn.Parameter(torch.randn(50, 10)) self.matrix = nn.Parameter(torch.randn(50, 10)) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return input @ self.matrix @@ -476,7 +476,7 @@ def __init__(self): self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.matrix = nn.Parameter(torch.randn(50, 10)) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return input @ self.matrix + (input**2 / 5.0) @ self.frozen_param @@ -494,7 +494,7 @@ def __init__(self): self.all_frozen = nn.Linear(50, 10) self.all_frozen.requires_grad_(False) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.all_frozen(input) + self.non_frozen(input**2 / 5.0) @@ -522,7 +522,7 @@ def __init__(self): self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.non_frozen_param = nn.Parameter(torch.randn(50, 10)) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return input @ self.frozen_param def __init__(self): @@ -530,7 +530,7 @@ def __init__(self): self.weird_module = self.SomeFrozenParamAndUnusedTrainableParam() self.normal_module = nn.Linear(10, 3) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.normal_module(self.weird_module(input)) @@ -548,7 +548,7 @@ def __init__(self): self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.matrix = nn.Parameter(torch.randn(50, 10)) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> tuple[Tensor, Tensor]: return (input**2 / 5.0) @ self.frozen_param, input @ self.matrix @@ -563,7 +563,7 @@ def __init__(self): super().__init__() self.buffer = nn.Buffer(torch.tensor(1.5)) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return input * self.buffer def __init__(self): @@ -571,7 +571,7 @@ def __init__(self): self.module_with_buffer = self._Buffered() self.linear = nn.Linear(27, 10) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.linear(self.module_with_buffer(input)) @@ -585,7 +585,7 @@ def __init__(self): super().__init__() self.matrix = nn.Parameter(torch.randn(9, 10)) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: noise = torch.zeros_like(input) noise.normal_() return (input * noise) @ self.matrix @@ -602,7 +602,7 @@ def __init__(self): self.matrix = nn.Parameter(torch.randn(9, 10)) self.buffer = nn.Buffer(torch.zeros((9,))) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: self.buffer = self.buffer + 1.0 return (input + self.buffer) @ self.matrix @@ -621,7 +621,7 @@ def __init__(self): self.linear1 = nn.Linear(9, 12) self.linear2 = nn.Linear(9, 10) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: _ = self.linear1(input) output = self.linear2(input) return output @@ -705,7 +705,7 @@ def __init__(self): super().__init__() self.rnn = nn.RNN(input_size=10, hidden_size=5) - def forward(self, input: Tensor) -> Tensor: + def forward(self, input: Tensor) -> None: pass @@ -755,7 +755,7 @@ def __init__(self): self.linear3 = nn.Linear(60, 70) self.linear4 = nn.Linear(70, 80) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: output = self.relu(input @ self.matrix) output = self.relu(self.linear1(output)) output = self.relu(self.linear2(output)) @@ -782,7 +782,7 @@ def __init__(self): self.linear3 = nn.Linear(60, 70) self.linear4 = nn.Linear(70, 80) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: output = self.relu(self.linear0(input)) output = self.relu(self.linear1(output)) output = self.relu(self.linear2(output)) @@ -864,7 +864,7 @@ def __init__(self): super().__init__() self.alexnet = torchvision.models.alexnet() - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.alexnet(input) @@ -883,7 +883,7 @@ def __init__(self): norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True) ) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.resnet18(input) / 5.0 @@ -899,7 +899,7 @@ def __init__(self): norm_layer=partial(nn.GroupNorm, 2, affine=True) ) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.mobile_net(input) @@ -913,7 +913,7 @@ def __init__(self): super().__init__() self.squeezenet = torchvision.models.squeezenet1_0() - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.squeezenet(input) @@ -929,7 +929,7 @@ def __init__(self): norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True) ) - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return self.mobilenet(input) / 10.0