Add mHC fused kernels + LigerMHC API + benchmarks#1065
Add mHC fused kernels + LigerMHC API + benchmarks#1065yukiu00 wants to merge 14 commits intolinkedin:mainfrom
Conversation
- Add mHC kernels and APIs - Provide reference implementations for tests and benchmarks - Add/adjust tests, tolerances, and benchmarks - Document memory trade-offs, usage notes, and options
benchmark/scripts/benchmark_mhc.py
Outdated
| def _time_loop(fn, iters=200, warmup=50) -> float: | ||
| torch.cuda.synchronize() | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| t0 = time.time() | ||
| for _ in range(iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| t1 = time.time() | ||
| return (t1 - t0) * 1e3 / iters |
There was a problem hiding this comment.
we prefer using triton.testing.do_bench() to bench our kernels. refer to other benchmark scripts for concrete examples.
There was a problem hiding this comment.
Acknowledged. I switched benchmark_mhc.py to use triton.testing.do_bench() with QUANTILES for all forward/backward measurements.
benchmark/scripts/benchmark_mhc.py
Outdated
| def _peak_bytes(fn) -> int: | ||
| torch.cuda.reset_peak_memory_stats() | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| return int(torch.cuda.max_memory_allocated()) |
There was a problem hiding this comment.
Liger-Kernel/benchmark/scripts/utils.py
Line 92 in 83cdcf8
we have utils._test_memory() for checking memory.
There was a problem hiding this comment.
Done. Memory measurement now uses utils._test_memory() consistently.
benchmark/scripts/benchmark_mhc.py
Outdated
|
|
||
| import torch | ||
|
|
||
| from utils import mhc_coeffs_ref |
There was a problem hiding this comment.
Let's keep one reference only for single source of truth. Instead of writing another one in benchmark/scripts/utils.py, reuse the one from test.transformers.test_mhc to avoid inconsistent reference in future update.
Add root directory to your path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))then you can access test directory in your function.
There was a problem hiding this comment.
I moved mhc_sinkhorn_ref / mhc_coeffs_ref into test_mhc.py and updated benchmarks to import from there. I also added the repo root to sys.path in the benchmark scripts.
src/liger_kernel/ops/mhc.py
Outdated
| from liger_kernel.triton.mhc import mhc_mm_norm_bwd | ||
| from liger_kernel.triton.mhc import mhc_mm_norm_fwd | ||
| from liger_kernel.triton.mhc import mhc_post_res_bwd | ||
| from liger_kernel.triton.mhc import mhc_post_res_fwd | ||
| from liger_kernel.triton.mhc import mhc_pre_bwd | ||
| from liger_kernel.triton.mhc import mhc_pre_fwd | ||
| from liger_kernel.triton.mhc import mhc_sinkhorn_bwd | ||
| from liger_kernel.triton.mhc import mhc_split_sinkhorn_fwd |
There was a problem hiding this comment.
Let's define them in this file directly for codebase structure. Happy to discuss other approaches if you feel the file become too large.
There was a problem hiding this comment.
Done. I moved the Triton kernels into mhc.py and removed mhc.py.
| def test_mhc_mini_lm_convergence(): | ||
| set_seed(0) | ||
|
|
||
| device = "cuda" |
There was a problem hiding this comment.
we use liger_kernel.utils.infer_device to get device since we support mutliple backends, not just cuda
| HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py | ||
| HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py | ||
| HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mhc_mini_lm.py | ||
|
|
There was a problem hiding this comment.
I'll disucss with folks and decide whether testing mhc architecture or not.
- Remove backward-compatible alias functions (108 lines) - Add docstring and comments to _post_res_default_meta - Use Union[] instead of | for Python 3.9 compatibility - Replace assert with ValueError for better debugging - Add runtime warning when hc > 16
There was a problem hiding this comment.
Perhaps we can move this part to test_mhc.py. No need to check whether losses decreases, we only have to check whether the outputs generated by two models, one with torch refs and the other with liger's mhc components, are close enough, using torch.testing.assert_close() or utils.assert_verblose_allclose()
There was a problem hiding this comment.
Follow the other benchscripts structure so that we can track the benchmarking results with ci. Take benchmark_dpo_loss.py as an example, there are some key points you need to cover:
Pack configs, sample ranges, etc in a dictionary
Liger-Kernel/benchmark/scripts/benchmark_dpo_loss.py
Lines 144 to 162 in ad6f0a7
Define your bench_speed/memory functions
Liger-Kernel/benchmark/scripts/benchmark_dpo_loss.py
Lines 70 to 72 in ad6f0a7
Extract configs in your custom functions
Liger-Kernel/benchmark/scripts/benchmark_dpo_loss.py
Lines 74 to 83 in ad6f0a7
Use do_bench() to get statistics
Liger-Kernel/benchmark/scripts/benchmark_dpo_loss.py
Lines 108 to 122 in ad6f0a7
Pack your statistics into SingleBenchmarkRunOutput and return
Liger-Kernel/benchmark/scripts/benchmark_dpo_loss.py
Lines 134 to 138 in ad6f0a7
Call run_benchmark() in main function
Liger-Kernel/benchmark/scripts/benchmark_dpo_loss.py
Lines 165 to 171 in ad6f0a7
There was a problem hiding this comment.
Done. Restructured benchmark_mhc.py to follow the standard framework: uses SingleBenchmarkRunInput/Output, run_benchmarks, _test_memory, and parse_benchmark_script_args, following benchmark_dpo_loss.py as reference. All configs are packed into extra_benchmark_configs.
There was a problem hiding this comment.
Done. Applied the same restructuring to benchmark_mhc_lm.py.
| int(tmax), | ||
| float(rms_eps), | ||
| float(pre_eps), | ||
| float(sinkhorn_eps), | ||
| float(post_mult), |
There was a problem hiding this comment.
Any considerations why we have to cast them?
There was a problem hiding this comment.
The casts (int(tmax), float(rms_eps), etc.) convert config scalars from tensors or numpy types into plain Python types, ensuring they are not accidentally included in the autograd graph. Added a clarifying comment at L322.
src/liger_kernel/ops/mhc.py
Outdated
| def _as_scalar(x: Union[torch.Tensor, float]) -> float: | ||
| if isinstance(x, torch.Tensor): | ||
| return float(x.item()) | ||
| return float(x) |
There was a problem hiding this comment.
Can we just read the tensor in triton code isntead of turning it into python value? Moving tensors to cpu might cause gpu idle, we probably want to avoid it at best.
There was a problem hiding this comment.
Done. Alpha values (alpha_pre, alpha_post, alpha_res) are now passed as tensor pointers to the Triton kernels and read directly on GPU via tl.load(). There are no .item() or .cpu() calls.
src/liger_kernel/ops/mhc.py
Outdated
| h_post: [..., HC] FP32 | ||
| h_res: [..., HC, HC] FP32 | ||
| """ | ||
| assert x.is_cuda, "CUDA only" |
src/liger_kernel/ops/mhc.py
Outdated
| def liger_mhc_pre(x: torch.Tensor, h_pre: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Apply H_pre: x_in = sum_i h_pre[i] * x[i] | ||
| """ | ||
| return LigerMHCPreFunction.apply(x, h_pre) |
There was a problem hiding this comment.
Move functionals to transformers/functional.py
There was a problem hiding this comment.
Done. All functional APIs (liger_mhc_coeffs, liger_mhc_pre, liger_mhc_post_res, liger_mhc_apply, liger_mhc_forward) are now in transformers/functional.py.
src/liger_kernel/ops/mhc.py
Outdated
| return LigerMHCPostResFunction.apply(x, f_out, h_post, h_res) | ||
|
|
||
|
|
||
| def liger_mhc_apply( |
src/liger_kernel/ops/mhc.py
Outdated
| return x_out | ||
|
|
||
|
|
||
| def liger_mhc_forward( |
src/liger_kernel/transformers/mhc.py
Outdated
| Wraps a layer F: [..., C] -> [..., C] with mHC residual streams: [..., HC, C]. | ||
| Args: | ||
| layer: module applied to the aggregated stream input |
There was a problem hiding this comment.
Can you add an example in the docstring? It's still unclear to me how we can wrap the existing modules with LigerMHC
There was a problem hiding this comment.
Done. Added an Example:: section in the docstring (L61-77) that shows how to wrap a linear layer with LigerMHC and how to use it inside a transformer block.
- Remove all `assert x.is_cuda` checks for multi-backend support - Eliminate `_as_scalar()` GPU sync by passing alpha params as pointers to Triton kernels (use `tl.load()` instead of `.item()`) - Merge duplicate TC/TF32 kernel pairs into unified kernels with `CAST_FP32: tl.constexpr` compile-time flag (~180 lines removed) - Replace `view(N, ...)` with `view(-1, ...)` across autograd Functions - Move functional APIs from `ops.mhc` to `transformers.functional` - Improve `LigerMHC` docstring with architecture, args, and examples - Rewrite `benchmark_mhc.py` to standard framework (run_benchmarks) - Use `infer_device()` in convergence test instead of hardcoded "cuda"
… skipif - benchmark_mhc.py: pass all config params via extra_benchmark_configs following the DPO benchmark pattern - test_mhc_mini_lm.py: remove redundant torch.cuda.is_available() skipif (supports_bfloat16() already covers this case)
…st_mhc.py for improved organization and maintainability of convergence tests.
- Remove CUDA-only skipif decorators from tests for multi-backend support - Simplify _flatten_tokens to return x_shape, remove _unflatten_tokens helper - Remove dead Makefile reference to deleted test_mhc_mini_lm.py
…istency and clarity. Update function signatures in `_post_res_default_meta` and `_post_res_meta` to use `Tuple` from the `typing` module.
- Remove no-op mask=True from Sinkhorn backward kernels - Drop unused rms_eps/pre_eps from ctx.meta in coeffs backward - Remove redundant .contiguous() calls inside @ensure_contiguous methods - Simplify grad_x reshape to use x_shape directly - Simplify device detection in LigerMHC to try/except pattern - Replace torch.allclose with assert_verbose_allclose in tests - Standardize seed to set_seed(42) across all tests - Merge test_mhc_coeffs_allow_fp32 into test_mhc_coeffs_forward_backward - Add backward coverage to test_mhc_pre_and_post_res_match_reference - Widen bf16 tolerance for layer.weight.grad and phi.grad in module test - Move hardcoded B into extra_benchmark_configs (benchmark_mhc.py) - Rename MiniMHCLM to BenchMiniMHCLM in benchmark_mhc_lm.py - Split _build_models into single-provider _build_model
|
Thank you for the detailed review, @Tcc0403! All review comments have been addressed — I've replied to each one individually above. In addition to the review feedback, I also made a few minor cleanups:
Please let me know if there's anything else that needs attention! |
Add mHC fused kernels + LigerMHC API + benchmarks
Reference Issue
Summary
This PR adds an opt-in, paper-aligned mHC implementation to Liger-Kernel: fused Triton kernels, a
LigerMHCmodule, functional APIs, tests, and benchmarks.No existing default patching behavior is changed.
Background (Paper)
mHC: Manifold-Constrained Hyper-Connections (arXiv:2512.24880v2)
https://arxiv.org/abs/2512.24880
Key idea: constrain H_res via Sinkhorn-Knopp onto the doubly-stochastic set (Birkhoff polytope), restoring identity-mapping stability while preserving multi-stream residual benefits. The paper also emphasizes fused kernels + mixed precision + recompute (Sec. 4.3.1, Eq.(14)–(19)).
What’s included
LigerMHC+liger_mhc_*functional APIs (Liger naming).allow_fp32opt-in (default remains BF16/FP16 mixed precision; intended for specific/debug use cases).benchmark/scripts/benchmark_mhc_lm.pyBenchmarks (RTX 3090, BF16, B=2, T=256, n(HC)=4, layers=2, heads=8, vocab=4096)
see #1066 (comment)
Out of scope