diff --git a/test/unit_test/pass_test/test_convert_permute_to_reshape.py b/test/unit_test/pass_test/test_convert_permute_to_reshape.py new file mode 100644 index 00000000..97a19c06 --- /dev/null +++ b/test/unit_test/pass_test/test_convert_permute_to_reshape.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from tico.passes import ops +from tico.passes.convert_permute_to_reshape import ConvertPermuteToReshape + +from test.utils.helper import num_of_ops +from test.utils.pass_value_test import SinglePassValueTest + + +class PermuteBasic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.permute(x, (1, 2, 3, 0)) + + def get_example_inputs(self): + return (torch.rand([1, 5, 1, 3]),), {} + + +class PermuteBasicTest(SinglePassValueTest): + def test_pass(self): + self.setup(PermuteBasic()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) + self.run_value_test(ConvertPermuteToReshape(True)) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 0) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 1) + + +class PermuteBasicNegative(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.permute(x, (2, 3, 0, 1)) + + def get_example_inputs(self): + return (torch.rand([1, 5, 1, 3]),), {} + + +class PermuteBasicNegativeTest(SinglePassValueTest): + def test_pass(self): + self.setup(PermuteBasicNegative()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) + self.run_value_test(ConvertPermuteToReshape(True)) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 0) diff --git a/tico/passes/convert_permute_to_reshape.py b/tico/passes/convert_permute_to_reshape.py new file mode 100644 index 00000000..859f3cf8 --- /dev/null +++ b/tico/passes/convert_permute_to_reshape.py @@ -0,0 +1,106 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch.export import ExportedProgram + +from tico.passes import ops +from tico.serialize.circle_mapping import extract_shape +from tico.utils import logging +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.utils import is_target_node +from tico.utils.validate_args_kwargs import PermuteArgs + + +@trace_graph_diff_on_pass +class ConvertPermuteToReshape(PassBase): + """ + This pass replaces `aten.permute` to `aten.reshape` when + the order of output data is exactly same as input data. + """ + + def __init__(self, enabled: bool = False): + super().__init__() + self.enabled = enabled + + def call(self, exported_program: ExportedProgram) -> PassResult: + if not self.enabled: + return PassResult(False) + + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph = graph_module.graph + modified = False + + for node in graph.nodes: + if not isinstance(node, torch.fx.Node) or not is_target_node( + node, ops.aten.permute + ): + continue + + # Extract permute arguments + args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + input = args.input + dims = args.dims + + input_shape = extract_shape(input) + normalized_dims = [(d if d >= 0 else d + len(input_shape)) for d in dims] + + # When permute dims with non-1 values have same order, + # we can replace permute to reshape + # + # For example, if + # - input.shape = [1, x, 1, y] + # - torch.permute(input, dims = [1, 2, 3, 0]) + # then dims[0] --> dims[2] keeps same order for 'x' --> 'y'. + is_same_order = True + last_dim = -1 + for dim in normalized_dims: + if input_shape[dim] == 1: + continue + + if last_dim < dim: + last_dim = dim + else: + is_same_order = False + break + + if is_same_order == True: + with graph.inserting_before(node): + reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(input, [input_shape[dim] for dim in normalized_dims]), + origin=node, + ) + + node.replace_all_uses_with(reshape, propagate_meta=False) + modified = True + logger.debug( + f"{node.name} is replaced with {reshape.name} operators" + ) + + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + + return PassResult(modified) diff --git a/tico/utils/convert.py b/tico/utils/convert.py index eb08dd0b..1de56d01 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -29,6 +29,7 @@ from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear +from tico.passes.convert_permute_to_reshape import ConvertPermuteToReshape from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy from tico.passes.convert_sym_size_to_circle_shape import ConvertSymSizeToCircleShape from tico.passes.convert_to_relu6 import ConvertToReLU6 @@ -273,6 +274,7 @@ def convert_exported_module_to_circle( *LowerToSlicePasses(), FuseLeadingUnsqueezeReshape(), CastClampMixedTypeArgs(), + ConvertPermuteToReshape(), ] ) circle_legalize.run(exported_program)