diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 0a8779b02..b1e6343c6 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -1063,7 +1063,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config: # metadata are needed for execution. env_overrides = {"TRITON_LOCAL_BUILD": "1"} if "TRITON_STORE_BINARY_ONLY" not in os.environ: - env_overrides["TRITON_STORE_BINARY_ONLY"] = "1" + from .._compat import supports_mtia_tunables + + # TRITON_STORE_BINARY_ONLY is incompatible with the MTIA + # Triton backend which raises KeyError("Unknown key: 'bin'"). + if not supports_mtia_tunables(): + env_overrides["TRITON_STORE_BINARY_ONLY"] = "1" exit_stack.enter_context(patch.dict(os.environ, env_overrides, clear=False)) assert self._precompile_tmpdir is None tempdir = tempfile.TemporaryDirectory() diff --git a/helion/autotuner/local_cache.py b/helion/autotuner/local_cache.py index 64bd36fbd..9439c8a8c 100644 --- a/helion/autotuner/local_cache.py +++ b/helion/autotuner/local_cache.py @@ -151,6 +151,12 @@ def _generate_key(self) -> LooseAutotuneCacheKey: runtime_name = getattr(torch_tpu, "__version__", "unknown") except ImportError: runtime_name = "unknown" + elif dev.type == "mtia": + hardware = hardware or "mtia" + try: + runtime_name = str(torch.mtia.get_device_properties(dev)) + except Exception: + runtime_name = "unknown" assert hardware is not None and runtime_name is not None config_spec_hash = self.kernel.config_spec.structural_fingerprint_hash()