Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/pruna/algorithms/base/pruna_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import functools
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Iterable

from transformers import Pipeline
Expand Down Expand Up @@ -355,8 +356,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any:
Any
The model after the algorithm has been applied.
"""
if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config._prepare_saving:
save_dir = smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR
if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config.prepare_saving:
save_dir = self.get_save_before_smash_dir(smash_config)
save_pruna_model(model, save_dir, smash_config)

# save algorithms to reapply after loading
Expand Down Expand Up @@ -447,6 +448,23 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]:
"""
return _expand_tags_into_algorithm_names(self.disjointly_compatible_after)

@staticmethod
def get_save_before_smash_dir(smash_config: SmashConfig) -> Path:
"""
Get the save directory for the algorithm caches.

Parameters
----------
smash_config : SmashConfig
The SmashConfig to check the cache directory against.

Returns
-------
Path
The absolute path of "SAVE_BEFORE_SMASH_CACHE_DIR".
"""
return (smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR).resolve()


def wrap_handle_imports(func):
"""
Expand Down
46 changes: 43 additions & 3 deletions src/pruna/algorithms/global_utils/recovery/perp_recoverer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil
from typing import Any, Dict

import torch
Expand All @@ -23,15 +24,15 @@
from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner
from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr
from pruna.algorithms.global_utils.recovery.utils import get_trainable_parameters
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.engine.model_checks import (
is_causal_lm,
is_flux_pipeline,
is_sana_pipeline,
is_sd_pipeline,
is_sdxl_pipeline,
)
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model
from pruna.logging.logger import pruna_logger


Expand All @@ -52,7 +53,7 @@ class PERPRecoverer(PrunaAlgorithmBase):
"""

group_tags: list[AlgorithmTag] = [AlgorithmTag.RECOVERER] # type: ignore[attr-defined]
save_fn = SAVE_FUNCTIONS.pickled
save_fn = None
references: dict[str, str] = {
"GitHub": "https://github.com/huggingface/peft",
"Paper": "https://arxiv.org/pdf/2312.15230",
Expand Down Expand Up @@ -181,6 +182,45 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) ->
adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_")
adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed)

def apply(self, model: Any, smash_config: SmashConfig) -> Any:
"""
Apply the recovery algorithm and refresh the save cache if needed.

Recovery modifies weights in-place without changing the model's serialization
format. If a prior algorithm used ``save_before_apply`` (caching the model before
its transformation), the cached snapshot is now stale because recovery changed
the weights. This override refreshes that cache so the already saved model includes
the recovered weights.

Parameters
----------
model : Any
The model to apply the algorithm to.
smash_config : SmashConfig
The SmashConfig object containing the save and load functions.

Returns
-------
Any
The model after recovery has been applied.
"""
result = super().apply(model, smash_config)

if smash_config.prepare_saving:
save_dir = self.get_save_before_smash_dir(smash_config)
if not save_dir.exists():
return result

ori_save_fns = smash_config.save_fns[:]
smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name]
# Re-save with recovered weights
shutil.rmtree(save_dir, ignore_errors=True)
save_dir.mkdir(parents=True)
save_pruna_model(model, save_dir, smash_config)
# Restore save_fns
smash_config.save_fns = ori_save_fns
return result

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Recover performances from a given model with a given config.
Expand Down
5 changes: 5 additions & 0 deletions src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def __init__(
raise ValueError(f"Unsupported configuration type: {type(configuration)}")
self.config_space: ConfigurationSpace = self._configuration.config_space

@property
def prepare_saving(self):
"""Getter of _prepare_saving as an object's internal data."""
return self._prepare_saving

@classmethod
def from_list(
cls,
Expand Down
81 changes: 73 additions & 8 deletions tests/engine/test_save.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import os
import pytest
import torch
import shutil
from pathlib import Path
from unittest.mock import patch

import pytest
import torch
from diffusers import DiffusionPipeline
from transformers import AutoModelForCausalLM
from pruna.config.smash_config import SmashConfig

from pruna import smash
from pruna.engine.save import save_pruna_model
from pruna.engine.save import save_pruna_model_to_hub
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.engine.load import load_pruna_model
from pruna.config.smash_config import SmashConfig
from diffusers import DiffusionPipeline
from pruna.engine.load import load_pruna_model
from pruna.engine.pruna_model import PrunaModel
from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub


@pytest.mark.slow
Expand Down Expand Up @@ -160,3 +160,68 @@ def test_push_to_hub_path_types(tmp_path) -> None:
private=True
)
assert mock_upload.called


@pytest.mark.cpu
def test_recovery_save_fn_is_none() -> None:
"""Test that recovery algorithms use save_fn=None, preserving the prior algorithm's save format."""
from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer

assert PERPRecoverer.save_fn is None


@pytest.mark.cpu
def test_recovery_does_not_add_to_save_fns(tmp_path) -> None:
"""Test that recovery's apply() does not append to save_fns when save_fn is None."""

config = SmashConfig()
config.save_fns = ["hqq"] # simulate a prior algorithm's save_fn

save_fn = None

# PrunaAlgorithmBase apply logic
if save_fn is not None and save_fn != SAVE_FUNCTIONS.reapply:
config.save_fns.append(save_fn.name)

assert config.save_fns == ["hqq"], "Recovery should not add to save_fns"


@pytest.mark.cpu
def test_recovery_refresh_save_cache(tmp_path) -> None:
"""Test that recovery refreshes a stale save_before_apply cache with recovered weights."""
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase

model = AutoModelForCausalLM.from_pretrained("yujiepan/opt-tiny-random")

config = SmashConfig(device="cpu")

# Simulate a save_before_apply algorithm having run before recovery:
# 1. Save original (pre-transformation) model to cache
save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config)
save_dir.mkdir(parents=True)
save_pruna_model(model, save_dir, config)

# 2. Mark save_before_apply in save_fns (as the algorithm would)
config.save_fns.append(SAVE_FUNCTIONS.save_before_apply.name)

# 3. Simulate the transformation (e.g., half) + recovery modifying weights
model.lm_head.weight.data.fill_(0.99) # "recovered" weights

# 4. Simulate what recovery's apply() does: refresh the stale cache
ori_save_fns = config.save_fns[:]
config.save_fns = [fn for fn in config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name]
shutil.rmtree(save_dir, ignore_errors=True)
save_dir.mkdir(parents=True)
save_pruna_model(model, save_dir, config)
config.save_fns = ori_save_fns

# 5. Verify the cache was refreshed: save_before_apply should copy updated files
save_path = tmp_path / "final_model"
save_pruna_model(model, save_path, config)

# Load and verify the recovered weights survived the round-trip
loaded_model, _ = load_pruna_model(save_path)
loaded_model = loaded_model.cpu()
assert torch.allclose(
loaded_model.lm_head.weight, torch.full_like(loaded_model.lm_head.weight, 0.99)
), "Recovered weights should survive save/load through save_before_apply"
Loading