diff --git a/gemma/gm/ckpts/_checkpoint.py b/gemma/gm/ckpts/_checkpoint.py index 9449b9a6..2cf55081 100644 --- a/gemma/gm/ckpts/_checkpoint.py +++ b/gemma/gm/ckpts/_checkpoint.py @@ -150,6 +150,13 @@ def make_tree_for_params( if self.has_mm_params and not params.has_mm_params: ckpt_params = _add_skip_mm_params(ckpt_params, metadata) + # Reconcile known structural mismatches between model-init and + # checkpoint (e.g. Gemma4 LoRA: empty wrapper stubs from split_params, + # leaf-vs-dict format from nn.share_scope). Only triggers when + # mismatches are detected; Gemma3/legacy paths are unchanged. + if _needs_reconciliation(ckpt_params, self.nested_tree): + ckpt_params = _reconcile_tree(ckpt_params, self.nested_tree) + # 2. Reformat the nested tree to match the checkpoint structure. if self.type == _CheckpointType.NESTED: target_params = ckpt_params # No need to reformat @@ -170,7 +177,11 @@ def make_tree_for_params( @functools.cached_property def has_mm_params(self) -> bool: - return 'vision_encoder' in self.nested_tree + # Check for any known multimodal encoder (vision or audio). + return ( + 'vision_encoder' in self.nested_tree + or 'audio_encoder' in self.nested_tree + ) @functools.cached_property def has_audio_input_embedding(self) -> bool: @@ -395,10 +406,22 @@ def _remove_mm_params(params): # TODO(epot): Once orbax supports partial restore, we would not need to # load those extra params in the first place. - del params['vision_encoder'] - for k in ('mm_input_projection', 'mm_soft_embedding_norm'): + # Vision params + if 'vision_encoder' in params: + del params['vision_encoder'] + for k in ('mm_input_projection', 'mm_soft_embedding_norm', + 'mm_pre_projection_norm', 'mm_input_embedding_extra'): + if k in params.get('embedder', {}): + del params['embedder'][k] + + # Audio params (Gemma4) + if 'audio_encoder' in params: + del params['audio_encoder'] + for k in ('audio_input_projection', 'audio_soft_embedding_norm', + 'audio_input_embedding', 'audio_input_embedding_extra'): if k in params.get('embedder', {}): del params['embedder'][k] + return params @@ -407,14 +430,129 @@ def _add_skip_mm_params(params: Params, metadata: _CheckpointTree) -> Params: params = etree.copy(params) params_with_mm = metadata.nested_tree - params['vision_encoder'] = params_with_mm['vision_encoder'] - for k in ('mm_input_projection', 'mm_soft_embedding_norm'): - if k in params_with_mm.get('embedder', {}): - params['embedder'][k] = params_with_mm['embedder'][k] + # Known top-level multimodal encoder keys. + _MM_TOP_LEVEL_KEYS = ('vision_encoder', 'audio_encoder') + # Known embedder-level multimodal projection/norm keys. + _MM_EMBEDDER_KEYS = ( + # Vision + 'mm_input_projection', + 'mm_soft_embedding_norm', + 'mm_pre_projection_norm', + 'mm_input_embedding_extra', + # Audio + 'audio_input_projection', + 'audio_soft_embedding_norm', + 'audio_input_embedding', + 'audio_input_embedding_extra', + ) + + for k in _MM_TOP_LEVEL_KEYS: + if k in params_with_mm and k not in params: + params[k] = params_with_mm[k] + + embedder_mm = params_with_mm.get('embedder', {}) + for k in _MM_EMBEDDER_KEYS: + if k in embedder_mm and k not in params.get('embedder', {}): + params['embedder'][k] = embedder_mm[k] return params +def _needs_reconciliation(params: Params, metadata_tree: Params) -> bool: + """Returns True if the model params tree has known structural mismatches. + + Detects two patterns that arise when LoRA interceptors interact with + models using ``nn.share_scope`` (e.g. Gemma4 FeedForward): + + 1. **Empty ``{}`` stubs** left by ``peft.split_params`` at LoRA wrapper + scopes (e.g. ``_LoRAEinsum_0``). These keys exist in the model tree + but not in the checkpoint. + 2. **Leaf-vs-dict format**: ``nn.share_scope`` flattens ``{'w': array}`` + to bare ``ArrayImpl`` in the model-init tree, while the checkpoint + keeps the dict format. + + This check is intentionally conservative — it returns ``False`` for + Gemma3 and legacy checkpoints, so their restore paths are unchanged. + """ + if not isinstance(params, dict) or not isinstance(metadata_tree, dict): + return False + + for k, p_val in params.items(): + if k not in metadata_tree: + # Key in model but not in checkpoint (e.g. LoRA stub). + if isinstance(p_val, dict) and not p_val: + return True + continue + m_val = metadata_tree[k] + # Leaf-vs-dict mismatch. + if not isinstance(p_val, dict) and isinstance(m_val, dict): + return True + # Recurse into sub-dicts. + if isinstance(p_val, dict) and isinstance(m_val, dict): + if _needs_reconciliation(p_val, m_val): + return True + + return False + + +def _reconcile_tree(params: Params, metadata_tree: Params) -> Params: + """Align model-init params tree to match checkpoint metadata structure. + + Only called when :func:`_needs_reconciliation` returns ``True``. + + Handles two known mismatches between ``model.init()`` and on-disk + checkpoints: + + 1. **Empty stubs**: LoRA wrappers (or other interceptors) may leave + empty dict scopes in the params tree that don't exist in the + checkpoint. These are dropped. + 2. **Leaf-vs-dict format**: ``nn.share_scope`` in Gemma4 ``FeedForward`` + flattens ``{'w': array}`` to bare ``ArrayImpl`` during model init. + When the checkpoint stores ``{'w': array}``, the leaf is wrapped to + match. + + Args: + params: The model-init params tree (may contain stubs / format + mismatches). + metadata_tree: The checkpoint metadata tree (ground-truth structure). + + Returns: + A new params tree aligned to the checkpoint metadata structure. + """ + if not isinstance(params, dict) or not isinstance(metadata_tree, dict): + return params + + result = {} + for k in metadata_tree: + if k not in params: + # Key in checkpoint but not in model (e.g. MM params handled by + # _add_skip_mm_params separately) — skip. + continue + p_val = params[k] + m_val = metadata_tree[k] + + if isinstance(p_val, dict) and isinstance(m_val, dict): + # Both dicts — recurse. + reconciled = _reconcile_tree(p_val, m_val) + if reconciled: # Drop if empty after reconciliation. + result[k] = reconciled + elif not isinstance(p_val, dict) and isinstance(m_val, dict): + # Model has leaf (ArrayImpl), checkpoint has dict ({'w': ...}). + # Wrap the leaf to match checkpoint format. + if len(m_val) == 1: + inner_key = next(iter(m_val)) + result[k] = {inner_key: p_val} + else: + result[k] = p_val # Fallback: keep as-is. + else: + # Both leaves, or model has dict but checkpoint has leaf. + result[k] = p_val + + # Keys in params but NOT in metadata are intentionally dropped. + # This strips LoRA wrapper stubs (_LoRAEinsum_0, etc.). + return result + + def _is_flat_layout(params: Params) -> bool: """Returns True is the structure is the legacy one.""" return (not _is_stacked_layout(params)) and all( diff --git a/gemma/gm/data/_tasks.py b/gemma/gm/data/_tasks.py index 664c92dc..edde38dc 100644 --- a/gemma/gm/data/_tasks.py +++ b/gemma/gm/data/_tasks.py @@ -21,13 +21,42 @@ import einops from etils.etree import jax as etree # pylint: disable=g-importing-member from gemma.gm.data import _functional -from gemma.gm.text import _template from gemma.gm.text import _tokenizer from grain import python as grain import jax from kauldron import kd import numpy as np +# Turn tag strings indexed by tokenizer FORMAT. +# Gemma4 uses '<|turn>' / '', all others use '' / +# ''. Importing the `dialog` library is intentionally avoided +# to keep the dep footprint small; the two known format strings are inlined. +_TURN_TAGS: dict[str, tuple[str, str]] = {} + + +def _get_turn_tags( + tokenizer: _tokenizer.Tokenizer, +) -> tuple[str, str]: + """Returns (start_of_turn, end_of_turn) tag strings for *tokenizer*.""" + fmt = getattr(tokenizer, 'FORMAT', None) + # dialog.Format.GEMMA4 has value 'gemma4' (StrEnum). + if fmt is not None and str(fmt).lower() == 'gemma4': + return ('<|turn>', '') + # Default: Gemma3 / legacy format. + return ('', '') + + +def _format_prompt(prompt: str, tokenizer: _tokenizer.Tokenizer) -> str: + """Formats *prompt* with the correct turn tags for *tokenizer*.""" + sot, eot = _get_turn_tags(tokenizer) + return f'{sot}user\n{prompt}{eot}\n{sot}model\n' + + +def _format_answer(response: str, tokenizer: _tokenizer.Tokenizer) -> str: + """Formats *response* with the correct turn tags for *tokenizer*.""" + _, eot = _get_turn_tags(tokenizer) + return f'{response}{eot}' + @dataclasses.dataclass(kw_only=True, frozen=True) class Seq2SeqTask(grain.MapTransform): @@ -115,10 +144,10 @@ def map(self, element): prompt = _decode_bytes(prompt) response = _decode_bytes(response) - # Format the input to match the expected dialog template. - # TODO(epot): Add a `template` protocol to allow customizing this. - prompt = _template.PROMPT.format(prompt) - response = _template.ANSWER.format(response) + # Format the input using tokenizer-aware turn tags. + # TODO(epot): Add a `template` protocol for full customization. + prompt = _format_prompt(prompt, self.tokenizer) + response = _format_answer(response, self.tokenizer) # For sampling, we don't need to tokenize the input. if self.sampling: @@ -219,11 +248,11 @@ def map(self, element): chosen = _decode_bytes(chosen) rejected = _decode_bytes(rejected) - # Format the input to match the expected dialog template. - # TODO(epot): Move this in a separate FormatDialog transform. - prompt = _template.PROMPT.format(prompt) - chosen = _template.ANSWER.format(chosen) - rejected = _template.ANSWER.format(rejected) + # Format the input using tokenizer-aware turn tags. + # TODO(epot): Extract into a standalone FormatDialog transform. + prompt = _format_prompt(prompt, self.tokenizer) + chosen = _format_answer(chosen, self.tokenizer) + rejected = _format_answer(rejected, self.tokenizer) # Tokenize the input and the responses. # Note: Input should start with begin-of-sequence token. diff --git a/gemma/gm/nn/_lora.py b/gemma/gm/nn/_lora.py index 11b5a98e..ea7f439e 100644 --- a/gemma/gm/nn/_lora.py +++ b/gemma/gm/nn/_lora.py @@ -21,13 +21,26 @@ from flax import linen as nn from gemma import peft from gemma.gm.nn import _layers +from gemma.gm.nn.gemma3n import _layers as _gemma3n_layers +from gemma.gm.nn.gemma4 import _layers as _gemma4_layers import jax import jax.numpy as jnp from kauldron import kontext import numpy as np -_SUPPORTED_MODULES = (nn.Dense, nn.Einsum, nn.DenseGeneral, _layers.Einsum) +_SUPPORTED_MODULES = ( + nn.Dense, + nn.Einsum, + nn.DenseGeneral, + _layers.Einsum, + _gemma4_layers.Einsum, + _gemma4_layers.ClippedEinsum, + _gemma3n_layers.Einsum, + # NOTE: nano._layers.Einsum is excluded because nano:nano depends on + # //third_party/py/gemma/gm, creating a circular BUILD dependency. + # To add it, nano._layers needs its own fine-grained BUILD target. +) class LoRA(nn.Module): @@ -107,7 +120,7 @@ def _replace_by_lora( if debug_str: logging.info(debug_str) - # TODO(epot): Replace by generic LoRA wrapper ? +# TODO(epot): Replace by generic LoRA wrapper ? match module: case nn.Dense(): return peft.LoRADense(rank=rank, dtype=dtype, wrapped=module) @@ -115,11 +128,11 @@ def _replace_by_lora( return peft.LoRAEinsum(rank=rank, dtype=dtype, wrapped=module) case nn.DenseGeneral(): return peft.LoRADenseGeneral(rank=rank, dtype=dtype, wrapped=module) - case _layers.Einsum(): - # This hack is required because the FeedForward layer call two different - # Einsum with using `nn.share_scope`, so the two wrappers need a different - # name. - # This seems to be a bug in flax interceptor. + case _ if isinstance(module, _SUPPORTED_MODULES): + # All custom Einsum variants (gm.nn, gemma4, gemma3n, nano, etc.) + # use `_LoRAEinsum` wrapper. The name hack is required because + # FeedForward uses `nn.share_scope` to flatten two Einsum modules + # into the same param scope — the two wrappers need distinct names. if module.weight_name != 'w': name = f'_LoRAEinsum_{module.weight_name}' else: @@ -135,7 +148,7 @@ class _LoRAEinsum(nn.Module): _: dataclasses.KW_ONLY rank: int dtype: np.dtype - wrapped: _layers.Einsum + wrapped: nn.Module # Any Einsum variant (gm.nn, gemma4, gemma3n, nano) # Do not use `nn.share_scope` here as the `wrapped` module inside # `FeedForward` already uses `nn.share_scope`, so the two Einsum used in diff --git a/gemma/gm/nn/_lora_test.py b/gemma/gm/nn/_lora_test.py new file mode 100644 index 00000000..28df47c8 --- /dev/null +++ b/gemma/gm/nn/_lora_test.py @@ -0,0 +1,250 @@ +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for gm.nn.LoRA with various Einsum implementations.""" + +from flax import linen as nn +from gemma import peft +from gemma.gm.ckpts import _checkpoint +from gemma.gm.nn import _lora +from gemma.gm.nn.gemma3n import _layers as _gemma3n_layers +from gemma.gm.nn.gemma4 import _layers as _gemma4_layers +import jax +import jax.numpy as jnp +import numpy as np + + +class _ModelWithGemma4Einsum(nn.Module): + """Test model using Gemma4 Einsum modules.""" + + @nn.compact + def __call__(self, x): + y = _gemma4_layers.Einsum(shape=(4, 3))('bi,io->bo', x) + return y + + +class _ModelWithClippedEinsum(nn.Module): + """Test model using Gemma4 ClippedEinsum modules.""" + + @nn.compact + def __call__(self, x): + y = _gemma4_layers.ClippedEinsum(shape=(4, 3))('bi,io->bo', x) + return y + + +class _ModelWithGemma3nEinsum(nn.Module): + """Test model using Gemma3n Einsum modules.""" + + @nn.compact + def __call__(self, x): + y = _gemma3n_layers.Einsum(shape=(4, 3))('bi,io->bo', x) + return y + + +def _make_replace_fn(rank=2, dtype=jnp.bfloat16): + """Returns a LoRA replacement function for use with ModuleInterceptor.""" + return lambda m: _lora._replace_by_lora( + m, rank=rank, dtype=dtype, verbose=False + ) + + +def _init_with_lora(model, input_shape=(1, 4)): + """Initialize model with LoRA and return (params, lora_params).""" + with peft.ModuleInterceptor(_make_replace_fn()): + params = model.init(jax.random.key(0), jnp.zeros(input_shape))['params'] + _, lora_params = peft.split_params(params) + return params, lora_params + + +def test_lora_gemma4_einsum(): + """LoRA wraps Gemma4 Einsum and produces lora params.""" + _, lora_params = _init_with_lora(_ModelWithGemma4Einsum()) + leaves = jax.tree.leaves(lora_params) + assert leaves, 'Expected LoRA params for Gemma4 Einsum' + + +def test_lora_gemma4_clipped_einsum(): + """LoRA wraps Gemma4 ClippedEinsum and produces lora params.""" + _, lora_params = _init_with_lora(_ModelWithClippedEinsum()) + leaves = jax.tree.leaves(lora_params) + assert leaves, 'Expected LoRA params for Gemma4 ClippedEinsum' + + +def test_lora_gemma3n_einsum(): + """LoRA wraps Gemma3n Einsum and produces lora params.""" + _, lora_params = _init_with_lora(_ModelWithGemma3nEinsum()) + leaves = jax.tree.leaves(lora_params) + assert leaves, 'Expected LoRA params for Gemma3n Einsum' + + + +def test_lora_params_have_a_and_b(): + """LoRA params contain 'a' and 'b' matrices.""" + params, _ = _init_with_lora(_ModelWithGemma4Einsum()) + # The Einsum_0 should have a '_LoRAEinsum_0' sub-module with 'lora/a' and + # 'lora/b'. The wrapper doesn't use nn.share_scope, so the LoRA adapter + # lives in a nested sub-dict. + einsum_params = params['Einsum_0'] + assert '_LoRAEinsum_0' in einsum_params, ( + f'Missing _LoRAEinsum_0 key in {einsum_params.keys()}' + ) + lora_sub = einsum_params['_LoRAEinsum_0'] + assert 'lora' in lora_sub, f'Missing lora key in {lora_sub.keys()}' + assert 'a' in lora_sub['lora'], 'Missing lora/a matrix' + assert 'b' in lora_sub['lora'], 'Missing lora/b matrix' + + +# --------------------------------------------------------------------------- +# Checkpoint tree reconciliation tests +# --------------------------------------------------------------------------- + + +def test_needs_reconciliation_false_for_matching_trees(): + """Gemma3-like tree — structures match, no reconciliation needed.""" + params = {'layer': {'attn': {'w': np.zeros(2)}, 'mlp': {'w': np.zeros(3)}}} + metadata = {'layer': {'attn': {'w': None}, 'mlp': {'w': None}}} + assert not _checkpoint._needs_reconciliation(params, metadata) + + +def test_needs_reconciliation_true_for_empty_stubs(): + """LoRA stubs: empty {} dicts in model tree, absent from checkpoint.""" + params = {'layer': {'attn': {'w': np.zeros(2)}, '_LoRAEinsum_0': {}}} + metadata = {'layer': {'attn': {'w': None}}} + assert _checkpoint._needs_reconciliation(params, metadata) + + +def test_needs_reconciliation_true_for_format_mismatch(): + """Gemma4 share_scope: model has ArrayImpl, checkpoint has {'w': ...}.""" + params = {'mlp': {'gating_einsum': np.zeros(4)}} + metadata = {'mlp': {'gating_einsum': {'w': None}}} + assert _checkpoint._needs_reconciliation(params, metadata) + + +def test_needs_reconciliation_false_for_non_dict_leaves(): + """Both leaves are non-dicts — no mismatch.""" + params = {'a': np.zeros(2)} + metadata = {'a': None} + assert not _checkpoint._needs_reconciliation(params, metadata) + + +def test_needs_reconciliation_nested_detection(): + """Mismatch buried deep in the tree is still detected.""" + params = { + 'layer_0': { + 'attn': {'w': np.zeros(2)}, + 'mlp': {'gating_einsum': np.zeros(3)}, + } + } + metadata = { + 'layer_0': { + 'attn': {'w': None}, + 'mlp': {'gating_einsum': {'w': None}}, + } + } + assert _checkpoint._needs_reconciliation(params, metadata) + + +def test_reconcile_drops_empty_stubs(): + """Empty {} stubs from LoRA wrappers are removed.""" + params = { + 'layer': { + 'attn': {'w': 1}, + '_LoRAEinsum_0': {}, + '_LoRAEinsum_gating_einsum': {}, + } + } + metadata = {'layer': {'attn': {'w': None}}} + result = _checkpoint._reconcile_tree(params, metadata) + + assert result == {'layer': {'attn': {'w': 1}}} + assert '_LoRAEinsum_0' not in result.get('layer', {}) + + +def test_reconcile_wraps_leaf_to_dict(): + """ArrayImpl leaf is wrapped to match checkpoint {'w': ...} format.""" + arr = np.zeros(4) + params = {'mlp': {'gating_einsum': arr, 'linear': arr}} + metadata = {'mlp': {'gating_einsum': {'w': None}, 'linear': {'w': None}}} + result = _checkpoint._reconcile_tree(params, metadata) + + assert list(result['mlp']['gating_einsum'].keys()) == ['w'] + assert result['mlp']['gating_einsum']['w'] is arr + assert list(result['mlp']['linear'].keys()) == ['w'] + assert result['mlp']['linear']['w'] is arr + + +def test_reconcile_passthrough_matching(): + """No changes when trees already match (Gemma3 case).""" + params = {'a': {'b': 1, 'c': 2}} + metadata = {'a': {'b': None, 'c': None}} + result = _checkpoint._reconcile_tree(params, metadata) + + assert result == {'a': {'b': 1, 'c': 2}} + + +def test_reconcile_full_gemma4_like_tree(): + """End-to-end test with a Gemma4-like layer structure.""" + arr = np.zeros(2) + params = { + 'layer_0': { + 'attn': { + 'q_einsum': {'w': arr}, + 'kv_einsum': {'w': arr}, + 'attn_vec_einsum': {'w': arr}, + '_LoRAEinsum_0': {}, + }, + 'mlp': { + 'gating_einsum': arr, + 'linear': arr, + '_LoRAEinsum_gating_einsum': {}, + '_LoRAEinsum_linear': {}, + }, + 'per_layer_input_gate': { + 'w': arr, + '_LoRAEinsum_0': {}, + }, + }, + 'embedder': {'input_embedding': arr}, + } + metadata = { + 'layer_0': { + 'attn': { + 'q_einsum': {'w': None}, + 'kv_einsum': {'w': None}, + 'attn_vec_einsum': {'w': None}, + }, + 'mlp': { + 'gating_einsum': {'w': None}, + 'linear': {'w': None}, + }, + 'per_layer_input_gate': {'w': None}, + }, + 'embedder': {'input_embedding': None}, + } + result = _checkpoint._reconcile_tree(params, metadata) + + # LoRA stubs removed + assert '_LoRAEinsum_0' not in result['layer_0']['attn'] + assert '_LoRAEinsum_gating_einsum' not in result['layer_0']['mlp'] + assert '_LoRAEinsum_linear' not in result['layer_0']['mlp'] + assert '_LoRAEinsum_0' not in result['layer_0']['per_layer_input_gate'] + + # MLP leaves wrapped to dict format + assert result['layer_0']['mlp']['gating_einsum'] == {'w': arr} + assert result['layer_0']['mlp']['linear'] == {'w': arr} + + # Normal params preserved + assert result['layer_0']['attn']['q_einsum'] == {'w': arr} + assert result['embedder'] == {'input_embedding': arr} +