Feature/add frod#3270
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the PR to add FRoD. This looks already quite good. In this review, I focused on the integration, so I haven't checked the docs, examples, and experiments yet.
|
|
||
|
|
||
| @dataclass | ||
| class FRODConfig(PeftConfig): |
There was a problem hiding this comment.
Let's go with FrodConfig etc. It's easier to type and more consistent overall (e.g. we use LoraConfig, not LoRAConfig).
There was a problem hiding this comment.
Updated to FrodConfig, FrodModel, and FrodLayer to match PEFT naming conventions. I kept PeftType.FROD uppercase because it follows the enum naming style.
| self.merged_adapters = [] | ||
|
|
||
| base_layer = self.get_base_layer() | ||
| if isinstance(base_layer, nn.Linear): |
There was a problem hiding this comment.
Let's use peft.tuners.tuners_utils._get_in_out_features.
There was a problem hiding this comment.
Updated FrodLayer to use peft.tuners.tuners_utils._get_in_out_features() instead of manually handling nn.Linear and Conv1D.
| self.kwargs = kwargs | ||
|
|
||
| @property | ||
| def merged(self) -> bool: |
There was a problem hiding this comment.
Removed the duplicate merged property.
| init_weights, | ||
| ): | ||
| base_layer = self.get_base_layer() | ||
| weight = base_layer.weight.T if isinstance(base_layer, Conv1D) else base_layer.weight |
There was a problem hiding this comment.
Let's use peft.utils.other.transpose.
There was a problem hiding this comment.
Updated this to use transpose(base_layer.weight, self.fan_in_fan_out).
| import warnings | ||
| from collections import defaultdict | ||
|
|
||
| import numpy as np |
There was a problem hiding this comment.
Again, I would prefer to use torch throughout instead of numpy.
There was a problem hiding this comment.
Updated the FRoD initialization path to use torch.linalg throughout and removed the NumPy dependency.
| ) | ||
| if config.save_projection and not has_projection: | ||
| raise ValueError( | ||
| "Specified to load FRoD projection tensors from state dictionary however they were not present!" |
There was a problem hiding this comment.
It would be good to add instructions on how to deal with that situation.
There was a problem hiding this comment.
Added instructions to the error message: either load with save_projection=False to regenerate the projections from the base model weights, or re-save the adapter with save_projection=True.
| == mlp_same_prng.base_model.model.lin2.frod_s_indices["other"].data_ptr() | ||
| ) | ||
|
|
||
| def test_multiple_adapters_different_prng_raises(self): |
There was a problem hiding this comment.
As mentioned above, let's put this into test_initializaton.py.
There was a problem hiding this comment.
Added this coverage to tests/test_initialization.py using the same structure as TestVeraInitialization.
| return peft_model | ||
|
|
||
| @staticmethod | ||
| def _make_second_adapter_different(peft_model): |
There was a problem hiding this comment.
Shouldn't this be unnecessary with init_weigths=False?
There was a problem hiding this comment.
It's really helpful. Removed the helper and its call; init_weights=False already makes the two adapters different for this test.
| ("Vanilla MLP 2 FRoD", "MLP", FRODConfig, {"target_modules": ["lin0"]}), | ||
| ("Vanilla MLP 3 FRoD", "MLP", FRODConfig, {"target_modules": ["lin1"]}), | ||
| ("Vanilla MLP 4 FRoD", "MLP", FRODConfig, {"target_modules": ["lin0", "lin1"]}), | ||
| ("Vanilla MLP 5 FRoD", "MLP", FRODConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}), |
There was a problem hiding this comment.
Could you please also add FRoD to MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES?
There was a problem hiding this comment.
Added FRoD to MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES with the same-target adapter case.
|
Updated the PR in 0e0d816 to address the review feedback:
|
|
The two new commits fix the FRoD sparse forward semantics and align the image-classification example with the CLIP-ViT setup used in our experiments. I also re-tested the examples with two examples: the Stanford Cars image-classification example and the SST-2 text-classification example. The CLIP-ViT Stanford Cars run converged as expected in the local 3-epoch test. Thank you for the review and suggestions. Looking forward to your feedback. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the updates. I did another review focusing on the integration and still found a handful of issues, please check.
| return result | ||
|
|
||
| def __repr__(self) -> str: | ||
| # Match PEFT tuner convention so printed models show FRoD-wrapped layers as `frod.*`. |
There was a problem hiding this comment.
This comment can be removed.
There was a problem hiding this comment.
Removed the unnecessary comment.
| rep = super().__repr__() | ||
| return "frod." + rep | ||
|
|
||
| def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None: |
There was a problem hiding this comment.
I'm talking about _move_adapter_to_device_of_base_layer, not __repr__
| weight = self.get_base_layer().weight | ||
| return self._get_delta_weight(adapter) | ||
|
|
||
| def _get_delta_weight(self, adapter, base_weight: Optional[torch.Tensor] = None) -> torch.Tensor: |
There was a problem hiding this comment.
I don't see why we need _get_delta_weight with the option of passing base_weight. AFAICT, it's always self.get_base_layer().weight, so there is never a need to pass the argument explicitly.
There was a problem hiding this comment.
Simplified get_delta_weight() to always use the base layer weight. merge() now precomputes adapter deltas before mutating the base weight, so the explicit base_weight argument is no longer needed.
| values = self.frod_lambda_s_values[active_adapter].to(device=x.device, dtype=target_dtype) | ||
| lambda_l = self.frod_lambda_l[active_adapter].to(device=x.device, dtype=target_dtype) | ||
|
|
||
| x = x.to(target_dtype) |
There was a problem hiding this comment.
target_dtype is defined as x.dtype, so this call is useless.
| # F.linear(h, U @ (S + diag(lambda_l)) @ V.T). | ||
| # CUDA sparse fp16/bf16 kernels are less reliable, so use fp32 here and cast the update back below. | ||
| matmul_dtype = z_flat.dtype | ||
| if z_flat.is_cuda and matmul_dtype in (torch.float16, torch.bfloat16): |
There was a problem hiding this comment.
Just wondering: Why is this limited to is_cuda? Wouldn't we expect the same issue on other devices?
There was a problem hiding this comment.
Changed the fp16/bf16 sparse matmul fallback to use fp32 regardless of device.
| base_weight = transpose(self.get_base_layer().weight, self.fan_in_fan_out).to( | ||
| device=x.device, dtype=target_dtype | ||
| ) | ||
| base_out = F.linear(x, base_weight) | ||
|
|
||
| result = result - base_out + out_add |
There was a problem hiding this comment.
This part I don't understand: For each FRoD adapter, we need to remove the base result from result? So my understanding is that the FRoD result already includes the result from the base model, so we need to subtract it to prevent it being included multiple times.
However, this seems very wasteful: First we calculate result = self.base_layer(x, *args, **kwargs), then we calculate base_out = F.linear(x, base_weight), which is basically the same thing (just without bias) and remove it from the result. And then repeat it for each active adapter.
What I would like to see instead is for the FRoD result to only contain the "delta" from FRoD without the base result, so that we can simply accumulate it. This would be the cleanest implementation IMO. It would also help to simplify get_delta_weight, where there is a similar logic to remove the base weight from the FRoD weight.
If, for some reason, it's not possible, we should instead count how many active adapters we have. Then we calculate the base result only once and remove it x times, once for each active adapter. This should be a lot cheaper than the current implementation.
There was a problem hiding this comment.
Updated the forward pass so the FRoD branch explicitly computes and accumulates only the adapter delta. As described in the paper, FRoD directly reconstructs the adapted weight through the joint matrix factorization, so the base-weight contribution must be subtracted; otherwise the base model contribution would be included twice. The code now computes out_add - adapter_base_out and reuses the base output when dropout is identity.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the new updates. The integration is getting there, but there were still some unclear parts of the code that I commented. Please take a look.
| rep = super().__repr__() | ||
| return "frod." + rep | ||
|
|
||
| def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None: |
There was a problem hiding this comment.
Isn't _move_adapter_to_device_of_base_layer doing the exact same thing as the method from the parent class? Why does it need to be overridden?
| adapter_deltas = [] | ||
| for active_adapter in adapter_names: | ||
| if active_adapter in self.frod_lambda_l.keys(): | ||
| adapter_deltas.append((active_adapter, self.get_delta_weight(active_adapter))) |
There was a problem hiding this comment.
This could all be done in a single loop, no? I don't see why we need to collect the adapter_deltas separately. A single loop should be easier to read and is probably also better for garbage collection.
There was a problem hiding this comment.
For FRoD this needs to be computed before mutating the base weight. Unlike LoRA-style deltas, get_delta_weight() returns the difference between the reconstructed FRoD weight and the current base weight. If we merge in a single loop and update the base weight immediately, subsequent adapters would compute their delta against an already-merged weight. Precomputing the deltas keeps all adapter deltas relative to the original base weight, which is needed when merging multiple adapters.
| if not hasattr(self, "frod_V"): | ||
| self.frod_V = nn.ModuleDict() | ||
| self.frod_s_indices = nn.ModuleDict() | ||
| self.frod_s_size = nn.ModuleDict() |
There was a problem hiding this comment.
Okay, so for my better understanding, on the FrodModel, we have a frod_s_size that is a ModuleDict and it contains "categories", which are groups of similar modules (?). For each category, we have one BufferDict, which contains one value per adapter name. The values themselves are just sizes, so tuple of ints. On the FrodLayer, we only have the inner BufferDict, not the outer ModuleDict.
If this is correct, I would probably rename frod_s_size on the FrodModel to something like frod_categories_s_size or so to avoid confusion. I'm also not too fond of the name category, if you have a better idea, let's discuss that.
Finally, since we only have sizes in there, in theory a normal dictionary would be enough, right? We don't really need tensors. But as we may want to persist them, BufferDict is still helpful for that, is that understanding correct?
There was a problem hiding this comment.
Yes, that understanding is correct. The outer ModuleDict is keyed by projection category, and each category stores a BufferDict keyed by adapter name. For frod_s_size, the stored value is only the sparse matrix size tensor.
I agree that frod_category_s_size would be more explicit. I originally kept frod_s_size to mirror frod_s_indices, since both are metadata for the same sparse S matrix and are passed together to FrodLayer.
A normal dict would be enough for computation, but BufferDict keeps the tensors registered and gives them the same configurable save_projection / state-dict behavior as the projection and index buffers.
| self.frod_V[adapter_name] = frod_V[adapter_name] | ||
| self.frod_s_indices[adapter_name] = frod_s_indices[adapter_name] | ||
| self.frod_s_size[adapter_name] = frod_s_size[adapter_name] |
There was a problem hiding this comment.
I don't think this can be deleted. This follows the same pattern as VeRA: the shared projection buffers are created at the tuner level and then referenced by each wrapped layer. In VeRA, update_layer() assigns the shared vera_A / vera_B buffers to the layer; here FRoD attaches the shared category-level frod_V, frod_s_indices, and frod_s_size entries to the layer-local non-persistent buffers. These are shared references used by forward() and get_delta_weight(), not independent persistent copies.
| if active_adapter not in self.frod_lambda_s_values: | ||
| continue | ||
|
|
||
| V = self.frod_V[active_adapter].to(device=x.device, dtype=target_dtype) |
There was a problem hiding this comment.
The code here is partly exactly the same as in get_delta_weight. How about factoring out the common parts into a dedicated method?
There was a problem hiding this comment.
There is some overlap in fetching/casting the FRoD tensors, but the two paths intentionally do different computations. get_delta_weight() materializes the dense delta weight for merge/unmerge, while forward() keeps the sparse activation-side computation and avoids constructing the full dense matrix.
I could factor out only the tensor-fetching part, but that helper would mostly return a long tuple of tensors and may make the forward path harder to read. I would prefer to keep the dense merge path and sparse forward path separate unless you think the small helper would still be clearer.
| else: | ||
| adapter_base_out = F.linear(h, base_weight) | ||
|
|
||
| result = result + out_add - adapter_base_out |
There was a problem hiding this comment.
In theory, if no adapter uses dropout, we could subtract the adapter_base_out once outside the loop, multiplied by the number of active adapters. This would be more efficient if we deal with a high amount of active adapters. As this is a very unlikely use case, I'm fine with keeping the code as is for simplicity (it's also numerically more stable), I just wanted to bring that option up.
There was a problem hiding this comment.
Yes, that makes sense. I kept the current form because the common case is a small number of active adapters, and this also handles the dropout and non-dropout cases uniformly. When dropout is Identity, base_out is already computed only once and reused inside the loop, so the remaining overhead should be small.
| from .layer import FrodLayer, Linear | ||
|
|
||
|
|
||
| def _category_from_key(key: str) -> str: |
There was a problem hiding this comment.
This probably relies on certain naming conventions. Could you please add a docstring to make those explicit? The docstring should also contain an example, as it's not quite obvious what is happening here.
There was a problem hiding this comment.
Added a docstring explaining the projection-category naming convention, including BERT-style and CLIP/decoder-style examples.
| return category | ||
|
|
||
|
|
||
| def _layer_index_from_key(key: str, fallback: int) -> int: |
There was a problem hiding this comment.
This probably relies on certain naming conventions. Could you please add a docstring to make those explicit? The docstring should also contain an example, as it's not quite obvious what is happening here.
There was a problem hiding this comment.
Added a docstring explaining the layer-index parsing convention, including layers.<idx>, encoder layer.<idx>, and fallback examples.
There was a problem hiding this comment.
I wonder if there should be tests to check for how FRoD determines the "categories". We could use a tiny random model based on a real model architecture and configure FRoD to use common targets (e.g. q_proj, v_proj) and then check which categories were determined. WDYT?
There was a problem hiding this comment.
Sure. We added a test for this using a tiny local LlamaForCausalLM config with common targets q_proj and v_proj. The test checks that FRoD determines the expected categories: self_attn_q_proj and self_attn_v_proj.
Summary
This PR implements the FRoD PEFT tuner proposed in #3244.
Examples
google-bert/bert-base-uncasedandnyu-mll/glue.google/vit-base-patch16-224andtanganke/stanford_cars.Tests
PYTHONPATH=src python -m ruff check ...PYTHONPATH=src python -m py_compile examples/frod_finetuning/*.pyPYTHONPATH=src python -m pytest -o addopts='' tests/test_decoder_models.py tests/test_encoder_decoder_models.py tests/test_feature_extraction_models.py tests/test_seq_classifier.py tests/test_custom_models.py tests/test_frod.py -q -k "FROD or frod"PYTHONPATH=src python -m pytest -o addopts='' tests/test_config.py -q -k "FROD or frod"