Skip to content

Comments

Enhance fine-tuning capabilities for foundation models#3003

Open
Kurokabe wants to merge 13 commits intomasterfrom
finetuning
Open

Enhance fine-tuning capabilities for foundation models#3003
Kurokabe wants to merge 13 commits intomasterfrom
finetuning

Conversation

@Kurokabe
Copy link
Collaborator

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #2964

Summary

This PR implements native support for full and partial fine-tuning of foundation models (e.g., Chronos2Model) and adds advanced integration capabilities for external libraries like peft.

  1. Foundation Model Enhancements:

    • Updated FoundationModel base class to accept enable_finetuning, freeze_patterns, and unfreeze_patterns.
    • Automatic injection of LayerFreezeCallback when fine-tuning is enabled with specific patterns.
    • Added internal_model property to provide direct access to the underlying nn.Module, facilitating advanced use cases like PEFT/LoRA.
  2. Callback Improvements:

    • Ensured PeftCallback correctly handles adapter merging during checkpointing, allowing models trained with LoRA to be saved and reloaded as standard Darts models.
  3. Documentation & Examples:

    • Added a new example notebook 26-Chronos-2-finetuning-examples.ipynb demonstrating full fine-tuning, partial fine-tuning with layer freezing, and LoRA integration.
    • Included performance evaluation and persistence (save/load) examples for each method.
  4. Testing:

    • Expanded tests in test_foundation.py covering all new fine-tuning scenarios and ensuring correct model state after saving/loading.

How Has This Been Tested?

  • Added unit tests for FoundationModel fine-tuning logic.
  • Verified LoRA integration and weight merging via PeftCallback.
  • Manual verification of the example notebook.
  • All added tests in test_foundation.py pass.

@Kurokabe Kurokabe requested a review from dennisbader as a code owner January 30, 2026 17:23
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov
Copy link

codecov bot commented Jan 30, 2026

Codecov Report

❌ Patch coverage is 72.29730% with 41 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.48%. Comparing base (bc4d747) to head (a568a3d).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
darts/utils/callbacks/fine_tuning.py 57.31% 35 Missing ⚠️
darts/models/forecasting/foundation_model.py 82.85% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #3003      +/-   ##
==========================================
- Coverage   95.69%   95.48%   -0.22%     
==========================================
  Files         154      156       +2     
  Lines       16604    16753     +149     
==========================================
+ Hits        15890    15996     +106     
- Misses        714      757      +43     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@daidahao
Copy link
Contributor

daidahao commented Jan 31, 2026

Hi @Kurokabe Thank you for this PR and your efforts at making fine-tuning work for foundation models! Here are my suggestions:

Nested Model Attribute

After reviewing the code, I have some worries as to the nested model attribute of FoundationModel. From my perspective (having written FoundationModel and implemented Chronos-2 and TimesFM in Darts), I would raise two concerns:

  • It adds a new layer to new model implementation, e.g., FoundationModel -> FoundationPLModule -> nn.Module, and creates confusion for developers, with limited benefits, i.e., PEFT support.
  • It makes the model checkpoint (aka, ckpt file), incompatible with original checkpoints, because of the model.* prefix.

Even if we want PEFT support for foundation models, I wonder if we can do so without running into a nested model.model.model situation via more straightforward method overrides:

class FoundationModel(MixedCovariatesTorchModel, ABC):

    @abstractmethod
    def _create_original_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
        """Create the original PyTorch Lightning forecasting module without any PEFT adapters."""

    def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
        model = self._create_original_model(train_sample)
        if self._enable_finetuning and self.peft_config is not None:
            from peft import get_peft_model
            model = get_peft_model(model, self.peft_config)
        return model

We then override the save() method to ensure the PEFT-merged checkpoint is being saved when called:

    def save(
        self,
        path: Optional[str] = None,
        clean: bool = False,
    ) -> None:
        if self._enable_finetuning and self.peft_config is not None:
            self.model.merge_adapter()
        super().save(path=path, clean=clean)

That way, we could avoid implementing additional ModelTransformCallback and PeftCallback which IMHO are a bit opaque to use and maintain.

I also argue that we might not need adapter merge for training checkpoints as it adds overheads and those checkpoints do not need to be compatible. Instead, we could suggest the users call save() at the end of training to get portable model weights.

Fine-tuning Hyperparameters

Like I said in #2964, I recommend exposing the fine-tuning hyper-parameters to users rather than the callback. This allows direct control of fine-tuning behaviours.

model_lora = Chronos2Model(
    input_chunk_length=24,
    output_chunk_length=6,
    enable_finetuning=True,
    n_epochs=50,
    unfreeze_patterns=unfreeze_patterns,
    peft_config=peft_config,
)

For partial fine-tuning, please also consider:

  • Removing freeze_patterns as it is redundant to unfreeze_patterns that is more common than the former.
  • Using fnmatch or suffix (.endswith()) to match model weights rather than prefix-only. Users might want match *.self_attention.q.weight rather than a prefix like encoder.block.0.layer.1.
  • Raising an error when any pattern is not matched in unfreeze_patterns to prevent silent fails.
  • Would it also be possible to combine enable_finetuning and unfreeze_patterns into one parameter enable_finetuning for shared semantics?

For PEFT fine-tuning, please consider:

  • Exposing peft_config as a model hyper-parameter to directly configure PEFT.

Those are merely my suggestions for your considerations. Feel free to ignore them if you disagree.

Many thanks.

@dennisbader dennisbader added this to darts Feb 3, 2026
@github-project-automation github-project-automation bot moved this to In review in darts Feb 3, 2026
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @Kurokabe for this great and showing all the possibilities we have to support fine tuning! This will be a great addition to Darts 🚀

Also thanks @daidahao for your review, I agree with your suggestions.

How I see it now is that we can enable full and partial fine-tuning with relatively minimal effort (few lines of code, no breaking changes) and even add support for it to ALL our existing torch models. This is huge, and should be the focus for now. For example, it would close the gap of our fine tuning recommendations from here.

Adding another layer of model nesting is something I want to avoid - at least for the near future. Therefore, for now I would say should not add PEFT support. If PEFT is something that the users really need in the future, we can always come back to it.

Here are my suggestions:

  • Let's revert the changes to model nesting, PEFT support, callbacks
  • Let's merge the enable_fine_tuning related parameters into one (as suggested in the comments), and move it into TorchForecastingModel to enable fine tuning support for all torch models
  • Let's handle the parameter freezing / unfreezing directly in the base model instead of a callback

def __init__(
self,
enable_finetuning: bool = False,
freeze_patterns: list[str] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could combine the entire logic into a single parameter enable_fine_tuning to simplify things for the user. Here is an idea of how it could look like:

enable_fine_tuning: Optional[Union[bool, dict[str, list[str]]]] = None

# parameter description could look something like
"""
Parameters
----------
enable_fine_tuning
    Enables model fine-tuning. Only effective, if not `None`.
    If a bool, specifies whether to perform full fine tuning / training (all parameters are updated) or keep all parameters frozen.
    If a dict, specifies which parameters to fine tune. Must only contain one key-value record. Can be used to: 
    
    - Unfreeze specific parameters, while keeping everything else frozen: `{"unfreeze": ["patterns.to.freeze"]}`
    - Freeze specific parameters, while keeping everything else unfrozen: `{"freeze": ["patterns.to.freeze"]}`
    
    (TODO: add some words about the allowed patterns for parameter matching) 
"""

Like this, we can support fine-tuning in general for all our torch models (not only the foundation models), and move the parameter and handling into TorchForecastingModel.

And then we should add the param description to all torch models

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dennisbader

I love this suggestion! Instead of enabling fine-tuning for foundation models only, we could enable it at the TFM level to benefit all torch models. This could deliver huge benefits to many training + fine-tuning pipelines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep :)

else:
callbacks = list(callbacks)

callbacks.append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @daidahao mentioned, the parameter freezing/unfreezing could (should) be performed without the need of a callback. You could move that logic into a function that is called somewhere after _create_model(...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fine-tuning the foundation models, we should make sure that during training we use a QuantileRegression(quantiles) with all quantiles that the original weights were trained on.

The user should still be able specify some different quantiles when creating the model with likelihood=QuantileRegression(other_quantiles). These quantile will only be used for prediction.

@@ -0,0 +1,957 @@
{
Copy link
Collaborator

@dennisbader dennisbader Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice notebook, thanks a lot @Kurokabe !

Let's make sure we also update the notebook after having applied the code suggestions :)

Here are some suggestions for the notebook:

  • Since the new full/partial fine tuning support will be added to all torch models, we could rename it to Torch & Foundation Model Fine-Tuning
  • Foundation model fine-tuning can be one section
  • Regular torch model fine-tuning can be another section
  • Show the approach from here including the new fine tuning support
  • No need to compare the output of a loaded model to the original one
  • Remove the PEFT related parts

Reply via ReviewNB

@@ -0,0 +1,957 @@
{
Copy link
Collaborator

@dennisbader dennisbader Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using output_chunk_length=6 and forecast horizon n=24 will perform auto-regression. Maybe it would be better to avoid this since we should focus on the output window that the model was fine-tuned on. You can use output_chunk_length=12 and use a shorter val set

train_passengers, val_passengers = data[:-12], data[-12:] 

Also, I think it would be nice to use a QuantileRegression in the example, so we can show that the quantiles were also fine-tuned properly.

Here's the model setup I used, which gave some nice results

full_finetuned_model = Chronos2Model(
    input_chunk_length=24,
    output_chunk_length=12,
    use_reversible_instance_norm=True,
    likelihood=QuantileRegression([0.1, 0.5, 0.9]),
    enable_finetuning=True,
    random_state=42,
    n_epochs=100,
)

# ... later predict with predict_likelihood_parameters=True
pred_full_finetuned = full_finetuned_model.predict(
    n=len(val_passengers),
    series=train_passengers,
    predict_likelihood_parameters=True,
)

# ... metrics can still be computed against the median
mape(data, pred_full_finetuned, q=0.5)

Reply via ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

[Feature] Chronos-2 fine-tuning support

3 participants