ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298
Draft
robbiebusinessacc wants to merge 1 commit into
Draft
ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298robbiebusinessacc wants to merge 1 commit into
robbiebusinessacc wants to merge 1 commit into
Conversation
Implement KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071) using the LoRA-variant framework, following the SVD-based variants (CorDA/DoRA). KaSA changes vanilla LoRA in two ways: - A one-time, destructive SVD truncation of the frozen base weight that drops its r smallest singular components, leaving the rank-(k-r) approximation as the new frozen base (k = min(in_features, out_features)). - A learnable diagonal of singular values (lora_diag) inserted between the LoRA A and B factors, so the update is ΔW = scaling * B @ diag(lora_diag) @ A. - New KasaConfig sub-config (beta, gamma) and LoraConfig.kasa_config field; selection is driven by kasa_config being non-None via resolve_lora_variant, with explicit guards rejecting KaSA on embedding/conv/MHA/ParamWrapper and fan_in_fan_out layers. - KasaLinearVariant implements init (SVD truncation + lora_diag), forward, merge_safe/merge_unsafe/unmerge. lora_diag is registered in adapter_layer_names so it is saved/loaded. - get_kasa_regularization_loss helper exposes the paper's two auxiliary terms (L2 singular-value penalty + L3 orthogonal regularization), since the variant forward has no channel to inject an extra loss into the training loop. - Tests in tests/test_kasa.py (SVD-truncation faithfulness, lora_diag shape, zero-init update, merge/unmerge round-trip, delta-weight formula, save/load, regularization closed-form checks) plus wiring in tests/test_lora_variants.py. Faithfulness notes: - The base-weight truncation is destructive; disabling/unloading does not restore the original weight and merge/unmerge round-trips to the truncated base. This is inherent to the method and documented. - The paper's L2/L3 regularizers are required for the SVD interpretation to hold but cannot be auto-injected; users must add get_kasa_regularization_loss to their loss.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes part of #2516 (Call for contribution: KaSA).
Implements KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071) using the LoRA-variant framework from #2443, following the SVD-based variants (CorDA/DoRA) and without adding if-branches to core LoRA logic. Reference
implementation: https://github.com/juyongjiang/KaSA.
Method
KaSA changes vanilla LoRA in two ways:
(k = min(in_features, out_features)) as the new frozen base. The trainable branch re-learns in the discarded residual subspace.
lora_diag(ΔΣ) is inserted between LoRA A and B: ΔW = scaling * B @ diag(ΔΣ) @ A.lora_diagis the only new per-layer parameter(an r-vector); B stays zero-init so the update is 0 at step 0.
The paper additionally trains with two auxiliary regularizers — an L2 penalty on the singular values (sum(lora_diag**2)) and an orthogonal regularization ||B^T B - I||_F + ||A A^T - I||_F that softly enforces the semi-orthogonality the SVD
parametrization assumes. The PEFT variant
forwardhas no channel to inject a scalar into the training loss, so these are exposed via aget_kasa_regularization_loss(model)helper that the user adds to their task loss.Integration
KasaConfigsub-config (beta,gamma, both validated non-negative; reference GLUE defaults 1e-4 / 1e-3) andLoraConfig.kasa_configfield. Selection is driven bykasa_config is not Noneinresolve_lora_variant(config-objectpattern, mirrors velora/monteclora/arrow), with dict-coercion + TypeError guard in
__post_init__.KasaLinearVariant(LoraVariant)implements init (SVD truncation + lora_diag), forward, merge_safe/merge_unsafe/unmerge, get_delta_weight.lora_diagis registered inadapter_layer_namesso it is saved/loaded.svd_rank = in_features - rtomin(in_features, out_features) - rso wide layers are handled correctly; raises ValueError if r >= min(in, out).KasaConfigandget_kasa_regularization_loss.Tests (tests/test_kasa.py, CPU-only, tiny random nn.Linear, no downloads)
Config dispatch/dict-round-trip/alias/error cases; lora_diag shape+learnable; B zero-init; SVD-truncation rank check (exactly r singular values zeroed, principal values preserved); truncation changes base forward; merge/unmerge round-trip to
the truncated base; delta-weight formula; lora_diag in state_dict + save/load forward equivalence; reload onto the original base re-truncates deterministically (parametrized over low_cpu_mem_usage=False/True, idempotent across forwards);
regularization closed-form (L2/L3), orthonormal vs non-orthonormal, gradients. Plus wiring in tests/test_lora_variants.py.
Open questions for maintainers (honest)
weight stashed to allow a true unload, or is this semantics acceptable as-is (documented)?
get_kasa_regularization_lossfor the user to add to their loss (the reference computes them in external training scripts). Is afree-function helper the API you want, or a method on the PEFT model? Without these terms the SVD interpretation is only approximate, so they are implemented and unit-tested rather than dropped — this was the correctness gap behind
Add KaSA implementation to layer.py #2543/[WIP] Update
LoraConfigfor KaSA implementation #2698.Want a docs note / explicit guard?
No docs page added yet; happy to add one if you'd like it in this PR.