From d6295b6dc62641c1d16f9cc67333e4465bad1495 Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Fri, 15 May 2026 14:11:20 -0700 Subject: [PATCH] Prevent _safe_softmax decomposition in traceand rewire replaceSafeSoftmaxWithSoftmax MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Prevent _safe_softmax from being decomposed in trace(). Re-add `ReplaceSafeSoftmaxWithSoftmax` pass in `apply_torch_ops_passes()`, before `to_edge()`. ISS performance (FLLM_v2_wrist on Artemis_HiFi4_UT_v3): - Before: 421,662,921 cycles (843ms @ 500MHz) - After: 214,034,116 cycles (428ms @ 500MHz) — 1.97x speedup Differential Revision: D105367634 --- backends/cadence/aot/compiler_funcs.py | 1 + backends/cadence/aot/passes.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index bf7f79127a0..02dcde7fd39 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -35,6 +35,7 @@ def trace( model.eval() decomp_table = torch.export.default_decompositions() + ops_to_keep = [*(ops_to_keep or []), torch.ops.aten._safe_softmax.default] # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any remove_decompositions(decomp_table, ops_to_keep) program = torch.export.export(model, inputs, strict=strict).run_decompositions( diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index 8a03d72420e..f43ac3e4d2c 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -33,6 +33,7 @@ from executorch.backends.cadence.aot.replace_ops import ( CadenceReplaceOpsInGraph, ReplaceMulTensorWithMulAndFullOpsPass, + ReplaceSafeSoftmaxWithSoftmax, ) from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass @@ -131,7 +132,8 @@ def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram: """ aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [ - ReplaceMulTensorWithMulAndFullOpsPass() + ReplaceSafeSoftmaxWithSoftmax(), + ReplaceMulTensorWithMulAndFullOpsPass(), ] # TODO(T230417247): Use PassResult which is currently ignored. PassManager(aten_passes)(expo_program.graph_module)