Skip to content

Gemma-4: HF-faithful MoE fixes (dual pre-norm, global rotary, final-logit softcapping)#7

Open
Zhichenzzz wants to merge 4 commits into
bridgefrom
zhichen/gemma4-bridge-pr
Open

Gemma-4: HF-faithful MoE fixes (dual pre-norm, global rotary, final-logit softcapping)#7
Zhichenzzz wants to merge 4 commits into
bridgefrom
zhichen/gemma4-bridge-pr

Conversation

@Zhichenzzz
Copy link
Copy Markdown

@Zhichenzzz Zhichenzzz commented May 27, 2026

Summary

Make the bridge's Gemma-4 26B-A4B (MoE) implementation match the true HF model. Brings RL train-inference logprob consistency (train_rollout_logprob_abs_diff) from ~0.34 to ~0.007 (ess_ratio ~0.93 -> ~0.9995), with the genuine HF model (no skip-fix hack).

1. Dual pre-norm (true HF MoE arch)

HF Gemma-4 feeds pre_feedforward_layernorm (w1) to the dense/shared MLP and pre_feedforward_layernorm_2 (w2) to the routed experts; the router sees the un-normed post-attention residual. Replaces the earlier skip-fix hack and de-fuses the router (drops the bf16 w2 weight fusion that blew up on near-zero channels).

2. Global (head_dim=512) partial-rotary layout

Global attention uses HF rope_type="proportional", partial_rotary_factor=0.25: a 256-wide inv_freq with only the first 64 entries non-zero, applied over the FULL 512-dim head via rotate_half (rotates dims 0..63 with 256..319). The bridge previously used a 64-wide inv_freq + rotary_percent=0.25. Fix: zero-pad inv_freq to 256 + rotary_percent=1.0.

3. final_logit_softcapping = 30

HF (and sglang's LogitsProcessor) apply logits = 30*tanh(logits/30). A prior patch stubbed this to a no-op, leaving the bridge's logits over-extreme on tail tokens vs the inference engine. Re-enabled.

Files

  • models/gemma/gemma4_provider.py, models/gemma/gemma4_bridge.py
  • models/gemma_vl/gemma4_vl_bridge.py - required: gemma-4-26B-A4B is Gemma4ForConditionalGeneration, so the bridge loads its (text-decoder) weights through this file; it carries the dual-pre-norm weight mappings. Reverting it makes the bridge load the checkpoint with the wrong mapping (train-inference lpdiff blows up to ~18).
  • models/conversion/model_bridge.py (skip megatron-side-only params), utils/fusions.py

Verified

RL on dapo-math (sglang triton + rollout-routing-replay): mean lpdiff ~0.007 over 20-step runs at 2048 and 4096 response lengths (incl. --optimizer-cpu-offload). Ablation: removing the rotary fix regresses lpdiff to ~0.10.

…pre-norm

Add Gemma 4 MoE (26B-A4B-it) support to the bridge.  The key design choice
is to NOT fuse the (w_pre_feedforward_layernorm / w_pre_feedforward_layernorm_2)
ratio into shared-expert gate/up weights at load time — that fusion spans
up to ~24000x with sign flips per channel and destroys bf16 precision for
the ~5% of channels where w_pffl_2 is near zero, yielding mean lpdiff ~0.45
vs sglang.  Ship raw HF dense MLP gate/up weights and rely on standard
MCore MoELayer behavior.

Mathematically the shared expert ends up receiving unit_norm(x) * w_pffl_2
instead of the HF/sglang formal unit_norm(x) * w_pffl_1, but the magnitude
difference is absorbed by post_feedforward_layernorm_1 (RMSNorm on the
dense-branch output), and most channels share sign with w_pffl_1 so the
post-norm direction is preserved.  Empirically this matches sglang within
mean lpdiff ~0.02-0.05 across 10 steps — production-grade.

What this commit adds:

  gemma4_provider.py
    Gemma4TransformerLayer (per-layer scaling + extra post-norm):
      - layer_scalar buffer applied after residual add
      - post_ffn_layernorm (HF post_feedforward_layernorm) in _forward_post_mlp
    Gemma4MoELayer with per-branch post-norms in postprocess:
      - post_moe_layernorm (HF post_feedforward_layernorm_2) on routed output
      - post_shared_expert_layernorm (HF post_feedforward_layernorm_1) on shared output
    Gemma4TopKRouter with per_expert_scale renormalization + scaler-root-size
    fusion absorbed into proj weight at load time.
    Gemma4SelfAttention: heterogeneous sliding/global head dims, V-RMSNorm
    (parameter-free), K=V tying for global layers.
    _logit_softcapping: skipped to match sglang.

  gemma4_vl_bridge.py
    Param map for all Gemma4-specific norms and buffers (post_moe_layernorm,
    post_shared_expert_layernorm, post_ffn_layernorm, pffl_weight buffer,
    layer_scalar, router.scale, router.per_expert_scale).
    maybe_modify_loaded_hf_weight: skip _fuse_shared_expert_prenorm (return
    raw HF gate/up); fuse router scale/root_size/w_pffl_2 into router.proj.weight.

  gemma4_bridge.py
    LLM-only config promotion (from VLM .text_config layout).
    sglang-vs-HF naming swap for head_dim / swa_head_dim and KV head counts.
    Filter provider_kwargs to fields Gemma4ModelProvider accepts (drops MLA
    fields like v_head_dim).

  model_bridge.py
    Handle task is None in convert / export loops (megatron-side params with
    no HF counterpart, e.g. post_shared_expert_layernorm).

  fusions.py
    can_enable_gradient_accumulation_fusion no longer takes the TE shortcut —
    TE wgrad_accumulation exists only on TELinear, not on ColumnParallelLinear.
…usion

Implement the real HF Gemma-4 MoE architecture: independent pre-norms for the
dense/shared and routed paths, instead of the previous skip-fix that fed the
shared expert the wrong (w_pffl_2) pre-norm and relied on a load-time weight
fusion to mimic sglang.

gemma4_provider.py:
- Gemma4TransformerLayer.pre_mlp_layernorm -> Identity; dual pre-norm is now
  applied inside Gemma4MoELayer.
- Gemma4MoELayer: add pre_shared_layernorm (w_pffl_1, shared/dense) and
  pre_moe_layernorm (w_pffl_2, routed); override forward to apply both on the
  un-normed residual (matches HF: pre_feedforward_layernorm /
  pre_feedforward_layernorm_2).
- Gemma4TopKRouter.gating: sglang-faithful router preprocessing
  (parameter-free RMSNorm * scale * hidden^-0.5, then raw proj). Removes the
  old bf16 fusion of (scale*root / w_pffl_2) into the router weight, which
  destroyed precision on the ~5% of w_pffl_2 channels near zero.

gemma4_vl_bridge.py:
- Map pre_shared_layernorm / pre_moe_layernorm (drop pre_mlp_layernorm +
  inert pffl_weight buffer).
- Load raw shared-expert gate/up and raw router proj (no load-time fusion);
  drop the corresponding export inverse corrections.

Verified vs sglang (dual-pre-norm faithful) via tensor dumps: layer-0 outputs
match ~0.5% across all paths. RL trains stably (ess >= 0.93). Residual
train-inference logprob diff (~0.3 mean, median ~1e-4) is the inherent bf16
gap between the two engines on the same correct model, handled by TIS.
The global attention layers use HF rope_type="proportional" with
partial_rotary_factor=0.25. HF builds inv_freq of size global_head_dim/2 (256)
with only the first rotary_dim/2 (64) entries non-zero, and applies RoPE over
the FULL 512-dim head via rotate_half -- i.e. it rotates dims {0..63} paired
with {256..319}, leaving the rest unrotated.

The bridge previously built a 64-entry inv_freq and used rotary_percent=0.25,
which rotates the first 128 contiguous dims (pairing i<->i+64) -- a completely
different dim layout. This made the global attention output diverge from HF by
~11.6% at the first global layer (vs ~1% on sliding layers), compounding with
depth.

Fix: zero-pad inv_freq to global_head_dim/2 (first 64 = proportional freqs,
rest 0) and use rotary_percent=1.0, so TE rotates the full head with rotate_half
and the zero freqs make the non-rotary dims pass through -- reproducing HF's
exact rotated-dim layout.

Verified via per-layer attn-vs-HF tensor dumps: first global layer 11.6% -> 1.55%
(matches sliding-layer bf16 floor). RL train_rollout_logprob_abs_diff: 0.34 -> 0.26.
@Zhichenzzz Zhichenzzz force-pushed the zhichen/gemma4-bridge-pr branch from 07ea6a2 to ded4ba0 Compare May 27, 2026 21:18
@Zhichenzzz Zhichenzzz changed the title Gemma-4 26B-A4B-it: HF-faithful MoE fixes (dual pre-norm, global rotary, final-logit softcapping) Gemma-4: HF-faithful MoE fixes (dual pre-norm, global rotary, final-logit softcapping) May 27, 2026
…inference lpdiff

The true HF Gemma-4 applies final_logit_softcapping=30.0 to the LM logits
(modeling_gemma4.py: logits = 30*tanh(logits/30)), and sglang's LogitsProcessor
ALSO applies it (it reads config.final_logit_softcapping from text_config, =30).

A prior MILES patch had stubbed _logit_softcapping to a no-op with the comment
"skip softcap to match sglang (sglang gemma4 doesn't apply it)" — but that premise
was wrong: sglang DOES apply it. Skipping it in the bridge left the bridge's logits
un-capped and over-extreme on tail tokens, so for the few catastrophic tokens the
bridge's logprob diverged from sglang's (the rollout engine) by several nats.

Effect on RL train_rollout_logprob_abs_diff (8x GB200/H200, triton + R3):
  before: ~0.26 mean (median ~1e-4, ess_ratio ~0.93) — dominated by tail tokens
  after:  ~0.007 mean (median ~1e-5, ess_ratio ~0.9995)

Verified on a fixed prompt: with softcap, bridge logprobs match HF/sglang on the
previously-catastrophic tokens (e.g. " Germany" after "...The capital of":
bridge -24.4 -> -8.5, vs HF -9.8 / sglang -8.3).

Also drop a leftover [MILES_FUSION_SKIP] diagnostic print + commented-out
fusion call in gemma4_bridge.py (the raw-weight de-fusion logic is kept).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant