Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions test/unit_test/pass_test/test_convert_permute_to_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from tico.passes import ops
from tico.passes.convert_permute_to_reshape import ConvertPermuteToReshape

from test.utils.helper import num_of_ops
from test.utils.pass_value_test import SinglePassValueTest


class PermuteBasic(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.permute(x, (1, 2, 3, 0))

def get_example_inputs(self):
return (torch.rand([1, 5, 1, 3]),), {}
Comment on lines +24 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, by simply moving this class into test/op/..., operation test can be done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh...?
You mean passes are automatically applied to test/op/... during the operation tests?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It simply perform 'tico.convert' to the graph module, so it will perform your newly-added pass.



class PermuteBasicTest(SinglePassValueTest):
def test_pass(self):
self.setup(PermuteBasic())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1)
self.run_value_test(ConvertPermuteToReshape(True))
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 0)
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 1)


class PermuteBasicNegative(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.permute(x, (2, 3, 0, 1))

def get_example_inputs(self):
return (torch.rand([1, 5, 1, 3]),), {}


class PermuteBasicNegativeTest(SinglePassValueTest):
def test_pass(self):
self.setup(PermuteBasicNegative())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1)
self.run_value_test(ConvertPermuteToReshape(True))
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1)
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 0)
106 changes: 106 additions & 0 deletions tico/passes/convert_permute_to_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import torch.fx
import torch
from torch.export import ExportedProgram

from tico.passes import ops
from tico.serialize.circle_mapping import extract_shape
from tico.utils import logging
from tico.utils.graph import create_node
from tico.utils.passes import PassBase, PassResult
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.utils import is_target_node
from tico.utils.validate_args_kwargs import PermuteArgs


@trace_graph_diff_on_pass
class ConvertPermuteToReshape(PassBase):
"""
This pass replaces `aten.permute` to `aten.reshape` when
the order of output data is exactly same as input data.
"""

def __init__(self, enabled: bool = False):
super().__init__()
self.enabled = enabled

def call(self, exported_program: ExportedProgram) -> PassResult:
if not self.enabled:
return PassResult(False)

logger = logging.getLogger(__name__)

graph_module = exported_program.graph_module
graph = graph_module.graph
modified = False

for node in graph.nodes:
if not isinstance(node, torch.fx.Node) or not is_target_node(
node, ops.aten.permute
):
continue

# Extract permute arguments
args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type]

input = args.input
dims = args.dims

input_shape = extract_shape(input)
normalized_dims = [(d if d >= 0 else d + len(input_shape)) for d in dims]

# When permute dims with non-1 values have same order,
# we can replace permute to reshape
#
# For example, if
# - input.shape = [1, x, 1, y]
# - torch.permute(input, dims = [1, 2, 3, 0])
# then dims[0] --> dims[2] keeps same order for 'x' --> 'y'.
is_same_order = True
last_dim = -1
for dim in normalized_dims:
if input_shape[dim] == 1:
Comment on lines +71 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you normalize the input shape for negative integer cases?

ndims = len(input_shape)
normalized_dims = [(d if d >= 0 else d + ndims) for d in dims]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed the negative integer cases :(
Thanks!!

continue

if last_dim < dim:
last_dim = dim
else:
is_same_order = False
break

if is_same_order == True:
with graph.inserting_before(node):
reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(input, [input_shape[dim] for dim in normalized_dims]),
origin=node,
)

node.replace_all_uses_with(reshape, propagate_meta=False)
modified = True
logger.debug(
f"{node.name} is replaced with {reshape.name} operators"
)

graph.eliminate_dead_code()
graph.lint()
graph_module.recompile()

return PassResult(modified)
2 changes: 2 additions & 0 deletions tico/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
from tico.passes.convert_permute_to_reshape import ConvertPermuteToReshape
from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
from tico.passes.convert_sym_size_to_circle_shape import ConvertSymSizeToCircleShape
from tico.passes.convert_to_relu6 import ConvertToReLU6
Expand Down Expand Up @@ -273,6 +274,7 @@ def convert_exported_module_to_circle(
*LowerToSlicePasses(),
FuseLeadingUnsqueezeReshape(),
CastClampMixedTypeArgs(),
ConvertPermuteToReshape(),
]
)
circle_legalize.run(exported_program)
Expand Down