-
Notifications
You must be signed in to change notification settings - Fork 25
[passes] Introduce ConvertPermuteToReshape pass #521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import torch | ||
|
|
||
| from tico.passes import ops | ||
| from tico.passes.convert_permute_to_reshape import ConvertPermuteToReshape | ||
|
|
||
| from test.utils.helper import num_of_ops | ||
| from test.utils.pass_value_test import SinglePassValueTest | ||
|
|
||
|
|
||
| class PermuteBasic(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, x): | ||
| return torch.permute(x, (1, 2, 3, 0)) | ||
|
|
||
| def get_example_inputs(self): | ||
| return (torch.rand([1, 5, 1, 3]),), {} | ||
|
|
||
|
|
||
| class PermuteBasicTest(SinglePassValueTest): | ||
| def test_pass(self): | ||
| self.setup(PermuteBasic()) | ||
| self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) | ||
| self.run_value_test(ConvertPermuteToReshape(True)) | ||
| self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 0) | ||
| self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 1) | ||
|
|
||
|
|
||
| class PermuteBasicNegative(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, x): | ||
| return torch.permute(x, (2, 3, 0, 1)) | ||
|
|
||
| def get_example_inputs(self): | ||
| return (torch.rand([1, 5, 1, 3]),), {} | ||
|
|
||
|
|
||
| class PermuteBasicNegativeTest(SinglePassValueTest): | ||
| def test_pass(self): | ||
| self.setup(PermuteBasicNegative()) | ||
| self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) | ||
| self.run_value_test(ConvertPermuteToReshape(True)) | ||
| self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) | ||
| self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 0) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| import torch.fx | ||
| import torch | ||
| from torch.export import ExportedProgram | ||
|
|
||
| from tico.passes import ops | ||
| from tico.serialize.circle_mapping import extract_shape | ||
| from tico.utils import logging | ||
| from tico.utils.graph import create_node | ||
| from tico.utils.passes import PassBase, PassResult | ||
| from tico.utils.trace_decorators import trace_graph_diff_on_pass | ||
| from tico.utils.utils import is_target_node | ||
| from tico.utils.validate_args_kwargs import PermuteArgs | ||
|
|
||
|
|
||
| @trace_graph_diff_on_pass | ||
| class ConvertPermuteToReshape(PassBase): | ||
| """ | ||
| This pass replaces `aten.permute` to `aten.reshape` when | ||
| the order of output data is exactly same as input data. | ||
| """ | ||
|
|
||
| def __init__(self, enabled: bool = False): | ||
| super().__init__() | ||
| self.enabled = enabled | ||
|
|
||
| def call(self, exported_program: ExportedProgram) -> PassResult: | ||
| if not self.enabled: | ||
| return PassResult(False) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| graph_module = exported_program.graph_module | ||
| graph = graph_module.graph | ||
| modified = False | ||
|
|
||
| for node in graph.nodes: | ||
| if not isinstance(node, torch.fx.Node) or not is_target_node( | ||
| node, ops.aten.permute | ||
| ): | ||
| continue | ||
|
|
||
| # Extract permute arguments | ||
| args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type] | ||
|
|
||
| input = args.input | ||
| dims = args.dims | ||
|
|
||
| input_shape = extract_shape(input) | ||
| normalized_dims = [(d if d >= 0 else d + len(input_shape)) for d in dims] | ||
|
|
||
| # When permute dims with non-1 values have same order, | ||
| # we can replace permute to reshape | ||
| # | ||
| # For example, if | ||
| # - input.shape = [1, x, 1, y] | ||
| # - torch.permute(input, dims = [1, 2, 3, 0]) | ||
| # then dims[0] --> dims[2] keeps same order for 'x' --> 'y'. | ||
| is_same_order = True | ||
| last_dim = -1 | ||
| for dim in normalized_dims: | ||
| if input_shape[dim] == 1: | ||
|
Comment on lines
+71
to
+78
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you normalize the input shape for negative integer cases? ndims = len(input_shape)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed the negative integer cases :( |
||
| continue | ||
|
|
||
| if last_dim < dim: | ||
| last_dim = dim | ||
| else: | ||
| is_same_order = False | ||
| break | ||
|
|
||
| if is_same_order == True: | ||
| with graph.inserting_before(node): | ||
| reshape = create_node( | ||
| graph, | ||
| torch.ops.aten.reshape.default, | ||
| args=(input, [input_shape[dim] for dim in normalized_dims]), | ||
| origin=node, | ||
| ) | ||
|
|
||
| node.replace_all_uses_with(reshape, propagate_meta=False) | ||
| modified = True | ||
| logger.debug( | ||
| f"{node.name} is replaced with {reshape.name} operators" | ||
| ) | ||
|
|
||
| graph.eliminate_dead_code() | ||
| graph.lint() | ||
| graph_module.recompile() | ||
|
|
||
| return PassResult(modified) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, by simply moving this class into test/op/..., operation test can be done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh...?
You mean passes are automatically applied to
test/op/...during the operation tests?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It simply perform 'tico.convert' to the graph module, so it will perform your newly-added pass.