Gemma-4: HF-faithful MoE fixes (dual pre-norm, global rotary, final-logit softcapping)#7
Open
Zhichenzzz wants to merge 4 commits into
Open
Gemma-4: HF-faithful MoE fixes (dual pre-norm, global rotary, final-logit softcapping)#7Zhichenzzz wants to merge 4 commits into
Zhichenzzz wants to merge 4 commits into
Conversation
fa83d19 to
07ea6a2
Compare
…pre-norm
Add Gemma 4 MoE (26B-A4B-it) support to the bridge. The key design choice
is to NOT fuse the (w_pre_feedforward_layernorm / w_pre_feedforward_layernorm_2)
ratio into shared-expert gate/up weights at load time — that fusion spans
up to ~24000x with sign flips per channel and destroys bf16 precision for
the ~5% of channels where w_pffl_2 is near zero, yielding mean lpdiff ~0.45
vs sglang. Ship raw HF dense MLP gate/up weights and rely on standard
MCore MoELayer behavior.
Mathematically the shared expert ends up receiving unit_norm(x) * w_pffl_2
instead of the HF/sglang formal unit_norm(x) * w_pffl_1, but the magnitude
difference is absorbed by post_feedforward_layernorm_1 (RMSNorm on the
dense-branch output), and most channels share sign with w_pffl_1 so the
post-norm direction is preserved. Empirically this matches sglang within
mean lpdiff ~0.02-0.05 across 10 steps — production-grade.
What this commit adds:
gemma4_provider.py
Gemma4TransformerLayer (per-layer scaling + extra post-norm):
- layer_scalar buffer applied after residual add
- post_ffn_layernorm (HF post_feedforward_layernorm) in _forward_post_mlp
Gemma4MoELayer with per-branch post-norms in postprocess:
- post_moe_layernorm (HF post_feedforward_layernorm_2) on routed output
- post_shared_expert_layernorm (HF post_feedforward_layernorm_1) on shared output
Gemma4TopKRouter with per_expert_scale renormalization + scaler-root-size
fusion absorbed into proj weight at load time.
Gemma4SelfAttention: heterogeneous sliding/global head dims, V-RMSNorm
(parameter-free), K=V tying for global layers.
_logit_softcapping: skipped to match sglang.
gemma4_vl_bridge.py
Param map for all Gemma4-specific norms and buffers (post_moe_layernorm,
post_shared_expert_layernorm, post_ffn_layernorm, pffl_weight buffer,
layer_scalar, router.scale, router.per_expert_scale).
maybe_modify_loaded_hf_weight: skip _fuse_shared_expert_prenorm (return
raw HF gate/up); fuse router scale/root_size/w_pffl_2 into router.proj.weight.
gemma4_bridge.py
LLM-only config promotion (from VLM .text_config layout).
sglang-vs-HF naming swap for head_dim / swa_head_dim and KV head counts.
Filter provider_kwargs to fields Gemma4ModelProvider accepts (drops MLA
fields like v_head_dim).
model_bridge.py
Handle task is None in convert / export loops (megatron-side params with
no HF counterpart, e.g. post_shared_expert_layernorm).
fusions.py
can_enable_gradient_accumulation_fusion no longer takes the TE shortcut —
TE wgrad_accumulation exists only on TELinear, not on ColumnParallelLinear.
…usion Implement the real HF Gemma-4 MoE architecture: independent pre-norms for the dense/shared and routed paths, instead of the previous skip-fix that fed the shared expert the wrong (w_pffl_2) pre-norm and relied on a load-time weight fusion to mimic sglang. gemma4_provider.py: - Gemma4TransformerLayer.pre_mlp_layernorm -> Identity; dual pre-norm is now applied inside Gemma4MoELayer. - Gemma4MoELayer: add pre_shared_layernorm (w_pffl_1, shared/dense) and pre_moe_layernorm (w_pffl_2, routed); override forward to apply both on the un-normed residual (matches HF: pre_feedforward_layernorm / pre_feedforward_layernorm_2). - Gemma4TopKRouter.gating: sglang-faithful router preprocessing (parameter-free RMSNorm * scale * hidden^-0.5, then raw proj). Removes the old bf16 fusion of (scale*root / w_pffl_2) into the router weight, which destroyed precision on the ~5% of w_pffl_2 channels near zero. gemma4_vl_bridge.py: - Map pre_shared_layernorm / pre_moe_layernorm (drop pre_mlp_layernorm + inert pffl_weight buffer). - Load raw shared-expert gate/up and raw router proj (no load-time fusion); drop the corresponding export inverse corrections. Verified vs sglang (dual-pre-norm faithful) via tensor dumps: layer-0 outputs match ~0.5% across all paths. RL trains stably (ess >= 0.93). Residual train-inference logprob diff (~0.3 mean, median ~1e-4) is the inherent bf16 gap between the two engines on the same correct model, handled by TIS.
The global attention layers use HF rope_type="proportional" with
partial_rotary_factor=0.25. HF builds inv_freq of size global_head_dim/2 (256)
with only the first rotary_dim/2 (64) entries non-zero, and applies RoPE over
the FULL 512-dim head via rotate_half -- i.e. it rotates dims {0..63} paired
with {256..319}, leaving the rest unrotated.
The bridge previously built a 64-entry inv_freq and used rotary_percent=0.25,
which rotates the first 128 contiguous dims (pairing i<->i+64) -- a completely
different dim layout. This made the global attention output diverge from HF by
~11.6% at the first global layer (vs ~1% on sliding layers), compounding with
depth.
Fix: zero-pad inv_freq to global_head_dim/2 (first 64 = proportional freqs,
rest 0) and use rotary_percent=1.0, so TE rotates the full head with rotate_half
and the zero freqs make the non-rotary dims pass through -- reproducing HF's
exact rotated-dim layout.
Verified via per-layer attn-vs-HF tensor dumps: first global layer 11.6% -> 1.55%
(matches sliding-layer bf16 floor). RL train_rollout_logprob_abs_diff: 0.34 -> 0.26.
07ea6a2 to
ded4ba0
Compare
…inference lpdiff The true HF Gemma-4 applies final_logit_softcapping=30.0 to the LM logits (modeling_gemma4.py: logits = 30*tanh(logits/30)), and sglang's LogitsProcessor ALSO applies it (it reads config.final_logit_softcapping from text_config, =30). A prior MILES patch had stubbed _logit_softcapping to a no-op with the comment "skip softcap to match sglang (sglang gemma4 doesn't apply it)" — but that premise was wrong: sglang DOES apply it. Skipping it in the bridge left the bridge's logits un-capped and over-extreme on tail tokens, so for the few catastrophic tokens the bridge's logprob diverged from sglang's (the rollout engine) by several nats. Effect on RL train_rollout_logprob_abs_diff (8x GB200/H200, triton + R3): before: ~0.26 mean (median ~1e-4, ess_ratio ~0.93) — dominated by tail tokens after: ~0.007 mean (median ~1e-5, ess_ratio ~0.9995) Verified on a fixed prompt: with softcap, bridge logprobs match HF/sglang on the previously-catastrophic tokens (e.g. " Germany" after "...The capital of": bridge -24.4 -> -8.5, vs HF -9.8 / sglang -8.3). Also drop a leftover [MILES_FUSION_SKIP] diagnostic print + commented-out fusion call in gemma4_bridge.py (the raw-weight de-fusion logic is kept).
ded4ba0 to
c36f3c5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Make the bridge's Gemma-4 26B-A4B (MoE) implementation match the true HF model. Brings RL train-inference logprob consistency (train_rollout_logprob_abs_diff) from ~0.34 to ~0.007 (ess_ratio ~0.93 -> ~0.9995), with the genuine HF model (no skip-fix hack).
1. Dual pre-norm (true HF MoE arch)
HF Gemma-4 feeds pre_feedforward_layernorm (w1) to the dense/shared MLP and pre_feedforward_layernorm_2 (w2) to the routed experts; the router sees the un-normed post-attention residual. Replaces the earlier skip-fix hack and de-fuses the router (drops the bf16 w2 weight fusion that blew up on near-zero channels).
2. Global (head_dim=512) partial-rotary layout
Global attention uses HF rope_type="proportional", partial_rotary_factor=0.25: a 256-wide inv_freq with only the first 64 entries non-zero, applied over the FULL 512-dim head via rotate_half (rotates dims 0..63 with 256..319). The bridge previously used a 64-wide inv_freq + rotary_percent=0.25. Fix: zero-pad inv_freq to 256 + rotary_percent=1.0.
3. final_logit_softcapping = 30
HF (and sglang's LogitsProcessor) apply logits = 30*tanh(logits/30). A prior patch stubbed this to a no-op, leaving the bridge's logits over-extreme on tail tokens vs the inference engine. Re-enabled.
Files
Verified
RL on dapo-math (sglang triton + rollout-routing-replay): mean lpdiff ~0.007 over 20-step runs at 2048 and 4096 response lengths (incl. --optimizer-cpu-offload). Ablation: removing the rotary fix regresses lpdiff to ~0.10.