From 392f10f71e9c4b64565a74cd09e4d6a255a6fdfb Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Sat, 23 May 2026 16:28:17 +0800 Subject: [PATCH 1/2] [Feat] Support dataclass & automatic triton_kernel wrapping for triton_op registration --- magi_compiler/_magi_register_custom_op.py | 355 ++++- magi_compiler/_triton_introspect.py | 685 ++++++++++ magi_compiler/api.py | 62 +- tests/api_tests/_triton_external_helpers.py | 87 ++ tests/api_tests/test_register_custom_op.py | 39 + tests/api_tests/test_register_triton_op.py | 1358 +++++++++++++++++++ 6 files changed, 2577 insertions(+), 9 deletions(-) create mode 100644 magi_compiler/_triton_introspect.py create mode 100644 tests/api_tests/_triton_external_helpers.py create mode 100644 tests/api_tests/test_register_triton_op.py diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index 2fa1d7d..ea44538 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -53,11 +53,17 @@ import dataclasses import functools import inspect -from typing import Any, Callable, get_args, get_origin +from typing import Any, Callable, Literal, get_args, get_origin import torch import torch.utils._pytree as pytree +from ._triton_introspect import ( + DEFAULT_MAX_INTROSPECT_DEPTH, + IntrospectionResult, + introspect_fn, + rewrite_fn_with_wrap_triton, +) from .config import get_compile_config from .utils.logger import magi_logger @@ -65,6 +71,8 @@ # BLOCK 0 -- VALIDATE op signature constraints # # Helpers: +# - op-name guards: +# _assert_op_name_valid # - type predicates: # _is_frozen_dataclass # - assertion primitives: @@ -75,6 +83,42 @@ # ============================================================================== +# Op names already registered via this decorator in the current process. +_REGISTERED_OP_NAMES: set[str] = set() + + +def _assert_op_name_valid(op_name: str) -> None: + """Reject ``op_name`` if it lacks the ``namespace::op_name`` form or has + already been registered (here or on ``torch.ops``). Surfaces a clear error + instead of ``torch.library``'s opaque schema-fingerprint mismatch.""" + if "::" not in op_name: + raise ValueError( + f"@magi_register_custom_op: op name {op_name!r} is missing a " + "namespace. Use ``namespace::op_name`` (e.g. " + "``my_lib::my_op``). Pick a unique namespace for your project to " + "avoid clashing with other libraries." + ) + if op_name in _REGISTERED_OP_NAMES: + raise RuntimeError( + f"@magi_register_custom_op: op name {op_name!r} is already " + "registered. Each magi op must use a unique " + "``namespace::op_name``. If you really want to override, delete " + "the previous registration with " + "``torch.library._del_library_impl`` first, or pass an explicit " + "``name=`` to disambiguate." + ) + ns, _, opname = op_name.partition("::") + if ns and opname: + ns_obj = getattr(torch.ops, ns, None) + if ns_obj is not None and hasattr(ns_obj, opname): + raise RuntimeError( + f"@magi_register_custom_op: op name {op_name!r} is already " + f"defined on torch.ops.{ns}. Use a different name (or pass an " + "explicit ``name=`` to your decorator) to avoid clashing with " + "an existing operator." + ) + + def _is_frozen_dataclass(tp) -> bool: """Return True if ``tp`` is a frozen dataclass type.""" return ( @@ -559,6 +603,14 @@ def _lower_through( # _create_identity_meta_fn, _create_meta_fn_from_param_names # - op-name generation: # _generate_op_name +# - triton introspection consumer (works off a precomputed +# :class:`IntrospectionResult` -- a single recursive BFS feeds the +# heuristics-rejection, kernel-resolution, and nested-op classification +# paths, eliminating the old asymmetry where nested ops were only +# scanned at the top level while kernels were scanned recursively): +# _reject_heuristics_outermost +# - registration-path decision: +# _classify_nested_ops, _decide_registration_path # Core: _register_torch_op # ============================================================================== @@ -677,10 +729,203 @@ def _generate_op_name(fn: Callable) -> str: return f"{namespace}::{func_name}" +# ------------------------------------------------------------------------------ +# helpers: triton introspection & wrap_triton intercept +# ------------------------------------------------------------------------------ + + +def _reject_heuristics_outermost( + introspection: IntrospectionResult, + extra_triton_kernels: list[Any] | tuple[Any, ...] | None, +) -> None: + """Reject kernels whose outermost decorator is ``@triton.heuristics`` + (``wrap_triton`` only accepts JIT/Autotuner; surface here before the + path decision so the bug isn't silently demoted to ``custom_op``).""" + candidates = list(extra_triton_kernels or ()) + list(introspection.referenced_heuristics) + if not candidates: + return + from triton.runtime.autotuner import Autotuner, Heuristics + from triton.runtime.jit import JITFunction + + for k in candidates: + if isinstance(k, (JITFunction, Autotuner)): + continue + if isinstance(k, Heuristics): + name = getattr(getattr(k, "fn", None), "__name__", repr(k)) + raise RuntimeError( + f"@magi_register_custom_op: triton kernel {name!r} has " + "@triton.heuristics as its outermost decorator. " + "torch.library.wrap_triton (and therefore triton_op / Inductor) " + "only accepts triton.jit or triton.autotune at the top level. " + "Either remove @triton.heuristics, or place @triton.autotune " + "outside it: @triton.autotune -> @triton.heuristics -> @triton.jit." + ) + + +# ------------------------------------------------------------------------------ +# helpers: classify nested ops & decide registration path +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass(frozen=True) +class _NestedOpClassification: + """``"ns::op"`` strings bucketed into triton_op / custom_op / unresolved + (unresolved = not-yet-registered or torch lacks the introspection API; + treated as custom_op downstream -- conservative fusion-barrier choice).""" + + triton_op_names: tuple[str, ...] + custom_op_names: tuple[str, ...] + unresolved_names: tuple[str, ...] + + +def _classify_nested_ops(nested_op_calls: tuple[str, ...]) -> _NestedOpClassification: + """Bucket ``nested_op_calls`` into triton_op / custom_op / unresolved via + ``torch._library.triton.get_triton_kernels_for_op`` (non-empty -> triton_op) + plus ``torch._library.custom_ops.OPDEFS`` (definitive custom_op).""" + if not nested_op_calls: + return _NestedOpClassification((), (), ()) + + try: + from torch._library.triton import get_triton_kernels_for_op + except ImportError: # older PyTorch + get_triton_kernels_for_op = None # type: ignore[assignment] + + try: + from torch._library.custom_ops import OPDEFS + except ImportError: # very old PyTorch + OPDEFS = {} # type: ignore[assignment] + + triton_op_names: list[str] = [] + custom_op_names: list[str] = [] + unresolved_names: list[str] = [] + for op_name in nested_op_calls: + is_triton_op: bool | None = None + if get_triton_kernels_for_op is not None: + try: + # Non-empty kernel list => registered via triton_op. + is_triton_op = bool(get_triton_kernels_for_op(op_name)) + except Exception: + is_triton_op = None + + if is_triton_op is True: + triton_op_names.append(op_name) + elif is_triton_op is False and op_name in OPDEFS: + # Definitively known to ``custom_ops``: it's a plain custom_op. + custom_op_names.append(op_name) + else: + # Op not yet registered, or torch lacks the introspection API. + unresolved_names.append(op_name) + + return _NestedOpClassification( + triton_op_names=tuple(triton_op_names), + custom_op_names=tuple(custom_op_names), + unresolved_names=tuple(unresolved_names), + ) + + +# mode: "triton_op" (Inductor sees through) | "custom_op" (opaque barrier) | +# "none" (skip registration so Inductor inlines and fuses nested triton_ops). +@dataclasses.dataclass(frozen=True) +class _RegistrationDecision: + mode: Literal["triton_op", "custom_op", "none"] + reason: str + + +def _decide_registration_path( + fn: Callable, + has_direct_kernel: bool, + nested: _NestedOpClassification, + force_register_mode: Literal["triton_op", "custom_op"] | None, +) -> _RegistrationDecision: + """Pick triton_op / custom_op / none from fn body content (8-case matrix). + + ============================================ ========================= ============== ============== + body ``None`` (default) ``"triton_op"`` ``"custom_op"`` + ============================================ ========================= ============== ============== + 1. direct kernel only triton_op triton_op custom_op + 2. nested triton_op only none + warning triton_op custom_op + 3. nested custom_op only custom_op ValueError custom_op + 4. nested triton_op + custom_op (no kernel) none + warning ValueError custom_op + 5. direct kernel + nested triton_op triton_op triton_op custom_op + 6. direct kernel + nested custom_op ValueError ValueError custom_op + 7. direct kernel + nested t.op + nested c.op ValueError ValueError custom_op + 8. nothing custom_op ValueError custom_op + ============================================ ========================= ============== ============== + + 6/7 reject because bare-kernel + fusion-barrier in one body is almost + always a bug; 2/4 inline because that maximises fusion across the + nested triton_ops at zero cost. ``force=custom_op`` is the universal + "I know what I'm doing" override. + """ + has_triton_op = bool(nested.triton_op_names) + has_custom_op = bool(nested.custom_op_names) or bool(nested.unresolved_names) # unresolved -> assume custom_op + fn_label = getattr(fn, "__qualname__", getattr(fn, "__name__", repr(fn))) + + custom_op_list = list(nested.custom_op_names) or list(nested.unresolved_names) + + def _err_force_triton_with_custom_op() -> str: + return ( + f"@magi_register_custom_op: cannot register {fn_label!r} as triton_op: " + f"body calls custom_op(s) {custom_op_list!r} (fusion barriers, not allowed " + "inside triton_op). Drop force_register_mode='triton_op' or move the call out." + ) + + def _err_direct_kernel_plus_custom_op() -> str: + return ( + f"@magi_register_custom_op: {fn_label!r} mixes a bare triton kernel with " + f"custom_op(s) {custom_op_list!r} (fusion barrier). Move one out, or pass " + "force_register_mode='custom_op' to silence." + ) + + # forced modes + if force_register_mode == "custom_op": + return _RegistrationDecision("custom_op", "explicit force_register_mode='custom_op'") + + if force_register_mode == "triton_op": + if has_custom_op: # cases 3,4,6,7 + raise ValueError(_err_force_triton_with_custom_op()) + if not has_direct_kernel and not has_triton_op: # case 8 + raise ValueError( + f"@magi_register_custom_op: cannot register {fn_label!r} as triton_op: " + "no direct triton kernel nor nested triton_op. Drop force_register_mode." + ) + return _RegistrationDecision("triton_op", "explicit force_register_mode='triton_op'") + + # auto-decision (force_register_mode is None) + if has_direct_kernel and has_custom_op: # cases 6,7 + raise ValueError(_err_direct_kernel_plus_custom_op()) + + if has_direct_kernel: # cases 1,5 + return _RegistrationDecision("triton_op", "direct triton kernel(s) detected") + + if has_triton_op and not has_custom_op: # case 2 + magi_logger.warning( + "@magi_register_custom_op: %r has only nested triton_op(s) %s; skipping " + "registration so Inductor can inline for fusion. Pass force_register_mode to override.", + fn_label, + list(nested.triton_op_names), + ) + return _RegistrationDecision("none", "nested triton_op(s) only; inlining for fusion") + + if has_custom_op and has_triton_op: # case 4 + magi_logger.warning( + "@magi_register_custom_op: %r has nested triton_op(s) %s + custom_op(s) %s, " + "no direct kernel; skipping registration so Inductor can inline. " + "Pass force_register_mode='custom_op' to override.", + fn_label, + list(nested.triton_op_names), + list(nested.custom_op_names) + list(nested.unresolved_names), + ) + return _RegistrationDecision("none", "nested triton_op + custom_op only; inlining") + + if has_custom_op: # case 3 + return _RegistrationDecision("custom_op", "nested custom_op(s) only") + + return _RegistrationDecision("custom_op", "no triton content detected") # case 8 + + # ------------------------------------------------------------------------------ # core: _register_torch_op -# -# Forward reference: ``_DataclassRuntimeAdapter`` using ``from __future__ import annotations``. # ------------------------------------------------------------------------------ @@ -691,21 +936,28 @@ def _register_torch_op( infer_output_meta_fn: Callable | list[str] | None, setup_context_fn: Callable | None, backward_fn: Callable | None, + mode: Literal["triton_op", "custom_op"], dataclass_runtime_adapter: _DataclassRuntimeAdapter | None = None, + bare_triton_kernels: list[Any] | None = None, + excluded_kernel_ids: set[int] | None = None, ): - """Register the op in torch.library.custom_op.""" + """Register via ``torch.library`` on the path selected by ``mode``; on + ``"triton_op"`` first shadow ``bare_triton_kernels`` via + :func:`rewrite_fn_with_wrap_triton`, and fall back to ``custom_op`` on failure.""" effective_mutates_args = ( dataclass_runtime_adapter.expand_mutates_args(mutates_args) if dataclass_runtime_adapter is not None else mutates_args ) - torch_registered_op = torch.library.custom_op(op_name, mutates_args=effective_mutates_args)(fn) - # Build & register the meta/fake function. + # Build the meta/fake fn up front -- both registration paths need it. if infer_output_meta_fn is None: meta_fn = _create_identity_meta_fn(fn) + user_supplied_meta = False elif isinstance(infer_output_meta_fn, list): meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) + user_supplied_meta = True elif dataclass_runtime_adapter is None: # No flattening scenario meta_fn = infer_output_meta_fn + user_supplied_meta = True else: # Flattening scenario user_meta = infer_output_meta_fn @@ -714,7 +966,47 @@ def _bridged_meta_fn(*args, **kwargs): _bridged_meta_fn.__signature__ = inspect.signature(fn) meta_fn = _bridged_meta_fn - torch.library.register_fake(op_name)(meta_fn) + user_supplied_meta = True + + torch_registered_op = None + if mode == "triton_op": + try: + from torch.library import triton_op + except ImportError: + triton_op = None # type: ignore[assignment] + magi_logger.warning( + "torch.library.triton_op not available; falling back to torch.library.custom_op for op %s", + op_name, + ) + if triton_op is not None: + try: + fn_for_register = rewrite_fn_with_wrap_triton( + fn, bare_triton_kernels or [], excluded_kernel_ids=excluded_kernel_ids + ) + # rewriter rebuilds via FunctionType; restamp sig/annotations for infer_schema + if fn_for_register is not fn: + sig_override = fn.__dict__.get("__signature__") + if sig_override is not None: + fn_for_register.__signature__ = sig_override + if getattr(fn, "__annotations__", None): + fn_for_register.__annotations__ = dict(fn.__annotations__) + torch_registered_op = triton_op(op_name, mutates_args=effective_mutates_args)(fn_for_register) + # triton_op self-registers fake; only override if user-supplied + if user_supplied_meta: + torch_registered_op.register_fake(meta_fn) + except Exception: + magi_logger.warning( + "triton_op registration failed for %s; falling back to " + "custom_op + register_fake. Inductor will not be able to " + "see through the op.", + op_name, + exc_info=True, + ) + torch_registered_op = None + + if torch_registered_op is None: + torch_registered_op = torch.library.custom_op(op_name, mutates_args=effective_mutates_args)(fn) + torch.library.register_fake(op_name)(meta_fn) # Register autograd. if backward_fn is not None: @@ -894,11 +1186,15 @@ def _magi_register_custom_op_impl( backward_fn: Callable | None = None, is_compute_sensitive: bool = False, is_subgraph_boundary: bool = False, + extra_triton_kernels: list[Any] | tuple[Any, ...] | None = None, + force_register_mode: Literal["triton_op", "custom_op"] | None = None, + max_introspect_depth: int = DEFAULT_MAX_INTROSPECT_DEPTH, ): def decorator(fn: Callable) -> Callable: # A 4-slot pipeline. op_name = name if name is not None else _generate_op_name(fn) + _assert_op_name_valid(op_name) if is_compute_sensitive: get_compile_config().recompute_config.custom_compute_sensitive_ops.append(op_name) if is_subgraph_boundary: @@ -908,6 +1204,41 @@ def decorator(fn: Callable) -> Callable: original_sig, lowered_sig, param_mapping_tree = _lower_op_signature(fn) needs_flattening = any(kind == "dataclass" for kind, *_ in param_mapping_tree) + # Step A: single AST pass feeds every downstream check. + introspection = introspect_fn( + fn, + extra_triton_kernels=extra_triton_kernels, + max_depth=max_introspect_depth, + ) + + # Reject top-level ``@triton.heuristics`` before path decision for a precise error. + _reject_heuristics_outermost(introspection, extra_triton_kernels) + + nested = _classify_nested_ops(introspection.nested_op_calls) + decision = _decide_registration_path( + fn, + has_direct_kernel=introspection.has_direct_kernel, + nested=nested, + force_register_mode=force_register_mode, + ) + magi_logger.debug( + "@magi_register_custom_op: %s -> mode=%s (%s)", + op_name, + decision.mode, + decision.reason, + ) + + # mode="none" -> skip registration so Inductor inlines fn (warning already emitted). + if decision.mode == "none": + return fn + + # Step B: project introspection -> rewriter inputs (custom_op skips wrap_triton). + if decision.mode == "triton_op": + bare_triton_kernels = list(introspection.bare_triton_kernels) + user_wrapped_ids = set(introspection.user_wrapped_kernel_ids) + else: + bare_triton_kernels, user_wrapped_ids = [], set() + if not needs_flattening: # ----- No-flattening scenario ----- # Path: fn -> [lowered_fn ->] torch_registered_op @@ -932,9 +1263,13 @@ def lowered_fn(*args, **kwargs): infer_output_meta_fn=infer_output_meta_fn, setup_context_fn=setup_context_fn, backward_fn=backward_fn, + mode=decision.mode, dataclass_runtime_adapter=None, + bare_triton_kernels=bare_triton_kernels, + excluded_kernel_ids=user_wrapped_ids, ) + _REGISTERED_OP_NAMES.add(op_name) # Return bare torch-level op (slot 2). return torch_registered_op @@ -962,6 +1297,8 @@ def lowered_fn(*args, **kwargs): dataclass_runtime_adapter.apply_lowered_signature(lowered_fn) # Step 2: Register the op in torch and get ``torch_registered_op``. + # On this path the rewriter shadows kernels on ``lowered_fn`` + # (its ``__globals__`` were copied from ``fn`` by functools.wraps). torch_registered_op = _register_torch_op( op_name=op_name, fn=lowered_fn, @@ -969,6 +1306,9 @@ def lowered_fn(*args, **kwargs): infer_output_meta_fn=infer_output_meta_fn, setup_context_fn=setup_context_fn, backward_fn=backward_fn, + mode=decision.mode, + bare_triton_kernels=bare_triton_kernels, + excluded_kernel_ids=user_wrapped_ids, dataclass_runtime_adapter=dataclass_runtime_adapter, ) @@ -981,6 +1321,7 @@ def magi_exposed_op(*args, **kwargs): magi_exposed_op._magi_torch_registered_op = torch_registered_op magi_exposed_op._magi_param_mapping_tree = param_mapping_tree + _REGISTERED_OP_NAMES.add(op_name) # Return magi-level op. return magi_exposed_op diff --git a/magi_compiler/_triton_introspect.py b/magi_compiler/_triton_introspect.py new file mode 100644 index 0000000..2cefc92 --- /dev/null +++ b/magi_compiler/_triton_introspect.py @@ -0,0 +1,685 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton kernel introspection used by ``magi_register_custom_op``. + +:func:`introspect_fn` walks ``fn`` + helpers ONCE, returning every kernel / +nested-op / heuristics reference downstream consumers need. +:func:`rewrite_fn_with_wrap_triton` shadows kernel refs in ``fn``'s globals +and closures with ``wrap_triton(k)`` so Inductor can trace through them. +""" + +from __future__ import annotations + +import ast +import dataclasses +import functools +import inspect +import logging +import types +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +# Default helper-recursion depth; overridable via +# ``magi_register_custom_op(max_introspect_depth=...)``. +DEFAULT_MAX_INTROSPECT_DEPTH: int = 5 + + +__all__ = [ + "DEFAULT_MAX_INTROSPECT_DEPTH", + "IntrospectionResult", + "introspect_fn", + "rewrite_fn_with_wrap_triton", +] + + +# ============================================================================== +# SECTION 0 -- Single-pass AST scan (shared primitive) +# ============================================================================== + + +@dataclasses.dataclass(frozen=True) +class AstScanResult: + """Raw AST-level findings about ``fn`` (identifiers as written in source; + caller resolves them via the function's globals/closure). + + Attributes: + wrapped_kernel_names: identifiers passed to ``wrap_triton`` / + ``capture_triton`` (already wrapped -- do NOT re-shadow). + bare_kernel_names: identifiers launched as ``k[grid](...)``, + ``mod.k[grid](...)``, ``k.run(...)`` or surfaced via bare + ``return k``; dotted forms recorded as a single string. + called_helpers: plain function-call identifiers (helpers, recursed). + nested_op_calls: ``"ns::op"`` strings for ``torch.ops..(...)`` + (NOT recursed into -- registered ops stay opaque). + assignments: ``var -> [RHS expr, ...]`` for alias tracing + (``k = make_kernel(); k[grid](...)``). + """ + + wrapped_kernel_names: tuple[str, ...] + bare_kernel_names: tuple[str, ...] + called_helpers: tuple[str, ...] + nested_op_calls: tuple[str, ...] + assignments: dict[str, list[ast.expr]] + + +def _dotted_attr_name(node: ast.AST) -> Optional[str]: + """Return ``"a.b.c"`` if ``node`` is an Attribute chain rooted at a + ``Name``, else ``None`` (forms like ``factory().k`` are not statically + resolvable and fall back to ``extra_triton_kernels``).""" + parts: list[str] = [] + cur = node + while isinstance(cur, ast.Attribute): + parts.append(cur.attr) + cur = cur.value + if isinstance(cur, ast.Name): + parts.append(cur.id) + return ".".join(reversed(parts)) + return None + + +def _is_wrap_triton_call(node: ast.AST) -> bool: + """True if ``node`` is a ``wrap_triton``/``capture_triton`` call (any of + the forms ``_AstCollector.visit_Call`` recognises).""" + if not isinstance(node, ast.Call): + return False + triton_func_names = ("capture_triton", "wrap_triton") + triton_wrap_modules = ("_library", "library") # public + origin + f = node.func + if isinstance(f, ast.Name) and f.id in triton_func_names: + return True + if ( + isinstance(f, ast.Attribute) + and f.attr in triton_func_names + and isinstance(f.value, ast.Attribute) + and f.value.attr in triton_wrap_modules + and isinstance(f.value.value, ast.Name) + and f.value.value.id == "torch" + ): + return True + return False + + +def _names_outside_wrap_calls(expr: ast.expr) -> list[str]: + """Collect every ``Name`` in ``expr`` except those inside a + ``wrap_triton``/``capture_triton`` call (already counted by ``visit_Call``).""" + names: list[str] = [] + + class _Collector(ast.NodeVisitor): + def visit_Call(self, node: ast.Call) -> None: + if _is_wrap_triton_call(node): + return + self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> None: + names.append(node.id) + + _Collector().visit(expr) + return names + + +class _AstCollector(ast.NodeVisitor): + """Single AST walker producing the data for :class:`AstScanResult`.""" + + _TRITON_FUNC_NAMES = ("capture_triton", "wrap_triton") + _TRITON_WRAP_MODULES = ("_library", "library") + + def __init__(self) -> None: + self.wrapped_kernel_names: list[str] = [] + self.bare_kernel_names: list[str] = [] + self.called_helpers: list[str] = [] + self.nested_op_calls: list[str] = [] + self.assignments: dict[str, list[ast.expr]] = {} + + def visit_Return(self, node: ast.Return) -> None: + # A helper may surface a kernel via ``return k`` for the caller to + # launch. Names inside ``wrap_triton(...)`` are skipped here and + # picked up by ``visit_Call`` so we don't double-count them as bare. + if node.value is not None: + for name in _names_outside_wrap_calls(node.value): + self.bare_kernel_names.append(name) + self.generic_visit(node) + + def visit_Call(self, node: ast.Call) -> None: + # Recognised shapes: + # A.1 torch.[_]library.wrap_triton(k) -> wrapped_kernel_names + # A.2 torch.ops..(...) -> nested_op_calls (opaque) + # A.3 wrap_triton(k) / capture_triton(k) -> wrapped_kernel_names + # A.4 any other Name(...) call -> called_helpers (recursed) + # A.5 .run(*args, grid=...) -> bare_kernel_names + # (Triton's low-level launch API; non-kernel .run filtered at resolve) + # Bare k[grid](...) / mod.k[grid](...) -> Subscript branch below + if isinstance(node.func, ast.Attribute): + attr = node.func + handled = False + if isinstance(attr.value, ast.Attribute): + if ( + isinstance(attr.value.value, ast.Name) + and attr.value.value.id == "torch" + and attr.value.attr in self._TRITON_WRAP_MODULES + and attr.attr in self._TRITON_FUNC_NAMES + ): + # A.1 + if node.args and isinstance(node.args[0], ast.Name): + self.wrapped_kernel_names.append(node.args[0].id) + handled = True + elif ( + isinstance(attr.value.value, ast.Attribute) + and isinstance(attr.value.value.value, ast.Name) + and attr.value.value.value.id == "torch" + and attr.value.value.attr == "ops" + ): + # A.2 + self.nested_op_calls.append(f"{attr.value.attr}::{attr.attr}") + handled = True + if not handled and attr.attr == "run": + # A.5 + dotted = _dotted_attr_name(attr.value) + if dotted is not None: + self.bare_kernel_names.append(dotted) + elif isinstance(node.func, ast.Name): + if node.func.id in self._TRITON_FUNC_NAMES: + # A.3 + if node.args and isinstance(node.args[0], ast.Name): + self.wrapped_kernel_names.append(node.args[0].id) + else: + # A.4 + self.called_helpers.append(node.func.id) + + # Subscript launch: ``Name[grid](...)`` or attribute-chain rooted at a + # Name (``mod.k[grid](...)``). ``self.k[grid](...)`` records but can't + # be resolved statically -- caller must pass ``extra_triton_kernels``. + if isinstance(node.func, ast.Subscript): + base = node.func.value + if isinstance(base, ast.Name): + self.bare_kernel_names.append(base.id) + elif isinstance(base, ast.Attribute): + dotted = _dotted_attr_name(base) + if dotted is not None: + self.bare_kernel_names.append(dotted) + + self.generic_visit(node) + + +def scan_fn_ast(fn: Callable[..., Any]) -> Optional[AstScanResult]: + """Parse ``fn``'s source once and return the raw collector data; returns + ``None`` when source is unavailable (builtins, C-extensions, REPL). + Only inspects *this* frame -- :func:`introspect_fn` drives the recursion.""" + try: + fn = inspect.unwrap(fn) + except ValueError: + pass + + try: + source = inspect.getsource(fn) + except (OSError, TypeError): + return None + + from torch._inductor.utils import IndentedBuffer + + buffer = IndentedBuffer() + buffer.splice(source, strip=True) + + try: + tree = ast.parse(buffer.getrawvalue()) + except SyntaxError: + return None + + collector = _AstCollector() + collector.visit(tree) + + return AstScanResult( + wrapped_kernel_names=tuple(collector.wrapped_kernel_names), + bare_kernel_names=tuple(collector.bare_kernel_names), + called_helpers=tuple(collector.called_helpers), + nested_op_calls=tuple(collector.nested_op_calls), + assignments=collector.assignments, + ) + + +def _build_fn_namespace(func_obj: object) -> dict[str, Any]: + """Combined globals + closures + nonlocals view of ``func_obj``; empty + dict for non-callables or callables without ``__code__``.""" + if callable(func_obj): + try: + func_obj = inspect.unwrap(func_obj) + except ValueError: + pass + if not callable(func_obj) or not hasattr(func_obj, "__code__"): + return {} + closure_vars = inspect.getclosurevars(func_obj) + namespace: dict[str, Any] = {} + namespace.update(closure_vars.builtins) + namespace.update(closure_vars.globals) + namespace.update(closure_vars.nonlocals) + if hasattr(func_obj, "__globals__"): + namespace.update(func_obj.__globals__) + return namespace + + +# ============================================================================== +# SECTION 1 -- Unified single-pass introspection +# Vendored / extended from torch._library.triton (v2.11.0, +# https://github.com/pytorch/pytorch/blob/v2.11.0/torch/_library/triton.py, +# BSD-licensed; original: ``get_inner_triton_kernels``). +# ============================================================================== + + +@dataclasses.dataclass(frozen=True) +class IntrospectionResult: + """Output of :func:`introspect_fn`; all sequences dedup'd (kernels by + ``id``, op names by string) and ordered by first-discovery in the BFS. + + Attributes: + bare_triton_kernels: ``JITFunction``/``Autotuner`` launched as + ``k[grid]``, ``k.run(...)``, or surfaced via ``return k``. + ``Heuristics`` wrappers are *peeled* (inner kernel recorded); + the raw wrapper goes to ``referenced_heuristics``. + user_wrapped_kernels: kernels already passed to + ``wrap_triton``/``capture_triton`` (rewriter must skip these). + referenced_heuristics: raw ``Heuristics`` objects -- used to reject + ``@triton.heuristics``-as-outermost early. + nested_op_calls: ``"ns::op"`` strings for every reached + ``torch.ops..`` (anywhere in the call tree). + """ + + bare_triton_kernels: tuple[Any, ...] + user_wrapped_kernels: tuple[Any, ...] + referenced_heuristics: tuple[Any, ...] + nested_op_calls: tuple[str, ...] + + @property + def has_direct_kernel(self) -> bool: + """True iff a triton kernel that needs the ``triton_op`` path was + found (heuristics-outermost is already rejected upstream, so by + the time this runs ``referenced_heuristics`` is guaranteed empty).""" + return bool(self.bare_triton_kernels) + + @property + def user_wrapped_kernel_ids(self) -> frozenset[int]: + """``id``-set of already-wrapped kernels for the rewriter's + ``excluded_kernel_ids`` argument.""" + return frozenset(id(k) for k in self.user_wrapped_kernels) + + +def _extract_names_from_expr(expr: ast.expr) -> list[str]: + """Pull every ``ast.Name`` reachable from ``expr`` (descends into nested + calls so ``k = factory(arg).method`` surfaces ``factory``).""" + names: list[str] = [] + + class _NameExtractor(ast.NodeVisitor): + def visit_Name(self, node: ast.Name) -> None: + names.append(node.id) + + def visit_Call(self, node: ast.Call) -> None: + self.generic_visit(node) + + _NameExtractor().visit(expr) + return names + + +def introspect_fn( + fn: Callable[..., Any], + *, + extra_triton_kernels: list[Any] | tuple[Any, ...] | None = None, + max_depth: int = DEFAULT_MAX_INTROSPECT_DEPTH, +) -> IntrospectionResult: + """Walk ``fn`` + helpers once (BFS, depth-capped) and return everything + needed downstream; recurses into helper calls but NOT ``torch.ops.*``, + traces ``k = make_kernel()`` aliases, peels ``Heuristics``. Never raises.""" + try: + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + except ImportError: + logger.warning("Triton not available, introspect_fn returns empty result") + return IntrospectionResult((), (), (), ()) + try: + from triton.runtime.autotuner import Heuristics + except ImportError: + Heuristics = None # older triton -- heuristics rejection no-ops + + kernel_types: tuple[type, ...] = (JITFunction, Autotuner) + + # Dedup-on-insert ordered bucket keyed by ``id(obj)`` (or value, for strings). + class _Bucket: + __slots__ = ("items", "_seen", "_key") + + def __init__(self, key=id): + self.items: list[Any] = [] + self._seen: set[Any] = set() + self._key = key + + def add(self, obj: Any) -> None: + k = self._key(obj) + if k not in self._seen: + self._seen.add(k) + self.items.append(obj) + + bare = _Bucket() + user_wrapped = _Bucket() + heuristics = _Bucket() + nested_ops = _Bucket(key=lambda s: s) + + for k in extra_triton_kernels or (): + bare.add(k) + + visited_fns: set[int] = set() + + def _walk(func: Callable[..., Any], depth: int) -> None: + try: + f = inspect.unwrap(func) + except ValueError: + f = func + if id(f) in visited_fns: + return + if depth > max_depth: + logger.debug("reached max introspect depth (%s) in introspect_fn", max_depth) + return + visited_fns.add(id(f)) + + scan = scan_fn_ast(f) + if scan is None: + return + + for op_name in scan.nested_op_calls: + nested_ops.add(op_name) + + namespace = _build_fn_namespace(f) + + def _lookup_dotted(name: str) -> object | None: + """Resolve ``"a.b.c"`` via successive ``getattr`` on ``namespace[a]``.""" + if "." not in name: + return namespace.get(name) + root, *rest = name.split(".") + if root not in namespace: + return None + obj: object = namespace[root] + for attr in rest: + try: + obj = getattr(obj, attr) + except AttributeError: + return None + return obj + + def _classify_name( + name: str, + *, + as_bare: bool, + visited_names: set[str], + ) -> None: + """Classify ``name`` and route it to the right bucket: a kernel + bucket (``bare`` vs ``user_wrapped`` per ``as_bare``), the + ``heuristics`` bucket if the resolved object is a Triton + ``Heuristics``; otherwise recurse into a user helper via + ``_walk`` or follow ``k = make_kernel()`` assignment chains.""" + if name in visited_names: + return + visited_names.add(name) + + obj = _lookup_dotted(name) + if obj is not None: + # Raw-object check first so Heuristics is captured before peel. + if Heuristics is not None and isinstance(obj, Heuristics): + heuristics.add(obj) + kernel = _resolve_kernel(obj, kernel_types) + if kernel is not None: + (bare if as_bare else user_wrapped).add(kernel) + return + if callable(obj): + try: + unwrapped = inspect.unwrap(obj) + except ValueError: + unwrapped = obj + if hasattr(unwrapped, "__code__"): + _walk(unwrapped, depth + 1) + return + logger.debug("failed to resolve %s to a triton kernel", name) + return + + # Trace local aliases like ``k = make_kernel()``. + if name in scan.assignments: + for rhs_expr in scan.assignments[name]: + for sub in _extract_names_from_expr(rhs_expr): + _classify_name(sub, as_bare=as_bare, visited_names=visited_names) + else: + logger.debug("%s not found in namespace or assignments", name) + + # Per-frame visited set: an alias chain shouldn't escape its frame. + bare_visited: set[str] = set() + for n in scan.bare_kernel_names: + _classify_name(n, as_bare=True, visited_names=bare_visited) + + wrapped_visited: set[str] = set() + for n in scan.wrapped_kernel_names: + _classify_name(n, as_bare=False, visited_names=wrapped_visited) + + for helper_name in scan.called_helpers: + helper_obj = namespace.get(helper_name) + if helper_obj is None or not callable(helper_obj): + continue + # ``Heuristics`` is callable (``h(args)`` launches inner kernel); + # record before the ``__code__`` bail so bare ``h(args)`` form is caught. + if Heuristics is not None and isinstance(helper_obj, Heuristics): + heuristics.add(helper_obj) + if not hasattr(helper_obj, "__code__"): + continue + try: + _walk(helper_obj, depth + 1) + except Exception: + logger.debug( + "failed to analyze called helper %s", helper_name, exc_info=True + ) + + _walk(fn, 0) + + return IntrospectionResult( + bare_triton_kernels=tuple(bare.items), + user_wrapped_kernels=tuple(user_wrapped.items), + referenced_heuristics=tuple(heuristics.items), + nested_op_calls=tuple(nested_ops.items), + ) + + +# ============================================================================== +# SECTION 2 -- Runtime ``wrap_triton`` shadow rewriter +# Inductor needs ``wrap_triton(k)[grid]``; this clones ``fn`` with globals/ +# closures rewritten so bare ``k`` resolves to the wrapped version. +# ============================================================================== + + +def _resolve_kernel(obj: object, kernel_types: tuple[type, ...]) -> Optional[object]: + """Peel ``obj`` to the underlying ``JITFunction``/``Autotuner`` (returns + ``obj`` if it already is one, unwrapped ``obj.fn`` for thin wrappers like + ``Heuristics``, else ``None``). The result feeds ``wrap_triton``.""" + if isinstance(obj, kernel_types): + return obj + if callable(obj) and hasattr(obj, "fn"): + try: + inner = obj.fn + except Exception: + return None + if isinstance(inner, kernel_types): + return inner + return None + + +def _is_user_helper(obj: object) -> bool: + """True if ``obj`` is a plain Python function we can recursively rebuild + (excludes triton kernels, builtins, and torch/triton internals).""" + if not isinstance(obj, types.FunctionType): + return False + code = getattr(obj, "__code__", None) + if code is None: + return False + mod = getattr(obj, "__module__", "") or "" + if mod.startswith(("torch._library", "triton.")): + return False + return True + + +def rewrite_fn_with_wrap_triton( + fn: Callable[..., Any], kernels: list[object], excluded_kernel_ids: Optional[set[int]] = None +) -> Callable[..., Any]: + """Return a clone of ``fn`` whose globals/closures shadow each ``k`` in + ``kernels`` with ``wrap_triton(k)``; helper functions called from ``fn`` + are rebuilt the same way. Original objects are not modified.""" + if not kernels: + return fn + + # Triton must be importable here: callers only reach this point after + # ``introspect_fn`` produced non-empty kernels (which requires triton). + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + kernel_types: tuple[type, ...] = (JITFunction, Autotuner) + + try: + from torch.library import wrap_triton + except ImportError: + try: + from torch._library.triton import wrap_triton # type: ignore + except ImportError: + logger.debug("wrap_triton unavailable; skipping rewrite") + return fn + + # ``id(kernel) -> wrap_triton(kernel)`` (cache so identical kernels share + # one wrapper; ``wrapped_value_ids`` is the O(1) inverse for ``_maybe_wrap``). + wrapped_cache: dict[int, Any] = {} + wrapped_value_ids: set[int] = set() + + def _wrap_once(k: object) -> Any: + kid = id(k) + if kid not in wrapped_cache: + wrapper = wrap_triton(k) + wrapped_cache[kid] = wrapper + wrapped_value_ids.add(id(wrapper)) + return wrapped_cache[kid] + + # Pre-populate cache with explicitly detected kernels so identical objects + # encountered later resolve to the same wrapper. + target_ids: set[int] = set() + for k in kernels: + if isinstance(k, kernel_types): + _wrap_once(k) + target_ids.add(id(k)) + else: + resolved = _resolve_kernel(k, kernel_types) + if resolved is not None: + _wrap_once(resolved) + target_ids.add(id(resolved)) + + excluded_kernel_ids = set(excluded_kernel_ids or set()) + + def _maybe_wrap(obj: object) -> Optional[Any]: + """Return ``wrap_triton(obj)`` if it's a target kernel; ``None`` if + ``obj`` should be left alone (already-wrapped, excluded, or non-kernel).""" + if id(obj) in wrapped_value_ids: # already a wrap_triton wrapper + return None + + resolved = _resolve_kernel(obj, kernel_types) + if resolved is None: + return None + # Caller flagged this kernel as already user-wrapped in source; don't + # shadow its module-globals ref or ``wrap_triton(wrap_triton(k))`` results. + if id(resolved) in excluded_kernel_ids: + return None + if id(resolved) in target_ids or isinstance(resolved, kernel_types): + # Wrap any encountered kernel (not just initially-detected ones) so + # dynamically-resolved kernels in helper globals are also captured. + return _wrap_once(resolved) + return None + + rebuilt_fns: dict[int, Callable[..., Any]] = {} + + # All functions in a module share one ``__globals__`` dict; rewrite it once + # per module (else O(N_helpers * N_globals_per_module) blows up). + rebuilt_globals: dict[int, dict[str, Any]] = {} + + def _build_new_globals(old_globals: dict[str, Any]) -> dict[str, Any]: + gid = id(old_globals) + if gid in rebuilt_globals: + return rebuilt_globals[gid] + new_globals: dict[str, Any] = dict(old_globals) + # Pre-register so reentrant _rebuild (helper back-refs module) terminates. + rebuilt_globals[gid] = new_globals + + for name, obj in list(old_globals.items()): + wrapped = _maybe_wrap(obj) + if wrapped is not None: + new_globals[name] = wrapped + continue + if _is_user_helper(obj): + try: + new_globals[name] = _rebuild(obj) + except Exception: + logger.debug("failed to rebuild helper %s", name, exc_info=True) + return new_globals + + def _rebuild(f: Callable[..., Any]) -> Callable[..., Any]: + if not isinstance(f, types.FunctionType): + return f + if id(f) in rebuilt_fns: + return rebuilt_fns[id(f)] + + # Pre-register a placeholder so back-references through globals/closures + # don't recurse forever; the real new_fn replaces it at the bottom. + rebuilt_fns[id(f)] = f + + new_globals = _build_new_globals(f.__globals__) + + new_closure: Optional[tuple] = None + if f.__closure__ is not None: + new_cells = [] + for cell in f.__closure__: + try: + contents = cell.cell_contents + except ValueError: + new_cells.append(cell) # empty cell + continue + + wrapped = _maybe_wrap(contents) + if wrapped is not None: + new_cells.append(types.CellType(wrapped)) + continue + if _is_user_helper(contents) and id(contents) != id(f): + try: + new_cells.append(types.CellType(_rebuild(contents))) + continue + except Exception: + logger.debug("failed to rebuild closure helper %s", getattr(contents, "__name__", "?"), exc_info=True) + new_cells.append(cell) + new_closure = tuple(new_cells) + + new_fn = types.FunctionType(f.__code__, new_globals, f.__name__, f.__defaults__, new_closure) + # Preserve metadata for infer_schema / register_fake. + try: + functools.update_wrapper(new_fn, f, updated=()) + except Exception: + pass + new_fn.__kwdefaults__ = f.__kwdefaults__ + new_fn.__module__ = f.__module__ + new_fn.__qualname__ = f.__qualname__ + # Drop __wrapped__: ``inspect.unwrap`` must stop at the rewritten fn, + # otherwise it walks back to ``f`` whose globals lack wrap_triton. + try: + del new_fn.__wrapped__ + except AttributeError: + pass + + rebuilt_fns[id(f)] = new_fn + return new_fn + + return _rebuild(fn) diff --git a/magi_compiler/api.py b/magi_compiler/api.py index 3f94989..b933e8a 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -15,7 +15,7 @@ import copy import functools import inspect -from typing import Callable, TypeVar +from typing import Any, Callable, Literal, TypeVar from ._api import ( _check_dynamic_arg_dims, @@ -202,6 +202,9 @@ def magi_register_custom_op( backward_fn: Callable | None = None, is_compute_sensitive: bool = False, is_subgraph_boundary: bool = False, + extra_triton_kernels: list[Any] | tuple[Any, ...] | None = None, + force_register_mode: Literal["triton_op", "custom_op"] | None = None, + max_introspect_depth: int = 5, ): """ A unified decorator to register a custom operator with PyTorch's library. @@ -233,12 +236,64 @@ def magi_register_custom_op( ops are prioritised for saving rather than recomputing. is_subgraph_boundary: Split the FX graph at this op during compilation. Each sub-graph between boundary ops is compiled independently. + extra_triton_kernels: Escape hatch for triton kernels the AST scanner + can't pick up (e.g. ``self.kernel[grid](...)`` -- subscripted + attributes are not statically resolvable). Listed kernels are + treated as bare and shadowed via ``wrap_triton``. Also forces + ``has_direct_kernel = True`` in the decision matrix below. + force_register_mode: Override the auto-selected registration path: + - ``None`` (default): auto-decide; may also choose to *not* + register at all (returns ``fn`` unchanged with a warning) so + Inductor can inline the body for maximum fusion across nested + ops. See the decision matrix below. + - ``"triton_op"``: force ``torch.library.triton_op`` registration + (Inductor traces through). Raises ``ValueError`` if the body + contains a ``custom_op`` (fusion barriers cannot live inside + a triton_op) or has no triton content at all. + - ``"custom_op"``: force ``torch.library.custom_op`` (opaque + fusion barrier). Always succeeds. + max_introspect_depth: How many levels of helper-function calls the + AST scanner follows when looking for triton kernels. Default + ``5``. Doesn't bound flat AST scanning of ``fn`` itself, only + recursion into its callees. Nested ``torch.ops..(...)`` + calls are never followed. + + Registration-path decision matrix (when ``force_register_mode is None``): + + +----+--------------------------------------+---------------------------+ + | | body of ``fn`` | default path | + +====+======================================+===========================+ + | 1 | direct triton kernel only | ``triton_op`` | + +----+--------------------------------------+---------------------------+ + | 2 | nested triton_op only | ``none`` (warns, inlines) | + +----+--------------------------------------+---------------------------+ + | 3 | nested custom_op only | ``custom_op`` | + +----+--------------------------------------+---------------------------+ + | 4 | nested triton_op + custom_op only | ``none`` (warns, inlines) | + +----+--------------------------------------+---------------------------+ + | 5 | direct kernel + nested triton_op | ``triton_op`` | + +----+--------------------------------------+---------------------------+ + | 6 | direct kernel + nested custom_op | ``ValueError`` (mistake) | + +----+--------------------------------------+---------------------------+ + | 7 | direct kernel + nested triton_op + | ``ValueError`` (mistake) | + | | nested custom_op | | + +----+--------------------------------------+---------------------------+ + | 8 | nothing triton-related | ``custom_op`` | + +----+--------------------------------------+---------------------------+ + + Cases 6 / 7 are rejected because mixing a bare triton kernel with an + opaque ``custom_op`` in the same body is almost always a mistake (the + barrier already prevents fusing the kernel with anything else). Use + ``force_register_mode="custom_op"`` to silence the check. Returns: A callable with the user's original signature. Examples: - 1. Basic usage (forward only, auto-generated name and meta function): + + #### Basic usage + + 1. Forward only, auto-generated name and meta function: >>> @magi_register_custom_op() ... def my_relu(x: torch.Tensor) -> torch.Tensor: @@ -313,4 +368,7 @@ def magi_register_custom_op( backward_fn=backward_fn, is_compute_sensitive=is_compute_sensitive, is_subgraph_boundary=is_subgraph_boundary, + extra_triton_kernels=extra_triton_kernels, + force_register_mode=force_register_mode, + max_introspect_depth=max_introspect_depth, ) diff --git a/tests/api_tests/_triton_external_helpers.py b/tests/api_tests/_triton_external_helpers.py new file mode 100644 index 0000000..0547d9d --- /dev/null +++ b/tests/api_tests/_triton_external_helpers.py @@ -0,0 +1,87 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""External helper module used by ``test_register_triton_op.py``. + +The helpers below intentionally live in their own module so that, when the +test file imports them and calls them inside a ``magi_register_custom_op``- +decorated function, the helpers' ``__globals__`` are *this* module, not the +test module. That exercises the truly cross-module rebuild path in +``rewrite_fn_with_wrap_triton``. +""" + +from __future__ import annotations + +""" +External helper module for test_register_triton_op.py to verify +cross-module triton kernel introspection. +""" +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: # pragma: no cover + HAS_TRITON = False + + +if HAS_TRITON: + + @triton.jit + def external_neg_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, -x, mask=mask) + + @triton.jit + def external_double_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x * 2, mask=mask) + + def external_neg_launcher(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + external_neg_kernel[((n + 127) // 128,)](x, out, n, BLOCK_SIZE=128) + return out + + def external_double_launcher(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + external_double_kernel[((n + 127) // 128,)](x, out, n, BLOCK_SIZE=128) + return out + + def maybe_capture(kernel): + """Third-party-style thin wrapper around a triton kernel. + + Some libraries return objects with a ``.fn`` attribute pointing back to + the underlying ``JITFunction``; we mimic that pattern here so the test + can confirm ``rewrite_fn_with_wrap_triton`` still recognises the + underlying kernel when users write ``maybe_capture(kernel)[grid](...)``. + """ + + class _Captured: + def __init__(self, k): + self.fn = k # introspector recognises objects with .fn + + def __getitem__(self, grid): + return self.fn[grid] + + return _Captured(kernel) diff --git a/tests/api_tests/test_register_custom_op.py b/tests/api_tests/test_register_custom_op.py index 071cee0..c645109 100644 --- a/tests/api_tests/test_register_custom_op.py +++ b/tests/api_tests/test_register_custom_op.py @@ -2127,5 +2127,44 @@ def outer(x: torch.Tensor) -> torch.Tensor: assert_close(outer(x), x * 3.0 + 1) +class TestDuplicateOpNameRejected: + """Re-registering the same ``namespace::op_name`` should raise a clear + error instead of letting ``torch.library`` complain about schema + fingerprints. + """ + + def test_duplicate_name_rejected(self): + @magi_register_custom_op(name="test::dup_name_first") + def _op_a(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + with pytest.raises(RuntimeError, match="already"): + + @magi_register_custom_op(name="test::dup_name_first") + def _op_b(x: torch.Tensor) -> torch.Tensor: + return x + 2 + + +class TestOpNameNamespaceRequired: + """``torch.library`` requires ``namespace::op_name``; a bare name causes a + confusing low-level error. We surface a clear, actionable message + pointing at the convention. + """ + + def test_missing_namespace_rejected(self): + with pytest.raises(ValueError, match="namespace"): + + @magi_register_custom_op(name="missing_namespace_op") + def _op(x: torch.Tensor) -> torch.Tensor: + return x + + def test_namespaced_name_accepted(self): + @magi_register_custom_op(name="test::ns_ok") + def _op(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + assert_close(_op(torch.zeros(2)), torch.ones(2)) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/api_tests/test_register_triton_op.py b/tests/api_tests/test_register_triton_op.py new file mode 100644 index 0000000..2fb687b --- /dev/null +++ b/tests/api_tests/test_register_triton_op.py @@ -0,0 +1,1358 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +""" +This test suite covers the ``triton_op`` auto-detection and registration path. + +Coverage Matrix & Sections: +--------------------------- +SECTION 1: Direct Kernel Launch Patterns + - Flat direct kernel call (``kernel[grid](...)``) + - Multiple kernels in sequence + - Kernel launched inside a closure + - Multilevel nesting & Helper functions launching kernels + +SECTION 2: Wrapped, Dynamic & Exotic Retrievals + - Helper launchers (local, cross-module, 3rd party wrappers) + - ``wrap_triton`` idempotency (Mixing wrapped and bare kernels safely) + - Explicit ``extra_triton_kernels`` override & deduplication + - Staticmethod / Classmethod kernels + - Dynamically fetched / runtime-imported kernels + +SECTION 3: Autotune, Heuristics & Autograd in Triton + - ``@triton.autotune`` kernels (single & multiple configs) + - ``@triton.heuristics`` rejection & graceful fallback + - Autograd combined with Triton kernels + +SECTION 4: End-to-End Tracing + - Pure Inductor see-through proof (AOT graph verification) +""" + +import pytest +import torch +from torch.testing import assert_close + +triton = pytest.importorskip("triton") +tl = pytest.importorskip("triton.language") + +from magi_compiler.api import magi_register_custom_op # noqa: E402 + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="triton kernels require CUDA") + + +# --------------------------------------------------------------------------- +# Module-level kernels (so they live in fn.__globals__ for several scenarios) +# --------------------------------------------------------------------------- + + +@triton.jit +def _cos_kernel(in_ptr0, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.cos(x) + tl.store(out_ptr + offsets, output, mask=mask) + + +@triton.jit +def _scale_kernel(in_ptr0, out_ptr, n_elements, scale, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x * scale + tl.store(out_ptr + offsets, output, mask=mask) + + +@triton.jit +def _add_kernel(a_ptr, b_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, a + b, mask=mask) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE": 128}, num_warps=4), triton.Config({"BLOCK_SIZE": 256}, num_warps=4)], + key=["n_elements"], +) +@triton.jit +def _autotuned_cos_kernel(in_ptr0, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + tl.store(out_ptr + offsets, tl.cos(x), mask=mask) + + +def _grid_1d(n: int): + return ((n + 127) // 128,) + + +# Module-level frozen dataclass fixtures used by the dataclass+triton tests +# below. Defined at module scope (not inside the test methods) so that +# ``typing.get_type_hints`` / ``eval`` on the function's stringified +# annotations (PEP 563 / ``from __future__ import annotations``) can find +# them via ``fn.__globals__``. +from dataclasses import dataclass as _dc_dataclass # noqa: E402 + + +@_dc_dataclass(frozen=True) +class _DcCosCfg: + block_size: int + + +@_dc_dataclass(frozen=True) +class _DcKernelCfg: + block_size: int + extra_offset: float + + +@_dc_dataclass(frozen=True) +class _DcOuterCfg: + kernel: _DcKernelCfg + scale: float + + +@_dc_dataclass(frozen=True) +class _DcShapeCfg: + out_dim: int + + +@_dc_dataclass(frozen=True) +class _DcProjCfg: + shape: _DcShapeCfg + block_size: int + + +# that by defining it at module scope but in its own helper that fn calls. + + +def _scale_launcher(x: torch.Tensor, factor: float) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _scale_kernel[_grid_1d(n)](x, out, n, factor, BLOCK_SIZE=128) + return out + + +def _add_launcher(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(a) + n = a.numel() + _add_kernel[_grid_1d(n)](a, b, out, n, BLOCK_SIZE=128) + return out + + +def _make_cos_kernel(): + @triton.jit + def _kernel(in_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK + offsets = block_start + tl.arange(0, BLOCK) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, tl.cos(x), mask=mask) + + return _kernel + + +def _inner_launcher(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + +def _dispatch_launcher(x: torch.Tensor) -> torch.Tensor: + return _inner_launcher(x) + + +@triton.heuristics({"BLOCK_SIZE": lambda args: 128}) +@triton.jit +def _heuristics_top_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x, mask=mask) + + +@triton.autotune(configs=[triton.Config({}, num_warps=4)], key=["n_elements"]) +@triton.heuristics({"BLOCK_SIZE": lambda args: 128}) +@triton.jit +def _autotune_then_heuristics_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x, mask=mask) + + +class _KernelHolder: + """Holder with kernel exposed via classmethod / staticmethod. The + introspector cannot statically follow ``Holder.get()`` to a kernel at + decoration time; users must use ``extra_triton_kernels=`` instead. + """ + + @staticmethod + def get_static(): + return _scale_kernel + + @classmethod + def get_class(cls): + return _scale_kernel + + +# ============================================================================ +# SECTION 1: Direct Kernel Launch Patterns +# ============================================================================ + + +class TestFlatDirectKernel: + def test_basic_cos(self): + @magi_register_custom_op(name="magi_test::flat_cos") + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + out = mycos(x) + assert_close(out, torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_op_is_triton_op(self): + """Sanity: the registered op should be a triton_op-style CustomOpDef + and torch.compile should be able to see through it.""" + + @magi_register_custom_op(name="magi_test::seethrough_cos") + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + compiled = torch.compile(mycos, backend="inductor", fullgraph=True) + x = torch.randn(2048, device="cuda") + out = compiled(x) + assert_close(out, torch.cos(x), atol=1e-5, rtol=1e-5) + + +class TestMultiKernelSequence: + def test_chain(self): + @magi_register_custom_op(name="magi_test::cos_then_scale") + def fn(x: torch.Tensor, scale: float) -> torch.Tensor: + tmp = torch.empty_like(x) + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=128) + _scale_kernel[_grid_1d(n)](tmp, out, n, scale, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + out = fn(x, 2.5) + assert_close(out, torch.cos(x) * 2.5, atol=1e-5, rtol=1e-5) + + +class TestKernelInsideClosure: + def test_closure_kernel(self): + def make_op(kernel): + @magi_register_custom_op(name=f"magi_test::closure_{id(kernel)}") + def op(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK=128) + return out + + return op + + kernel = _make_cos_kernel() + op = make_op(kernel) + x = torch.randn(2048, device="cuda") + assert_close(op(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# extra_triton_kernels escape hatch (scenario 9-style: kernel hidden behind +# an attribute access the introspector cannot trace). + + +class TestMultiLevelNesting: + def test_fn_to_dispatch_to_launcher_to_kernel(self): + @magi_register_custom_op(name="magi_test::multi_level_cos") + def fn(x: torch.Tensor) -> torch.Tensor: + return _dispatch_launcher(x) + + x = torch.randn(2048, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_introspection_walks_all_levels(self): + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + from magi_compiler._triton_introspect import introspect_fn, rewrite_fn_with_wrap_triton + + def fn(x): + return _dispatch_launcher(x) + + kernels = list(introspect_fn(fn).bare_triton_kernels) + assert _cos_kernel in kernels + + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + rebuilt_dispatch = rewritten.__globals__["_dispatch_launcher"] + rebuilt_inner = rebuilt_dispatch.__globals__["_inner_launcher"] + assert isinstance(rebuilt_inner.__globals__["_cos_kernel"], TraceableTritonKernelWrapper) + + +# Third-party "thin wrapper" pattern: some libraries return objects with a +# ``.fn`` attribute pointing at the underlying triton kernel; the introspector +# already knows how to unwrap that, so kernels invoked via +# ``maybe_capture(kernel)[grid](...)`` should still register as a triton_op. + + +class TestFactoryInsideFn: + def test_factory_inside_fn_runtime(self): + @magi_register_custom_op(name="magi_test::factory_inside_fn") + def fn(x: torch.Tensor) -> torch.Tensor: + kernel = _make_cos_kernel() + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# True cross-module launcher: helpers and kernels live in +# ``tests/api_tests/_triton_external_helpers.py``. The decorated function +# imports them, so ``_rebuild`` has to descend into a helper whose +# ``__globals__`` is a *different* module dict than ``fn.__globals__``. + + +class TestNnModuleSelfKernel: + def test_kernel_on_self(self): + from torch import nn + + class CosModule(nn.Module): + def __init__(self, kernel): + super().__init__() + self._kernel = kernel + self.fn = self._build_fn() + + def _build_fn(self): + kernel = self._kernel + + @magi_register_custom_op(name=f"magi_test::module_self_kernel_{id(self)}", extra_triton_kernels=[kernel]) + def op(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + return op + + def forward(self, x): + return self.fn(x) + + mod = CosModule(_cos_kernel).to("cuda") + x = torch.randn(1024, device="cuda") + assert_close(mod(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# Factory created *inside* fn (kernel is a local variable, not a closure +# captured from outside). The introspector detects the bare ``kernel[grid]`` +# call but the actual kernel object lives only in the runtime locals, so +# rewrite has nothing to shadow. This must still execute correctly because +# ``wrap_triton`` is optional for runtime correctness (only required for +# torch.compile traceability). + + +# ============================================================================ +# SECTION 2: Wrapped, Dynamic & Exotic Retrievals +# ============================================================================ + + +class TestHelperLauncher: + def test_helper_launcher(self): + @magi_register_custom_op(name="magi_test::add_via_launcher") + def add_op(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return _add_launcher(a, b) + + a = torch.randn(2048, device="cuda") + b = torch.randn(2048, device="cuda") + assert_close(add_op(a, b), a + b, atol=1e-5, rtol=1e-5) + + +class TestCrossModuleLauncher: + def test_scale_via_external_launcher(self): + @magi_register_custom_op(name="magi_test::scale_via_external") + def scale_op(x: torch.Tensor, factor: float) -> torch.Tensor: + return _scale_launcher(x, factor) + + x = torch.randn(2048, device="cuda") + assert_close(scale_op(x, 0.25), x * 0.25, atol=1e-5, rtol=1e-5) + + +class TestThirdPartyThinWrapper: + def test_thin_wrapper_kernel(self): + from tests.api_tests._triton_external_helpers import maybe_capture + + @magi_register_custom_op( + name="magi_test::cos_via_thin_wrapper", + # Even though the introspector handles ``.fn``-style wrappers, we + # also pass the raw kernel as ``extra_triton_kernels`` to confirm + # the deduplication path works with this style of call. + extra_triton_kernels=[_cos_kernel], + ) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + wrapped = maybe_capture(_cos_kernel) + wrapped[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# wrap_triton idempotency: if the user already wrote ``wrap_triton(kernel)`` +# explicitly, we must not produce a wrap_triton(wrap_triton(kernel)). + + +class TestTrueCrossModuleLauncher: + def test_external_neg_launcher(self): + from tests.api_tests._triton_external_helpers import external_neg_launcher + + @magi_register_custom_op(name="magi_test::true_cross_module_neg") + def fn(x: torch.Tensor) -> torch.Tensor: + return external_neg_launcher(x) + + x = torch.randn(2048, device="cuda") + assert_close(fn(x), -x, atol=1e-5, rtol=1e-5) + + def test_rewrite_descends_into_other_module(self): + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + from magi_compiler._triton_introspect import introspect_fn, rewrite_fn_with_wrap_triton + from tests.api_tests._triton_external_helpers import external_double_kernel, external_double_launcher + + def fn(x): + # Bare Name call so the introspector can follow it across modules + # via ``called_functions``. + return external_double_launcher(x) + + kernels = list(introspect_fn(fn).bare_triton_kernels) + assert external_double_kernel in kernels + + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + + # ``external_double_launcher`` was captured from the enclosing test + # method's locals, so it lives in ``fn``'s closure (NOT in + # ``__globals__``). The rewrite pass must still descend into it and + # produce a rebuilt copy whose globals reference the wrap_triton- + # aware kernel. + rebuilt_launcher = None + for cell in rewritten.__closure__ or (): + try: + contents = cell.cell_contents + except ValueError: + continue + if callable(contents) and getattr(contents, "__name__", None) == ("external_double_launcher"): + rebuilt_launcher = contents + break + assert rebuilt_launcher is not None, ( + "expected rewrite_fn_with_wrap_triton to keep the launcher in " "the rewritten function's closure" + ) + assert isinstance(rebuilt_launcher.__globals__["external_double_kernel"], TraceableTritonKernelWrapper), ( + "rewrite_fn_with_wrap_triton should rebuild cross-module helpers " + "so the kernel reference inside them is wrap_triton-aware." + ) + + # The ORIGINAL helper module's globals must NOT be mutated; only the + # rebuilt copy carries the wrapper. + from tests.api_tests import _triton_external_helpers as ext_mod + + assert not isinstance( + ext_mod.external_double_launcher.__globals__["external_double_kernel"], TraceableTritonKernelWrapper + ), ( + "rewrite_fn_with_wrap_triton must not mutate the helper's home " + "module globals (other unrelated callers would be affected)." + ) + + +class TestMixedWrappedAndBareKernels: + """When the user has manually wrapped some kernels with ``wrap_triton`` + but left others bare (a common state during incremental migration), the + decorator must wrap only the bare ones (no double-wrap) and the op must + still run. + """ + + def test_mixed_wrapped_and_bare(self): + from torch.library import wrap_triton + + @magi_register_custom_op(name="magi_test::mixed_wrap_state") + def myop(x: torch.Tensor) -> torch.Tensor: + n = x.numel() + mid = torch.empty_like(x) + wrap_triton(_cos_kernel)[_grid_1d(n)](x, mid, n, BLOCK_SIZE=128) + out = torch.empty_like(x) + _scale_kernel[_grid_1d(n)](mid, out, n, 2.0, BLOCK_SIZE=128) + return out + + x = torch.randn(512, device="cuda") + out = myop(x) + assert_close(out, torch.cos(x) * 2.0, atol=1e-5, rtol=1e-5) + + +class TestWrapTritonIdempotent: + def test_user_already_wrapped(self): + from torch.library import wrap_triton + + @magi_register_custom_op(name="magi_test::cos_user_wrapped") + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + wrap_triton(_cos_kernel)[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_rewrite_does_not_double_wrap(self): + """Direct unit test: passing the already-wrapped kernel back through + ``rewrite_fn_with_wrap_triton`` must not produce a double wrapper.""" + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + from torch.library import wrap_triton + + from magi_compiler._triton_introspect import rewrite_fn_with_wrap_triton + + wrapped_kernel = wrap_triton(_cos_kernel) + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + wrapped_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + # Pass the wrapped kernel as the "kernels" argument; the rewrite path + # should pass it through ``_resolve_kernel`` and not re-wrap. + rewritten = rewrite_fn_with_wrap_triton(fn, [wrapped_kernel]) + # The closure cell for ``wrapped_kernel`` (or the rebuilt globals + # entry, depending on closure capture order) must still be a single + # TraceableTritonKernelWrapper, not nested. + seen = [] + if rewritten.__closure__ is not None: + for cell in rewritten.__closure__: + try: + seen.append(cell.cell_contents) + except ValueError: + pass + seen.extend(rewritten.__globals__.values()) + wrappers = [v for v in seen if isinstance(v, TraceableTritonKernelWrapper)] + assert wrappers, "expected at least one wrap_triton wrapper to be present" + for w in wrappers: + inner = getattr(w, "kernel", None) or getattr(w, "fn", None) + assert not isinstance( + inner, TraceableTritonKernelWrapper + ), "rewrite_fn_with_wrap_triton produced a double-wrapped kernel" + + +# infer_output_meta_fn override: both the ``list[str]`` shorthand and the +# explicit ``Callable`` form should be honoured even when we go down the +# triton_op path (because triton_op pre-registers ``fn`` itself as the fake). + + +class TestExtraTritonKernels: + def test_explicit_kernel_list(self): + kernels_holder = type("KH", (), {})() + kernels_holder.k = _cos_kernel + + @magi_register_custom_op(name="magi_test::cos_via_extra", extra_triton_kernels=[_cos_kernel]) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + kernels_holder.k[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# Fallback: no triton kernels => still works (custom_op path). + + +class TestExtraTritonKernelsDedup: + def test_dedup_in_resolve_and_rewrite(self): + from magi_compiler._triton_introspect import introspect_fn, rewrite_fn_with_wrap_triton + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + introspection = introspect_fn(fn, extra_triton_kernels=[_cos_kernel]) + resolved_bare = list(introspection.bare_triton_kernels) + # Should appear exactly once even though it's both passed explicitly + # and discovered by introspection. + assert resolved_bare.count(_cos_kernel) == 1 + assert len(resolved_bare) == 1 + + rewritten = rewrite_fn_with_wrap_triton(fn, resolved_bare) + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + wrapped = rewritten.__globals__["_cos_kernel"] + assert isinstance(wrapped, TraceableTritonKernelWrapper) + + def test_dedup_e2e(self): + @magi_register_custom_op(name="magi_test::dedup_cos", extra_triton_kernels=[_cos_kernel]) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + # Confirm we still went down the triton_op path even though the kernel + # was specified twice (auto-detected + extra_triton_kernels). + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( + "magi_test::dedup_cos" + ), "expected the op to be registered as a triton_op" + + +class TestExtraTritonKernelsForStaticOrClassmethod: + """``staticmethod`` / ``classmethod`` selectors are opaque to source + introspection. ``extra_triton_kernels`` keeps the op on the triton_op + path even so. + """ + + def test_staticmethod_selected_kernel(self): + @magi_register_custom_op(name="magi_test::sm_kernel", extra_triton_kernels=[_scale_kernel]) + def myop(x: torch.Tensor) -> torch.Tensor: + kernel = _KernelHolder.get_static() + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, 2.0, BLOCK_SIZE=128) + return out + + x = torch.randn(256, device="cuda") + out = myop(x) + assert_close(out, x * 2.0) + + def test_classmethod_selected_kernel(self): + @magi_register_custom_op(name="magi_test::cm_kernel", extra_triton_kernels=[_scale_kernel]) + def myop(x: torch.Tensor) -> torch.Tensor: + kernel = _KernelHolder.get_class() + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, 3.0, BLOCK_SIZE=128) + return out + + x = torch.randn(256, device="cuda") + out = myop(x) + assert_close(out, x * 3.0) + + +class TestExtraTritonKernelsForRuntimeImport: + """A kernel imported inside the function body (runtime import) is invisible + to source introspection. ``extra_triton_kernels`` works around that. + """ + + def test_runtime_imported_kernel(self): + # The kernel object lives at module scope (we can't actually do a fresh + # ``import`` in a way that hides it from source scanning AND lets the + # function still call it). Simulate the runtime-import case by stuffing + # the kernel into a local ``import``-like alias derived from globals, + # so source introspection cannot statically resolve it. + @magi_register_custom_op(name="magi_test::runtime_import_kernel", extra_triton_kernels=[_cos_kernel]) + def myop(x: torch.Tensor) -> torch.Tensor: + module_globals = globals() + # Indirect lookup hides the kernel from static introspection of + # ``myop``'s globals/closure. + kernel = module_globals["_cos_kernel"] + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(256, device="cuda") + out = myop(x) + assert_close(out, torch.cos(x)) + + +class TestNoTritonFallback: + def test_no_kernel_uses_custom_op(self): + @magi_register_custom_op(name="magi_test::pure_python_op") + def fn(x: torch.Tensor) -> torch.Tensor: + return x * 2 + 1 + + x = torch.randn(8, 8) + assert_close(fn(x), x * 2 + 1) + + +# Triton path + autograd combination. + + +class TestIntrospection: + def test_introspect_fn_flat(self): + from magi_compiler._triton_introspect import introspect_fn + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + assert _cos_kernel in introspect_fn(fn).bare_triton_kernels + + def test_introspect_fn_nested(self): + from magi_compiler._triton_introspect import introspect_fn + + def fn(a, b): + return _add_launcher(a, b) + + assert _add_kernel in introspect_fn(fn).bare_triton_kernels + + def test_rewrite_replaces_kernel_with_wrap_triton(self): + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + from magi_compiler._triton_introspect import introspect_fn, rewrite_fn_with_wrap_triton + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + kernels = list(introspect_fn(fn).bare_triton_kernels) + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + + # _cos_kernel name in the rewritten globals should now point to a + # TraceableTritonKernelWrapper, not the bare JITFunction. + assert isinstance(rewritten.__globals__["_cos_kernel"], TraceableTritonKernelWrapper) + # Originals untouched. + from triton.runtime.jit import JITFunction + + assert isinstance(_cos_kernel, JITFunction) + + def test_rewrite_propagates_through_helpers(self): + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + from magi_compiler._triton_introspect import introspect_fn, rewrite_fn_with_wrap_triton + + def fn(a, b): + return _add_launcher(a, b) + + kernels = list(introspect_fn(fn).bare_triton_kernels) + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + + rebuilt_launcher = rewritten.__globals__["_add_launcher"] + assert isinstance(rebuilt_launcher.__globals__["_add_kernel"], TraceableTritonKernelWrapper) + + def test_introspect_fn_dot_run(self): + """``kernel.run(*args, grid=...)`` is Triton's low-level launch + API. It is what ``kernel[grid](*args)`` desugars to and what + PyTorch Inductor's generated code uses. Verify the AST scanner + recognises it as a bare kernel launch. + """ + from magi_compiler._triton_introspect import introspect_fn + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + _cos_kernel.run(x, out, n, BLOCK_SIZE=128, grid=_grid_1d(n), warmup=False) + return out + + assert _cos_kernel in introspect_fn(fn).bare_triton_kernels + + def test_introspect_fn_dotted_module_attr(self): + """``mod.kernel[grid](...)`` references a kernel through a module + attribute. The collector must record ``"mod.kernel"`` and the + resolver must walk the dotted path via ``getattr`` to recover + the underlying ``JITFunction``. + """ + from magi_compiler._triton_introspect import introspect_fn + from tests.api_tests import _triton_external_helpers as ext + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + ext.external_neg_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + assert ext.external_neg_kernel in introspect_fn(fn).bare_triton_kernels + + def test_introspect_fn_class_attr(self): + """``Holder.kernel[grid](...)`` references a kernel through a + class attribute. Same mechanism as module attributes -- the + collector records ``"Holder.kernel"`` and the resolver walks + it via ``getattr``. + """ + from magi_compiler._triton_introspect import introspect_fn + + class Holder: + kernel = _cos_kernel + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + Holder.kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + assert _cos_kernel in introspect_fn(fn).bare_triton_kernels + + def test_introspect_fn_dot_run_with_dotted_receiver(self): + """Combination of A.5 and dotted lookup: ``mod.kernel.run(...)``.""" + from magi_compiler._triton_introspect import introspect_fn + from tests.api_tests import _triton_external_helpers as ext + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + ext.external_double_kernel.run( + x, out, n, BLOCK_SIZE=128, grid=_grid_1d(n), warmup=False + ) + return out + + assert ext.external_double_kernel in introspect_fn(fn).bare_triton_kernels + + +# Multi-level nesting: fn -> dispatch -> launcher -> kernel. +# Verifies that kernels several call-graph hops away are still detected and +# that ``rewrite_fn_with_wrap_triton`` rebuilds every helper along the path. + + +class TestInferOutputMetaOverride: + def test_meta_list_form(self): + @magi_register_custom_op(name="magi_test::triton_meta_list", infer_output_meta_fn=["x"]) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + # And inside torch.compile (forces the fake/meta path to be used). + compiled = torch.compile(fn, backend="inductor", fullgraph=True) + assert_close(compiled(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_meta_callable_form(self): + called = {"count": 0} + + def custom_meta(x: torch.Tensor) -> torch.Tensor: + called["count"] += 1 + return torch.empty_like(x) + + @magi_register_custom_op(name="magi_test::triton_meta_callable", infer_output_meta_fn=custom_meta) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + compiled = torch.compile(fn, backend="inductor", fullgraph=True) + assert_close(compiled(x), torch.cos(x), atol=1e-5, rtol=1e-5) + # Tracing through torch.compile should have invoked the user-provided + # meta at least once. + assert called["count"] >= 1 + + +# Explicit registry-level assertion that we actually went down the +# ``torch.library.triton_op`` path (i.e. Inductor would be able to inline +# the kernel), distinguishing it from the silent custom_op fallback. + + +class TestTritonOpRegistryAssertion: + """Verify we actually take the ``torch.library.triton_op`` registration + path (so Inductor / make_fx can see through the op) instead of silently + falling back to plain ``custom_op`` (which would be opaque).""" + + @staticmethod + def _was_registered_as_triton_op(op_or_name) -> bool: + # ``triton_op`` installs a torch_dispatch on FunctionalTensorMode that + # decomposes the op into ``triton_kernel_wrapper_mutation`` calls. + # Plain ``custom_op`` does not. + from torch._library.custom_ops import OPDEFS + from torch._subclasses.functional_tensor import FunctionalTensorMode + + if isinstance(op_or_name, str): + opdef = OPDEFS.get(op_or_name) + if opdef is None: + return False + else: + opdef = op_or_name + dispatch_fns = getattr(opdef, "_torch_dispatch_fns", {}) or {} + return FunctionalTensorMode in dispatch_fns + + def test_registered_as_triton_op(self): + @magi_register_custom_op(name="magi_test::registry_cos") + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + assert self._was_registered_as_triton_op("magi_test::registry_cos"), ( + "magi_test::registry_cos should have been registered via " + "torch.library.triton_op (so make_fx decomposes it into " + "triton_kernel_wrapper_mutation), not via plain custom_op." + ) + + def test_pure_python_op_not_registered_as_triton(self): + @magi_register_custom_op(name="magi_test::registry_pure_python") + def fn(x: torch.Tensor) -> torch.Tensor: + return x * 2 + 1 + + assert not self._was_registered_as_triton_op("magi_test::registry_pure_python"), ( + "magi_test::registry_pure_python has no triton kernels; it should " + "have fallen back to the custom_op path and remain opaque to " + "make_fx." + ) + + +# extra_triton_kernels deduplication: a kernel that is *both* auto-detected +# and listed in ``extra_triton_kernels`` should appear exactly once after +# resolution and must not be wrap_triton-wrapped twice. + + +# ============================================================================ +# SECTION 3: Autotune, Heuristics & Autograd in Triton +# ============================================================================ + + +class TestAutotuneKernels: + def test_autotuned(self): + @magi_register_custom_op(name="magi_test::autotuned_cos") + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + # autotuner picks BLOCK_SIZE; grid uses meta lambda + _autotuned_cos_kernel[(triton.cdiv(n, 128),)](x, out, n) + return out + + x = torch.randn(2048, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +class TestMultipleAutotuneKernelsSameOp: + """A single op may launch several differently-autotuned kernels (a common + FlashAttention / Mamba pattern). Verify both kernels are detected and the + op runs end-to-end through the triton_op path. + """ + + def test_two_autotune_kernels_in_same_op(self): + # Build a *second* autotuned kernel locally so we can be sure both + # kernel objects appear in the op's call graph. + @triton.autotune( + configs=[triton.Config({"BLOCK_SIZE": 128}, num_warps=4), triton.Config({"BLOCK_SIZE": 256}, num_warps=4)], + key=["n_elements"], + ) + @triton.jit + def _autotuned_sin_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, tl.sin(x), mask=mask) + + @magi_register_custom_op(name="magi_test::two_autotune_kernels", extra_triton_kernels=[_autotuned_sin_kernel]) + def myop(x: torch.Tensor) -> torch.Tensor: + n = x.numel() + mid = torch.empty_like(x) + _autotuned_cos_kernel[_grid_1d(n)](x, mid, n) + out = torch.empty_like(x) + _autotuned_sin_kernel[_grid_1d(n)](mid, out, n) + return out + + x = torch.randn(2048, device="cuda") + out = myop(x) + assert_close(out, torch.sin(torch.cos(x)), atol=1e-4, rtol=1e-4) + + +class TestHeuristicsRejection: + """``torch.library.wrap_triton`` only accepts ``JITFunction`` and + ``Autotuner``. A top-level ``@triton.heuristics`` produces a + ``Heuristics`` instance that fails ``wrap_triton`` with a confusing + error. ``@magi_register_custom_op`` rejects this case up front with a + clearer message, while still accepting the recommended layering of + ``@triton.autotune -> @triton.heuristics -> @triton.jit``. + """ + + def test_top_level_heuristics_rejected_with_clear_message(self): + """Bare ``@triton.heuristics`` on a kernel referenced from the op + body must be rejected at registration time, not deep inside + ``wrap_triton``.""" + with pytest.raises(RuntimeError, match="triton.heuristics"): + + @magi_register_custom_op(name="magi_test::heuristics_top") + def myop(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _heuristics_top_kernel[_grid_1d(n)](x, out, n) + return out + + def test_top_level_heuristics_via_extra_triton_kernels_rejected(self): + """Same constraint applies when the user passes the offending kernel + through the ``extra_triton_kernels`` escape hatch (no auto-detection + involved).""" + with pytest.raises(RuntimeError, match="triton.heuristics"): + + @magi_register_custom_op(name="magi_test::heuristics_extra", extra_triton_kernels=[_heuristics_top_kernel]) + def myop(x: torch.Tensor) -> torch.Tensor: + # Body doesn't reference the kernel at all; rejection comes + # purely from the extra_triton_kernels list. + return x.clone() + + def test_autotune_outside_heuristics_is_accepted(self): + """The recommended layering ``@triton.autotune -> @triton.heuristics + -> @triton.jit`` produces an ``Autotuner`` at the top level and is + accepted (and end-to-end functional).""" + + @magi_register_custom_op(name="magi_test::autotune_over_heuristics") + def myop(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _autotune_then_heuristics_kernel[_grid_1d(n)](x, out, n) + return out + + x = torch.randn(512, device="cuda") + out = myop(x) + assert_close(out, x) + + +# #15 / #16: kernels not statically discoverable -> extra_triton_kernels= + + +class TestTritonWithAutograd: + def test_triton_with_backward(self): + def setup_ctx(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad_out): + (x,) = ctx.saved_tensors + # d/dx cos(x) = -sin(x) + return grad_out * (-torch.sin(x)) + + @magi_register_custom_op(name="magi_test::triton_cos_grad", setup_context_fn=setup_ctx, backward_fn=backward) + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda", requires_grad=True) + out = mycos(x) + loss = out.sum() + loss.backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + +# ============================================================================ +# SECTION 4: Dataclass + Triton Bridge +# ---------------------------------------------------------------------------- +# The dataclass-aware registration path lowers each dataclass parameter into +# flat primitive leaves before handing the function off to torch.library. +# This section verifies that the triton_op auto-detection still kicks in on +# that lowered path, including nested dataclasses, custom meta functions, +# autograd hooks, per-field grads, and ``is_compute_sensitive``. +# ============================================================================ + + +class TestDataclassWithTritonKernel: + def test_dataclass_input_with_triton(self): + @magi_register_custom_op(name="magi_test::dc_cos") + def fn(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(1024, device="cuda") + cfg = _DcCosCfg(block_size=128) + assert_close(fn(x, cfg), torch.cos(x), atol=1e-5, rtol=1e-5) + + # The dataclass-aware path registers an inner op under the requested + # name; that inner op should still be a triton_op so Inductor can see + # through it. + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( + "magi_test::dc_cos" + ), "dataclass+triton path should still register the inner op as a triton_op" + + +class TestNestedDataclassWithTritonKernel: + def test_two_level_nested_dc_with_triton(self): + """Outer dataclass containing an inner dataclass; both are lowered + into flat primitive parameters.""" + + @magi_register_custom_op(name="magi_test::nested_dc_cos_scale") + def fn(x: torch.Tensor, cfg: _DcOuterCfg) -> torch.Tensor: + tmp = torch.empty_like(x) + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=cfg.kernel.block_size) + _scale_kernel[_grid_1d(n)](tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size) + return out + cfg.kernel.extra_offset + + x = torch.randn(1024, device="cuda") + cfg = _DcOuterCfg(kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5) + out = fn(x, cfg) + expected = torch.cos(x) * 2.5 + 0.5 + assert_close(out, expected, atol=1e-5, rtol=1e-5) + + # Sanity-check the param_mapping_tree exposes the expected lowered + # leaf names. + plan = fn._magi_param_mapping_tree + cfg_node = plan[1] + assert cfg_node[0] == "dataclass" and cfg_node[1] == "cfg" + flat_names: list[str] = [] + + def _collect(node): + if node[0] == "primitive": + flat_names.append(node[2]) + else: + for child in node[3]: + _collect(child) + + _collect(cfg_node) + assert {"cfg__kernel__block_size", "cfg__kernel__extra_offset", "cfg__scale"}.issubset(flat_names) + + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( + "magi_test::nested_dc_cos_scale" + ), "nested-dataclass + triton path should still register the inner op as a triton_op" + + def test_nested_dc_with_triton_and_meta_fn(self): + """User-supplied meta function expressed in nested-dataclass terms, + combined with a triton kernel call.""" + + def _meta(x: torch.Tensor, cfg: _DcProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.shape.out_dim)) + + @magi_register_custom_op(name="magi_test::nested_dc_cos_proj", infer_output_meta_fn=_meta) + def fn(x: torch.Tensor, cfg: _DcProjCfg) -> torch.Tensor: + sliced = x[..., : cfg.shape.out_dim].contiguous() + out = torch.empty_like(sliced) + n = sliced.numel() + _cos_kernel[_grid_1d(n)](sliced, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(2, 8, device="cuda") + cfg = _DcProjCfg(shape=_DcShapeCfg(out_dim=3), block_size=128) + out = fn(x, cfg) + expected = torch.cos(x[..., :3].contiguous()) + assert out.shape == (2, 3) + assert_close(out, expected, atol=1e-5, rtol=1e-5) + + +class TestDataclassWithTritonKernelAndBackward: + def test_triton_dc_backward_basic(self): + """End-to-end backward against a dc + triton op: use the cos kernel + (analytical grad: -sin(x)) so we can verify exact grads.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + assert isinstance(cfg, _DcCosCfg) + ctx.save_for_backward(x) + ctx.block_size = cfg.block_size + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + return grad_out * (-torch.sin(x)), None + + @magi_register_custom_op(name="magi_test::dc_cos_grad", setup_context_fn=_setup, backward_fn=_bwd) + def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(1024, device="cuda", requires_grad=True) + cfg = _DcCosCfg(block_size=128) + out = mycos(x, cfg) + out.sum().backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op("magi_test::dc_cos_grad") + + def test_triton_nested_dc_backward(self): + """Nested dataclass + triton + backward. The bridge must spread the + whole-nested-dc ``None`` grad over every flat slot under that + dataclass.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + assert isinstance(cfg, _DcOuterCfg) + assert isinstance(cfg.kernel, _DcKernelCfg) + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + return grad_out * (-torch.sin(x)) * ctx.scale, None + + @magi_register_custom_op(name="magi_test::nested_dc_cos_grad", setup_context_fn=_setup, backward_fn=_bwd) + def fn(x: torch.Tensor, cfg: _DcOuterCfg) -> torch.Tensor: + tmp = torch.empty_like(x) + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=cfg.kernel.block_size) + _scale_kernel[_grid_1d(n)](tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size) + return out + cfg.kernel.extra_offset + + x = torch.randn(1024, device="cuda", requires_grad=True) + cfg = _DcOuterCfg(kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5) + out = fn(x, cfg) + out.sum().backward() + expected = -torch.sin(x.detach()) * 2.5 + assert_close(x.grad, expected, atol=1e-5, rtol=1e-5) + + def test_triton_dc_backward_with_per_field_grad(self): + """User returns per-field grads (as a same-shape dataclass with + ``None`` leaves) for the dc slot. The triton path must still work.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.block_size = cfg.block_size + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + return (grad_out * (-torch.sin(x)), _DcCosCfg(block_size=None)) + + @magi_register_custom_op(name="magi_test::dc_cos_per_field_grad", setup_context_fn=_setup, backward_fn=_bwd) + def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(512, device="cuda", requires_grad=True) + out = mycos(x, _DcCosCfg(block_size=128)) + out.sum().backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + def test_triton_dc_backward_with_dict_grad(self): + """User returns the dataclass slot's grad as a plain ``dict``; the + bridge must spread it through ``__getitem__``-style access into the + underlying flat slots.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.block_size = cfg.block_size + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + return grad_out * (-torch.sin(x)), {"block_size": None} + + @magi_register_custom_op(name="magi_test::dc_cos_dict_grad", setup_context_fn=_setup, backward_fn=_bwd) + def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(512, device="cuda", requires_grad=True) + out = mycos(x, _DcCosCfg(block_size=128)) + out.sum().backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + +class TestDataclassTritonComputeSensitiveSmoke: + """The dataclass-aware bridge composes cleanly with + ``is_compute_sensitive=True`` on the triton path: registration succeeds, + the op runs, and its name lands in the compute-sensitive registry. + """ + + def test_dataclass_triton_compute_sensitive(self): + from magi_compiler.config import get_compile_config + + @magi_register_custom_op(name="magi_test::dc_triton_cs", is_compute_sensitive=True) + def myop(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(256, device="cuda") + out = myop(x, _DcCosCfg(block_size=128)) + assert_close(out, torch.cos(x)) + assert "magi_test::dc_triton_cs" in get_compile_config().recompute_config.custom_compute_sensitive_ops + + +# Direct unit tests for the introspection / rewrite helpers. + + +class TestInductorSeesTritonKernel: + """The whole point of the triton_op auto-detection is that + ``torch.compile`` (Inductor) sees through the op to the underlying + triton kernel rather than treating it as opaque. Verify by inspecting + the FX graph captured by Inductor for the wrap_triton-functional HOP. + """ + + def test_triton_kernel_visible_in_aot_graph(self): + from torch._functorch.aot_autograd import aot_function + + @magi_register_custom_op(name="magi_test::inductor_visible_cos") + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + # Run the op through AOTAutograd directly with a custom forward + # compiler that just records the post-functionalization graph. This + # is exactly the layer where ``triton_op``s decompose into the + # ``triton_kernel_wrapper_functional`` HOP; the presence of that + # node in the captured graph proves Inductor (which runs *after* + # AOTAutograd) sees the underlying triton kernel rather than an + # opaque ``torch.ops.magi_test.inductor_visible_cos`` call. + captured_graphs: list[str] = [] + + def _capture(gm, _example_inputs): + captured_graphs.append(gm.code) + return gm.forward + + x = torch.randn(1024, device="cuda") + torch._dynamo.reset() + compiled_aot = aot_function(mycos, fw_compiler=_capture, bw_compiler=_capture) + out = compiled_aot(x) + assert_close(out, torch.cos(x), atol=1e-5, rtol=1e-5) + + joined = "\n".join(captured_graphs) + assert "triton_kernel_wrapper_functional" in joined or "triton_kernel_wrapper_mutation" in joined, ( + "AOT graph did not decompose magi_test::inductor_visible_cos " + "into the triton_kernel_wrapper HOP; Inductor will treat it " + "as opaque. Captured AOT graph:\n" + joined + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 56b30ea9b23038e9b0d5be51715e348dbdbb4d56 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Sat, 23 May 2026 16:31:12 +0800 Subject: [PATCH 2/2] [Chores] fix code style --- magi_compiler/_magi_register_custom_op.py | 31 +++++----------------- magi_compiler/_triton_introspect.py | 18 +++---------- tests/api_tests/test_register_triton_op.py | 4 +-- 3 files changed, 10 insertions(+), 43 deletions(-) diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index ea44538..ede6948 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -58,12 +58,7 @@ import torch import torch.utils._pytree as pytree -from ._triton_introspect import ( - DEFAULT_MAX_INTROSPECT_DEPTH, - IntrospectionResult, - introspect_fn, - rewrite_fn_with_wrap_triton, -) +from ._triton_introspect import DEFAULT_MAX_INTROSPECT_DEPTH, IntrospectionResult, introspect_fn, rewrite_fn_with_wrap_triton from .config import get_compile_config from .utils.logger import magi_logger @@ -735,8 +730,7 @@ def _generate_op_name(fn: Callable) -> str: def _reject_heuristics_outermost( - introspection: IntrospectionResult, - extra_triton_kernels: list[Any] | tuple[Any, ...] | None, + introspection: IntrospectionResult, extra_triton_kernels: list[Any] | tuple[Any, ...] | None ) -> None: """Reject kernels whose outermost decorator is ``@triton.heuristics`` (``wrap_triton`` only accepts JIT/Autotuner; surface here before the @@ -975,8 +969,7 @@ def _bridged_meta_fn(*args, **kwargs): except ImportError: triton_op = None # type: ignore[assignment] magi_logger.warning( - "torch.library.triton_op not available; falling back to torch.library.custom_op for op %s", - op_name, + "torch.library.triton_op not available; falling back to torch.library.custom_op for op %s", op_name ) if triton_op is not None: try: @@ -1205,28 +1198,16 @@ def decorator(fn: Callable) -> Callable: needs_flattening = any(kind == "dataclass" for kind, *_ in param_mapping_tree) # Step A: single AST pass feeds every downstream check. - introspection = introspect_fn( - fn, - extra_triton_kernels=extra_triton_kernels, - max_depth=max_introspect_depth, - ) + introspection = introspect_fn(fn, extra_triton_kernels=extra_triton_kernels, max_depth=max_introspect_depth) # Reject top-level ``@triton.heuristics`` before path decision for a precise error. _reject_heuristics_outermost(introspection, extra_triton_kernels) nested = _classify_nested_ops(introspection.nested_op_calls) decision = _decide_registration_path( - fn, - has_direct_kernel=introspection.has_direct_kernel, - nested=nested, - force_register_mode=force_register_mode, - ) - magi_logger.debug( - "@magi_register_custom_op: %s -> mode=%s (%s)", - op_name, - decision.mode, - decision.reason, + fn, has_direct_kernel=introspection.has_direct_kernel, nested=nested, force_register_mode=force_register_mode ) + magi_logger.debug("@magi_register_custom_op: %s -> mode=%s (%s)", op_name, decision.mode, decision.reason) # mode="none" -> skip registration so Inductor inlines fn (warning already emitted). if decision.mode == "none": diff --git a/magi_compiler/_triton_introspect.py b/magi_compiler/_triton_introspect.py index 2cefc92..4ae19ea 100644 --- a/magi_compiler/_triton_introspect.py +++ b/magi_compiler/_triton_introspect.py @@ -38,12 +38,7 @@ DEFAULT_MAX_INTROSPECT_DEPTH: int = 5 -__all__ = [ - "DEFAULT_MAX_INTROSPECT_DEPTH", - "IntrospectionResult", - "introspect_fn", - "rewrite_fn_with_wrap_triton", -] +__all__ = ["DEFAULT_MAX_INTROSPECT_DEPTH", "IntrospectionResult", "introspect_fn", "rewrite_fn_with_wrap_triton"] # ============================================================================== @@ -414,12 +409,7 @@ def _lookup_dotted(name: str) -> object | None: return None return obj - def _classify_name( - name: str, - *, - as_bare: bool, - visited_names: set[str], - ) -> None: + def _classify_name(name: str, *, as_bare: bool, visited_names: set[str]) -> None: """Classify ``name`` and route it to the right bucket: a kernel bucket (``bare`` vs ``user_wrapped`` per ``as_bare``), the ``heuristics`` bucket if the resolved object is a Triton @@ -479,9 +469,7 @@ def _classify_name( try: _walk(helper_obj, depth + 1) except Exception: - logger.debug( - "failed to analyze called helper %s", helper_name, exc_info=True - ) + logger.debug("failed to analyze called helper %s", helper_name, exc_info=True) _walk(fn, 0) diff --git a/tests/api_tests/test_register_triton_op.py b/tests/api_tests/test_register_triton_op.py index 2fb687b..0a6a024 100644 --- a/tests/api_tests/test_register_triton_op.py +++ b/tests/api_tests/test_register_triton_op.py @@ -832,9 +832,7 @@ def test_introspect_fn_dot_run_with_dotted_receiver(self): def fn(x): out = torch.empty_like(x) n = x.numel() - ext.external_double_kernel.run( - x, out, n, BLOCK_SIZE=128, grid=_grid_1d(n), warmup=False - ) + ext.external_double_kernel.run(x, out, n, BLOCK_SIZE=128, grid=_grid_1d(n), warmup=False) return out assert ext.external_double_kernel in introspect_fn(fn).bare_triton_kernels