From c424b98db9904ebc42cdc10b140d5b7ba2192b8a Mon Sep 17 00:00:00 2001 From: Hamlin Li Date: Wed, 1 Apr 2026 09:20:09 -0700 Subject: [PATCH] Enable autotuning for layernorm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: ## Motivation Enables Helion autotuning (both FiniteSearch and LFBOTreeSearch) for layernorm kernels on MTIA. Previously, autotuning was completely broken on MTIA — even FiniteSearch would crash immediately. ## Change Summary Three changes across 4 files: 1. helion/autotuner/local_cache.py — Adds an elif dev.type == "mtia": branch in _generate_key() so that hardware and runtime_name are populated for MTIA devices. Without this, the method hit assert hardware is not None and runtime_name is not None and crashed. 2. helion/autotuner/base_search.py — Skips setting TRITON_STORE_BINARY_ONLY=1 on MTIA (guarded by supports_mtia_tunables()). The MTIA Triton backend uses binary_ext="bin", which isn't in the upstream hardcoded allowlist ("cubin", "hsaco", "json"), causing a KeyError("Unknown key: 'bin'"). 3. ads_mkl/ops/helion/tests/helion_layernorm_autotune_test.py (new) — Adds three test cases: - test_autotune_finite_search — FiniteSearch with 3 explicit MTIA configs - test_autotune_full_search — LFBOTreeSearch with autotune_effort="quick" - test_pointer_indexing — Verifies pointer indexing works on MTIA 4. ads_mkl/ops/helion/tests/BUCK — Adds the python_unittest_athena target for the new test. ## Background there is currently no test that combines layernorm autotuning via FiniteSearch on MTIA. Here's why: What exists today 1. FiniteSearch tests (helion/test/test_autotuner.py) -- only use simple add/multiply kernels, no layernorm, no MTIA. 2. Layernorm on MTIA (ads_mkl/ops/helion/) -- bypasses autotuning entirely and uses hardcoded configs: - layer_norm.py -- get_hardcoded_layernorm_fwd_kernel_mtia() returns a fixed Config(block_sizes=[64], indexing="block_ptr", pid_type="flat") - layer_norm.py -- get_hardcoded_layernorm_bwd_kernel_mtia() similarly hardcoded - The test (tests/helion_layernorm_test.py) runs on MTIA Athena but always hits these hardcoded paths. 3. MTIA tunable tests (helion/test/fb/test_mtia_tunables.py) -- test cb_multiplier_strategy / dual_core_strategy with autotune_effort="none", using simple kernels. The gap: no test exercises the combination of: - A layernorm kernel - FiniteSearch (or any real autotuning) - MTIA hardware/device FiniteSearch did not work on MTIA. The call chain was: 1. kernel.autotune(args, force=False) 2. -> Backend.autotune() -> creates FiniteSearch 3. -> FiniteSearch wraps in LocalAutotuneCache (via autotuner_fn) 4. -> LocalAutotuneCache.__init__() calls _generate_key() 5. -> AssertionError -- no MTIA branch, runtime_name is None ## Revisions merged with: - D99065250, support autotune via FiniteSearch - D99066221, support full autotuner via LFBOTreeSearch Differential Revision: D99064834 --- helion/autotuner/base_search.py | 7 ++++++- helion/autotuner/local_cache.py | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 0a8779b02..b1e6343c6 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -1063,7 +1063,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config: # metadata are needed for execution. env_overrides = {"TRITON_LOCAL_BUILD": "1"} if "TRITON_STORE_BINARY_ONLY" not in os.environ: - env_overrides["TRITON_STORE_BINARY_ONLY"] = "1" + from .._compat import supports_mtia_tunables + + # TRITON_STORE_BINARY_ONLY is incompatible with the MTIA + # Triton backend which raises KeyError("Unknown key: 'bin'"). + if not supports_mtia_tunables(): + env_overrides["TRITON_STORE_BINARY_ONLY"] = "1" exit_stack.enter_context(patch.dict(os.environ, env_overrides, clear=False)) assert self._precompile_tmpdir is None tempdir = tempfile.TemporaryDirectory() diff --git a/helion/autotuner/local_cache.py b/helion/autotuner/local_cache.py index 64bd36fbd..9439c8a8c 100644 --- a/helion/autotuner/local_cache.py +++ b/helion/autotuner/local_cache.py @@ -151,6 +151,12 @@ def _generate_key(self) -> LooseAutotuneCacheKey: runtime_name = getattr(torch_tpu, "__version__", "unknown") except ImportError: runtime_name = "unknown" + elif dev.type == "mtia": + hardware = hardware or "mtia" + try: + runtime_name = str(torch.mtia.get_device_properties(dev)) + except Exception: + runtime_name = "unknown" assert hardware is not None and runtime_name is not None config_spec_hash = self.kernel.config_spec.structural_fingerprint_hash()