diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 1bac7674..1ba8cc20 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -48,6 +48,7 @@ SqueezeNet, WithBatchNorm, WithBuffered, + WithDropout, WithModuleTrackingRunningStats, WithNoTensorOutput, WithRNN, @@ -104,6 +105,7 @@ (Ndim2Output, 32), (Ndim3Output, 32), (Ndim4Output, 32), + (WithDropout, 32), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index fc47c32c..facf852c 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -740,6 +740,21 @@ def forward(self, input: Tensor) -> Tensor: return self.batch_norm(input) +class WithDropout(ShapedModule): + """Simple model containing Dropout layers.""" + + INPUT_SHAPES = (3, 6, 6) + OUTPUT_SHAPES = (3, 4, 4) + + def __init__(self): + super().__init__() + self.conv2d = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.dropout = nn.Dropout2d(p=0.5) + + def forward(self, input: Tensor) -> Tensor: + return self.dropout(self.conv2d(self.dropout(input))) + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the