Skip to content

ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298

Draft
robbiebusinessacc wants to merge 1 commit into
huggingface:mainfrom
robbiebusinessacc:contrib/add-kasa-lora-variant
Draft

ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298
robbiebusinessacc wants to merge 1 commit into
huggingface:mainfrom
robbiebusinessacc:contrib/add-kasa-lora-variant

Conversation

@robbiebusinessacc
Copy link
Copy Markdown

Closes part of #2516 (Call for contribution: KaSA).

Implements KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071) using the LoRA-variant framework from #2443, following the SVD-based variants (CorDA/DoRA) and without adding if-branches to core LoRA logic. Reference
implementation: https://github.com/juyongjiang/KaSA.

Method

KaSA changes vanilla LoRA in two ways:

  1. Knowledge-based SVD truncation of the frozen base weight (one-time, destructive). At init the base weight W is SVD-factored and its r smallest ("noisy"/long-tail) singular components are discarded, leaving the rank-(k-r) approximation
    (k = min(in_features, out_features)) as the new frozen base. The trainable branch re-learns in the discarded residual subspace.
  2. Knowledge-aware singular-value adaptation (trainable update). A learnable diagonal of singular values lora_diag (ΔΣ) is inserted between LoRA A and B: ΔW = scaling * B @ diag(ΔΣ) @ A. lora_diag is the only new per-layer parameter
    (an r-vector); B stays zero-init so the update is 0 at step 0.

The paper additionally trains with two auxiliary regularizers — an L2 penalty on the singular values (sum(lora_diag**2)) and an orthogonal regularization ||B^T B - I||_F + ||A A^T - I||_F that softly enforces the semi-orthogonality the SVD
parametrization assumes. The PEFT variant forward has no channel to inject a scalar into the training loss, so these are exposed via a get_kasa_regularization_loss(model) helper that the user adds to their task loss.

Integration

  • New KasaConfig sub-config (beta, gamma, both validated non-negative; reference GLUE defaults 1e-4 / 1e-3) and LoraConfig.kasa_config field. Selection is driven by kasa_config is not None in resolve_lora_variant (config-object
    pattern, mirrors velora/monteclora/arrow), with dict-coercion + TypeError guard in __post_init__.
  • KasaLinearVariant(LoraVariant) implements init (SVD truncation + lora_diag), forward, merge_safe/merge_unsafe/unmerge, get_delta_weight. lora_diag is registered in adapter_layer_names so it is saved/loaded.
  • Explicit guards reject KaSA on Embedding / Conv / MultiheadAttention / ParamWrapper / fan_in_fan_out (Conv1D) layers, consistent with the VeLoRA guard pattern.
  • Generalizes the reference's svd_rank = in_features - r to min(in_features, out_features) - r so wide layers are handled correctly; raises ValueError if r >= min(in, out).
  • Top-level exports of KasaConfig and get_kasa_regularization_loss.

Tests (tests/test_kasa.py, CPU-only, tiny random nn.Linear, no downloads)

Config dispatch/dict-round-trip/alias/error cases; lora_diag shape+learnable; B zero-init; SVD-truncation rank check (exactly r singular values zeroed, principal values preserved); truncation changes base forward; merge/unmerge round-trip to
the truncated base; delta-weight formula; lora_diag in state_dict + save/load forward equivalence; reload onto the original base re-truncates deterministically (parametrized over low_cpu_mem_usage=False/True, idempotent across forwards);
regularization closed-form (L2/L3), orthonormal vs non-orthonormal, gradients. Plus wiring in tests/test_lora_variants.py.

Open questions for maintainers (honest)

  • Destructive base mutation breaks the usual "disable adapter == base" contract. Adding/disabling/unloading does not restore W0, and merge/unmerge round-trips to the truncated weight. This is inherent to KaSA. Do you want the original
    weight stashed to allow a true unload, or is this semantics acceptable as-is (documented)?
  • Regularization location. The LoRA-variant API has no loss-return channel, so L2/L3 are exposed as get_kasa_regularization_loss for the user to add to their loss (the reference computes them in external training scripts). Is a
    free-function helper the API you want, or a method on the PEFT model? Without these terms the SVD interpretation is only approximate, so they are implemented and unit-tested rather than dropped — this was the correctness gap behind
    Add KaSA implementation to layer.py #2543/[WIP] Update LoraConfig for KaSA implementation #2698.
  • Save/load story. Reloading the adapter onto the original base re-applies the deterministic truncation (verified for both the default and low_cpu_mem_usage paths). Reloading onto an already-truncated/merged base would double-truncate.
    Want a docs note / explicit guard?
  • lora_diag uses randn(r) per the reference; the paper says "randomly initialized without bias". Safe because B is zero-init. Confirm if ones is preferred.
  • Scope is nn.Linear only for now (quantized/Conv/MHA explicitly rejected). OK to land Linear-first?

No docs page added yet; happy to add one if you'd like it in this PR.

Implement KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071)
using the LoRA-variant framework, following the SVD-based variants (CorDA/DoRA).

KaSA changes vanilla LoRA in two ways:
- A one-time, destructive SVD truncation of the frozen base weight that drops
  its r smallest singular components, leaving the rank-(k-r) approximation as
  the new frozen base (k = min(in_features, out_features)).
- A learnable diagonal of singular values (lora_diag) inserted between the LoRA
  A and B factors, so the update is ΔW = scaling * B @ diag(lora_diag) @ A.

- New KasaConfig sub-config (beta, gamma) and LoraConfig.kasa_config field;
  selection is driven by kasa_config being non-None via resolve_lora_variant,
  with explicit guards rejecting KaSA on embedding/conv/MHA/ParamWrapper and
  fan_in_fan_out layers.
- KasaLinearVariant implements init (SVD truncation + lora_diag), forward,
  merge_safe/merge_unsafe/unmerge. lora_diag is registered in
  adapter_layer_names so it is saved/loaded.
- get_kasa_regularization_loss helper exposes the paper's two auxiliary terms
  (L2 singular-value penalty + L3 orthogonal regularization), since the variant
  forward has no channel to inject an extra loss into the training loop.
- Tests in tests/test_kasa.py (SVD-truncation faithfulness, lora_diag shape,
  zero-init update, merge/unmerge round-trip, delta-weight formula, save/load,
  regularization closed-form checks) plus wiring in tests/test_lora_variants.py.

Faithfulness notes:
- The base-weight truncation is destructive; disabling/unloading does not
  restore the original weight and merge/unmerge round-trips to the truncated
  base. This is inherent to the method and documented.
- The paper's L2/L3 regularizers are required for the SVD interpretation to
  hold but cannot be auto-injected; users must add get_kasa_regularization_loss
  to their loss.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant