Skip to content

[gemma4] feat: add Gemma-4 31B dense model support#8

Open
Zhichenzzz wants to merge 1 commit into
zhichen/gemma4-bridge-prfrom
zhichen/gemma4-dense
Open

[gemma4] feat: add Gemma-4 31B dense model support#8
Zhichenzzz wants to merge 1 commit into
zhichen/gemma4-bridge-prfrom
zhichen/gemma4-dense

Conversation

@Zhichenzzz
Copy link
Copy Markdown

Stacked on #7 (Gemma-4 MoE HF-faithfulness fixes).

Ports upstream NVIDIA-NeMo/Megatron-Bridge#3885 ("limited Gemma4 dense support") onto this fork, adapted to our fork's norm fusion, adding Gemma-4 31B-it (dense, non-MoE) RL support.

Changes

gemma4_provider.py

  • Add attention_k_eq_v: bool = False field.
  • _install_tied_kv: gate on attention_k_eq_v (covers both MoE and dense) instead of num_moe_experts is None. Both 26B-A4B-it and 31B-it set attention_k_eq_v=True, so this is a no-op for the 26B MoE path.

gemma4_vl_bridge.py

  • Unblock the dense path for models without per-layer embeddings (hidden_size_per_layer_input == 0, e.g. 31B-it). Dense models with PLE still error since MCore lacks PLE support.
  • Branch provider config on enable_moe_block: dense path sets num_moe_experts=None and ffn_hidden_size=intermediate_size. MoE path unchanged.
  • Pass provider.attention_k_eq_v = text_config.attention_k_eq_v.
  • Add dense weight mappings (inert on MoE — those keys don't exist there):
    • mlp.linear_fc1.weight via GatedMLPMapping(gate_proj, up_proj)
    • mlp.linear_fc2.weightdown_proj
    • mlp.linear_fc1.layer_norm_weightpre_feedforward_layernorm.weightdiffers from upstream [model] feat: Add limited Gemma4 dense model support NVIDIA-NeMo/Megatron-Bridge#3885, which maps it to post_attention_layernorm. Our fork fuses post_attention_layernorm into linear_proj.post_layernorm (TERowParallelLinearLayerNorm), so the MLP's fused fc1 norm corresponds to HF's MLP input norm (pre_feedforward_layernorm), not post_attention_layernorm. Wrong norm → high lpdiff; lpdiff ~0.007 confirms this mapping.

Validation

31B-it dense (e2e RL on dapo-math-17k, 8× H200, TP=4, no expert parallelism, n=8 / response 768 for reward variance):

  • 0 missing / 0 unexpected keys on load.
  • Coherent rollouts (real math reasoning).
  • train_rollout_logprob_abs_diff0.007 (range 0.004–0.008, well under the 0.02 target).
  • Stable through real gradient updates: weight_version 1 → 2 → 3, lpdiff stays in the 0.004–0.008 band (no blowup).

26B-A4B-it MoE regression check on this branch: lpdiff 0.0086 — matches the validated baseline from #7. No regression.

Notes

Port of upstream NVIDIA-NeMo#3885 (limited Gemma4 dense support) onto the radixark fork, adapted to this forks norm fusion.

- gemma4_provider.py: add attention_k_eq_v field; gate _install_tied_kv on attention_k_eq_v (covers MoE and dense; 26B and 31B both set it True, so no regression to the 26B MoE path).

- gemma4_vl_bridge.py: unblock the dense path for models without per-layer embeddings (hidden_size_per_layer_input==0, e.g. 31B); branch provider config on enable_moe_block; add dense GatedMLP weight mappings (inert on MoE). Dense mlp.linear_fc1.layer_norm_weight maps to HF pre_feedforward_layernorm (linear_proj already carries post_attention_layernorm).

Validated e2e: 31B-it RL train-inference lpdiff ~0.007, stable through weight updates; 26B MoE not regressed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant