Enhance fine-tuning capabilities for foundation models#3003
Enhance fine-tuning capabilities for foundation models#3003
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #3003 +/- ##
==========================================
- Coverage 95.69% 95.48% -0.22%
==========================================
Files 154 156 +2
Lines 16604 16753 +149
==========================================
+ Hits 15890 15996 +106
- Misses 714 757 +43 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hi @Kurokabe Thank you for this PR and your efforts at making fine-tuning work for foundation models! Here are my suggestions: Nested Model AttributeAfter reviewing the code, I have some worries as to the nested
Even if we want PEFT support for foundation models, I wonder if we can do so without running into a nested class FoundationModel(MixedCovariatesTorchModel, ABC):
@abstractmethod
def _create_original_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
"""Create the original PyTorch Lightning forecasting module without any PEFT adapters."""
def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
model = self._create_original_model(train_sample)
if self._enable_finetuning and self.peft_config is not None:
from peft import get_peft_model
model = get_peft_model(model, self.peft_config)
return modelWe then override the def save(
self,
path: Optional[str] = None,
clean: bool = False,
) -> None:
if self._enable_finetuning and self.peft_config is not None:
self.model.merge_adapter()
super().save(path=path, clean=clean)That way, we could avoid implementing additional I also argue that we might not need adapter merge for training checkpoints as it adds overheads and those checkpoints do not need to be compatible. Instead, we could suggest the users call Fine-tuning HyperparametersLike I said in #2964, I recommend exposing the fine-tuning hyper-parameters to users rather than the callback. This allows direct control of fine-tuning behaviours. model_lora = Chronos2Model(
input_chunk_length=24,
output_chunk_length=6,
enable_finetuning=True,
n_epochs=50,
unfreeze_patterns=unfreeze_patterns,
peft_config=peft_config,
)For partial fine-tuning, please also consider:
For PEFT fine-tuning, please consider:
Those are merely my suggestions for your considerations. Feel free to ignore them if you disagree. Many thanks. |
dennisbader
left a comment
There was a problem hiding this comment.
Thanks a lot @Kurokabe for this great and showing all the possibilities we have to support fine tuning! This will be a great addition to Darts 🚀
Also thanks @daidahao for your review, I agree with your suggestions.
How I see it now is that we can enable full and partial fine-tuning with relatively minimal effort (few lines of code, no breaking changes) and even add support for it to ALL our existing torch models. This is huge, and should be the focus for now. For example, it would close the gap of our fine tuning recommendations from here.
Adding another layer of model nesting is something I want to avoid - at least for the near future. Therefore, for now I would say should not add PEFT support. If PEFT is something that the users really need in the future, we can always come back to it.
Here are my suggestions:
- Let's revert the changes to model nesting, PEFT support, callbacks
- Let's merge the enable_fine_tuning related parameters into one (as suggested in the comments), and move it into
TorchForecastingModelto enable fine tuning support for all torch models - Let's handle the parameter freezing / unfreezing directly in the base model instead of a callback
| def __init__( | ||
| self, | ||
| enable_finetuning: bool = False, | ||
| freeze_patterns: list[str] | None = None, |
There was a problem hiding this comment.
We could combine the entire logic into a single parameter enable_fine_tuning to simplify things for the user. Here is an idea of how it could look like:
enable_fine_tuning: Optional[Union[bool, dict[str, list[str]]]] = None
# parameter description could look something like
"""
Parameters
----------
enable_fine_tuning
Enables model fine-tuning. Only effective, if not `None`.
If a bool, specifies whether to perform full fine tuning / training (all parameters are updated) or keep all parameters frozen.
If a dict, specifies which parameters to fine tune. Must only contain one key-value record. Can be used to:
- Unfreeze specific parameters, while keeping everything else frozen: `{"unfreeze": ["patterns.to.freeze"]}`
- Freeze specific parameters, while keeping everything else unfrozen: `{"freeze": ["patterns.to.freeze"]}`
(TODO: add some words about the allowed patterns for parameter matching)
"""
Like this, we can support fine-tuning in general for all our torch models (not only the foundation models), and move the parameter and handling into TorchForecastingModel.
And then we should add the param description to all torch models
There was a problem hiding this comment.
I love this suggestion! Instead of enabling fine-tuning for foundation models only, we could enable it at the TFM level to benefit all torch models. This could deliver huge benefits to many training + fine-tuning pipelines.
| else: | ||
| callbacks = list(callbacks) | ||
|
|
||
| callbacks.append( |
There was a problem hiding this comment.
As @daidahao mentioned, the parameter freezing/unfreezing could (should) be performed without the need of a callback. You could move that logic into a function that is called somewhere after _create_model(...)
There was a problem hiding this comment.
For fine-tuning the foundation models, we should make sure that during training we use a QuantileRegression(quantiles) with all quantiles that the original weights were trained on.
The user should still be able specify some different quantiles when creating the model with likelihood=QuantileRegression(other_quantiles). These quantile will only be used for prediction.
| @@ -0,0 +1,957 @@ | |||
| { | |||
There was a problem hiding this comment.
Very nice notebook, thanks a lot @Kurokabe !
Let's make sure we also update the notebook after having applied the code suggestions :)
Here are some suggestions for the notebook:
- Since the new full/partial fine tuning support will be added to all torch models, we could rename it to Torch & Foundation Model Fine-Tuning
- Foundation model fine-tuning can be one section
- Regular torch model fine-tuning can be another section
- Show the approach from here including the new fine tuning support
- No need to compare the output of a loaded model to the original one
- Remove the PEFT related parts
Reply via ReviewNB
| @@ -0,0 +1,957 @@ | |||
| { | |||
There was a problem hiding this comment.
using output_chunk_length=6 and forecast horizon n=24 will perform auto-regression. Maybe it would be better to avoid this since we should focus on the output window that the model was fine-tuned on. You can use output_chunk_length=12 and use a shorter val set
train_passengers, val_passengers = data[:-12], data[-12:]
Also, I think it would be nice to use a QuantileRegression in the example, so we can show that the quantiles were also fine-tuned properly.
Here's the model setup I used, which gave some nice results
full_finetuned_model = Chronos2Model(
input_chunk_length=24,
output_chunk_length=12,
use_reversible_instance_norm=True,
likelihood=QuantileRegression([0.1, 0.5, 0.9]),
enable_finetuning=True,
random_state=42,
n_epochs=100,
)
# ... later predict with predict_likelihood_parameters=True
pred_full_finetuned = full_finetuned_model.predict(
n=len(val_passengers),
series=train_passengers,
predict_likelihood_parameters=True,
)
# ... metrics can still be computed against the median
mape(data, pred_full_finetuned, q=0.5)
Reply via ReviewNB
Checklist before merging this PR:
Fixes #2964
Summary
This PR implements native support for full and partial fine-tuning of foundation models (e.g.,
Chronos2Model) and adds advanced integration capabilities for external libraries likepeft.Foundation Model Enhancements:
FoundationModelbase class to acceptenable_finetuning,freeze_patterns, andunfreeze_patterns.LayerFreezeCallbackwhen fine-tuning is enabled with specific patterns.internal_modelproperty to provide direct access to the underlyingnn.Module, facilitating advanced use cases like PEFT/LoRA.Callback Improvements:
PeftCallbackcorrectly handles adapter merging during checkpointing, allowing models trained with LoRA to be saved and reloaded as standard Darts models.Documentation & Examples:
26-Chronos-2-finetuning-examples.ipynbdemonstrating full fine-tuning, partial fine-tuning with layer freezing, and LoRA integration.Testing:
test_foundation.pycovering all new fine-tuning scenarios and ensuring correct model state after saving/loading.How Has This Been Tested?
FoundationModelfine-tuning logic.PeftCallback.test_foundation.pypass.