Skip to content

Add mHC fused kernels + LigerMHC API + benchmarks#1065

Open
yukiu00 wants to merge 14 commits intolinkedin:mainfrom
yukiu00:feat/mhc-kernel
Open

Add mHC fused kernels + LigerMHC API + benchmarks#1065
yukiu00 wants to merge 14 commits intolinkedin:mainfrom
yukiu00:feat/mhc-kernel

Conversation

@yukiu00
Copy link
Contributor

@yukiu00 yukiu00 commented Feb 4, 2026

Add mHC fused kernels + LigerMHC API + benchmarks

Reference Issue

Summary

This PR adds an opt-in, paper-aligned mHC implementation to Liger-Kernel: fused Triton kernels, a LigerMHC module, functional APIs, tests, and benchmarks.
No existing default patching behavior is changed.

Background (Paper)

mHC: Manifold-Constrained Hyper-Connections (arXiv:2512.24880v2)
https://arxiv.org/abs/2512.24880

Key idea: constrain H_res via Sinkhorn-Knopp onto the doubly-stochastic set (Birkhoff polytope), restoring identity-mapping stability while preserving multi-stream residual benefits. The paper also emphasizes fused kernels + mixed precision + recompute (Sec. 4.3.1, Eq.(14)–(19)).

What’s included

  • Triton mHC kernels (coeffs / Sinkhorn / apply; fwd + bwd).
  • API: LigerMHC + liger_mhc_* functional APIs (Liger naming).
  • allow_fp32 opt-in (default remains BF16/FP16 mixed precision; intended for specific/debug use cases).
  • Benchmarks: benchmark/scripts/benchmark_mhc_lm.py
  • Tests: ops correctness, transformer-level tests, convergence test.

Benchmarks (RTX 3090, BF16, B=2, T=256, n(HC)=4, layers=2, heads=8, vocab=4096)

see #1066 (comment)

Out of scope

  • Regarding the recomputation strategy mentioned in the paper (Section 4.3.2): The block-wise recomputation ($L_r$ layers) is out of scope for the Liger-Kernel. Users can achieve the memory savings described in the paper by simply applying torch.utils.checkpoint to groups of these mHC layers in their training loop.
  • DualPipe schedule / distributed pipeline optimization (paper Sec. 4.3.3)

- Add mHC kernels and APIs
- Provide reference implementations for tests and benchmarks
- Add/adjust tests, tolerances, and benchmarks
- Document memory trade-offs, usage notes, and options
@yukiu00 yukiu00 marked this pull request as ready for review February 4, 2026 09:50
@yukiu00 yukiu00 changed the title Add mHC fused kernels + LigerMHC API + benchmarks (paper-aligned) Add mHC fused kernels + LigerMHC API + benchmarks Feb 4, 2026
Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work🚀 Looking forward to have mHC in Liger Kernel!

I've only skimmed through the code, most comments are about code structure. I'll have a thorough review on implementation tomorrow. Thanks for your patience!

Comment on lines 13 to 23
def _time_loop(fn, iters=200, warmup=50) -> float:
torch.cuda.synchronize()
for _ in range(warmup):
fn()
torch.cuda.synchronize()
t0 = time.time()
for _ in range(iters):
fn()
torch.cuda.synchronize()
t1 = time.time()
return (t1 - t0) * 1e3 / iters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we prefer using triton.testing.do_bench() to bench our kernels. refer to other benchmark scripts for concrete examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. I switched benchmark_mhc.py to use triton.testing.do_bench() with QUANTILES for all forward/backward measurements.

Comment on lines 26 to 30
def _peak_bytes(fn) -> int:
torch.cuda.reset_peak_memory_stats()
fn()
torch.cuda.synchronize()
return int(torch.cuda.max_memory_allocated())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _test_memory(

we have utils._test_memory() for checking memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Memory measurement now uses utils._test_memory() consistently.


import torch

from utils import mhc_coeffs_ref
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep one reference only for single source of truth. Instead of writing another one in benchmark/scripts/utils.py, reuse the one from test.transformers.test_mhc to avoid inconsistent reference in future update.

Add root directory to your path

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

then you can access test directory in your function.

from test.transformers.test_dyt import TorchDyT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved mhc_sinkhorn_ref / mhc_coeffs_ref into test_mhc.py and updated benchmarks to import from there. I also added the repo root to sys.path in the benchmark scripts.

Comment on lines 8 to 15
from liger_kernel.triton.mhc import mhc_mm_norm_bwd
from liger_kernel.triton.mhc import mhc_mm_norm_fwd
from liger_kernel.triton.mhc import mhc_post_res_bwd
from liger_kernel.triton.mhc import mhc_post_res_fwd
from liger_kernel.triton.mhc import mhc_pre_bwd
from liger_kernel.triton.mhc import mhc_pre_fwd
from liger_kernel.triton.mhc import mhc_sinkhorn_bwd
from liger_kernel.triton.mhc import mhc_split_sinkhorn_fwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define them in this file directly for codebase structure. Happy to discuss other approaches if you feel the file become too large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I moved the Triton kernels into mhc.py and removed mhc.py.

def test_mhc_mini_lm_convergence():
set_seed(0)

device = "cuda"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use liger_kernel.utils.infer_device to get device since we support mutliple backends, not just cuda

HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mhc_mini_lm.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll disucss with folks and decide whether testing mhc architecture or not.

yukiu00 and others added 4 commits February 5, 2026 10:00
- Remove backward-compatible alias functions (108 lines)
- Add docstring and comments to _post_res_default_meta
- Use Union[] instead of | for Python 3.9 compatibility
- Replace assert with ValueError for better debugging
- Add runtime warning when hc > 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can move this part to test_mhc.py. No need to check whether losses decreases, we only have to check whether the outputs generated by two models, one with torch refs and the other with liger's mhc components, are close enough, using torch.testing.assert_close() or utils.assert_verblose_allclose()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow the other benchscripts structure so that we can track the benchmarking results with ci. Take benchmark_dpo_loss.py as an example, there are some key points you need to cover:

Pack configs, sample ranges, etc in a dictionary

common_configs = {
"kernel_name": "dpo_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,

Define your bench_speed/memory functions

def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

Extract configs in your custom functions
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

Use do_bench() to get statistics

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

Pack your statistics into SingleBenchmarkRunOutput and return

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)

Call run_benchmark() in main function

run_benchmarks(
bench_test_fn=bench_speed_dpo_loss,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Restructured benchmark_mhc.py to follow the standard framework: uses SingleBenchmarkRunInput/Output, run_benchmarks, _test_memory, and parse_benchmark_script_args, following benchmark_dpo_loss.py as reference. All configs are packed into extra_benchmark_configs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Applied the same restructuring to benchmark_mhc_lm.py.

Comment on lines +330 to +334
int(tmax),
float(rms_eps),
float(pre_eps),
float(sinkhorn_eps),
float(post_mult),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any considerations why we have to cast them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The casts (int(tmax), float(rms_eps), etc.) convert config scalars from tensors or numpy types into plain Python types, ensuring they are not accidentally included in the autograd graph. Added a clarifying comment at L322.

Comment on lines 16 to 19
def _as_scalar(x: Union[torch.Tensor, float]) -> float:
if isinstance(x, torch.Tensor):
return float(x.item())
return float(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just read the tensor in triton code isntead of turning it into python value? Moving tensors to cpu might cause gpu idle, we probably want to avoid it at best.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Alpha values (alpha_pre, alpha_post, alpha_res) are now passed as tensor pointers to the Triton kernels and read directly on GPU via tl.load(). There are no .item() or .cpu() calls.

h_post: [..., HC] FP32
h_res: [..., HC, HC] FP32
"""
assert x.is_cuda, "CUDA only"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no cuda only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 1847 to 1851
def liger_mhc_pre(x: torch.Tensor, h_pre: torch.Tensor) -> torch.Tensor:
"""
Apply H_pre: x_in = sum_i h_pre[i] * x[i]
"""
return LigerMHCPreFunction.apply(x, h_pre)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move functionals to transformers/functional.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. All functional APIs (liger_mhc_coeffs, liger_mhc_pre, liger_mhc_post_res, liger_mhc_apply, liger_mhc_forward) are now in transformers/functional.py.

return LigerMHCPostResFunction.apply(x, f_out, h_post, h_res)


def liger_mhc_apply(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return x_out


def liger_mhc_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Wraps a layer F: [..., C] -> [..., C] with mHC residual streams: [..., HC, C].
Args:
layer: module applied to the aggregated stream input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an example in the docstring? It's still unclear to me how we can wrap the existing modules with LigerMHC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added an Example:: section in the docstring (L61-77) that shows how to wrap a linear layer with LigerMHC and how to use it inside a transformer block.

- Remove all `assert x.is_cuda` checks for multi-backend support
- Eliminate `_as_scalar()` GPU sync by passing alpha params as pointers
  to Triton kernels (use `tl.load()` instead of `.item()`)
- Merge duplicate TC/TF32 kernel pairs into unified kernels with
  `CAST_FP32: tl.constexpr` compile-time flag (~180 lines removed)
- Replace `view(N, ...)` with `view(-1, ...)` across autograd Functions
- Move functional APIs from `ops.mhc` to `transformers.functional`
- Improve `LigerMHC` docstring with architecture, args, and examples
- Rewrite `benchmark_mhc.py` to standard framework (run_benchmarks)
- Use `infer_device()` in convergence test instead of hardcoded "cuda"
… skipif

- benchmark_mhc.py: pass all config params via extra_benchmark_configs
  following the DPO benchmark pattern
- test_mhc_mini_lm.py: remove redundant torch.cuda.is_available() skipif
  (supports_bfloat16() already covers this case)
…st_mhc.py for improved organization and maintainability of convergence tests.
- Remove CUDA-only skipif decorators from tests for multi-backend support
- Simplify _flatten_tokens to return x_shape, remove _unflatten_tokens helper
- Remove dead Makefile reference to deleted test_mhc_mini_lm.py
…istency and clarity. Update function signatures in `_post_res_default_meta` and `_post_res_meta` to use `Tuple` from the `typing` module.
- Remove no-op mask=True from Sinkhorn backward kernels
- Drop unused rms_eps/pre_eps from ctx.meta in coeffs backward
- Remove redundant .contiguous() calls inside @ensure_contiguous methods
- Simplify grad_x reshape to use x_shape directly
- Simplify device detection in LigerMHC to try/except pattern
- Replace torch.allclose with assert_verbose_allclose in tests
- Standardize seed to set_seed(42) across all tests
- Merge test_mhc_coeffs_allow_fp32 into test_mhc_coeffs_forward_backward
- Add backward coverage to test_mhc_pre_and_post_res_match_reference
- Widen bf16 tolerance for layer.weight.grad and phi.grad in module test
- Move hardcoded B into extra_benchmark_configs (benchmark_mhc.py)
- Rename MiniMHCLM to BenchMiniMHCLM in benchmark_mhc_lm.py
- Split _build_models into single-provider _build_model
@yukiu00
Copy link
Contributor Author

yukiu00 commented Feb 9, 2026

Thank you for the detailed review, @Tcc0403! All review comments have been addressed — I've replied to each one individually above.

In addition to the review feedback, I also made a few minor cleanups:

  • Replaced torch.allclose with assert_verbose_allclose across all tests
  • Standardized random seeds to set_seed(42)
  • Merged the test_mhc_coeffs_allow_fp32 test into the main parametrized test_mhc_coeffs_forward_backward
  • Added backward pass coverage to test_mhc_pre_and_post_res_match_reference
  • Removed redundant .contiguous() calls inside @ensure_contiguous-decorated methods
  • Simplified ctx.meta tuple by dropping unused fields (rms_eps, pre_eps)
  • Simplified device detection in LigerMHC.__init__
  • Renamed benchmark model class to BenchMiniMHCLM to avoid name collision with test
  • Split _build_models into single-provider _build_model to avoid unnecessary GPU allocation
  • Moved hardcoded B into extra_benchmark_configs in benchmark_mhc.py

Please let me know if there's anything else that needs attention!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add mHC (Manifold-Constrained Hyper-Connections) fused kernels to Liger-Kernel

2 participants