diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index a7943bfc19b..23bf5f4cf37 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -32,6 +32,11 @@ class Var(torch.nn.Module): ), } + test_parameters_ethosu = { + "var_4d_keep_dim_0_correction": lambda: (torch.randn(1, 50, 10, 20), True, 0), + "var_4d_keep_dim_1_correction": lambda: (torch.randn(1, 30, 15, 20), True, 1), + } + def __init__(self, keepdim: bool = True, correction: int = 0): super().__init__() self.keepdim = keepdim @@ -170,7 +175,7 @@ def test_var_dim_tosa_INT_no_dim(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", Var.test_parameters) +@common.parametrize("test_data", Var.test_parameters_ethosu) @common.XfailIfNoCorstone300 def test_var_dim_u55_INT_no_dim(test_data: Tuple): test_data, keepdim, correction = test_data() @@ -183,7 +188,7 @@ def test_var_dim_u55_INT_no_dim(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", Var.test_parameters) +@common.parametrize("test_data", Var.test_parameters_ethosu) @common.XfailIfNoCorstone320 def test_var_dim_u85_INT_no_dim(test_data: Tuple): test_data, keepdim, correction = test_data() @@ -224,6 +229,36 @@ def test_var_dim_vgf_quant_no_dim(test_data: Tuple): pipeline.run() +@common.parametrize("test_data", Var.test_parameters_ethosu) +@common.XfailIfNoCorstone300 +def test_var_a16w8_u55_INT(test_data: Tuple): + test_data, keepdim, correction = test_data() + pipeline = EthosU55PipelineINT[input_t1]( + Var(keepdim, correction), + (test_data,), + aten_ops=[], + exir_ops=[], + a16w8_quantization=True, + symmetric_io_quantization=True, + ) + pipeline.run() + + +@common.parametrize("test_data", Var.test_parameters_ethosu) +@common.XfailIfNoCorstone320 +def test_var_a16w8_u85_INT(test_data: Tuple): + test_data, keepdim, correction = test_data() + pipeline = EthosU85PipelineINT[input_t1]( + Var(keepdim, correction), + (test_data,), + aten_ops=[], + exir_ops=[], + a16w8_quantization=True, + symmetric_io_quantization=True, + ) + pipeline.run() + + ############# ## VarDim ### ############# diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index b8030ae7ba8..30fa348414f 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -39,6 +39,7 @@ def define_arm_tests(): "ops/test_exp.py", "ops/test_reciprocal.py", "ops/test_mean_dim.py", + "ops/test_var.py", ] # Quantization