diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0784069b..fa9a0405 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -16,6 +16,7 @@ import functools from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, Iterable from transformers import Pipeline @@ -355,8 +356,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: Any The model after the algorithm has been applied. """ - if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config._prepare_saving: - save_dir = smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config.prepare_saving: + save_dir = self.get_save_before_smash_dir(smash_config) save_pruna_model(model, save_dir, smash_config) # save algorithms to reapply after loading @@ -447,6 +448,23 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]: """ return _expand_tags_into_algorithm_names(self.disjointly_compatible_after) + @staticmethod + def get_save_before_smash_dir(smash_config: SmashConfig) -> Path: + """ + Get the save directory for the algorithm caches. + + Parameters + ---------- + smash_config : SmashConfig + The SmashConfig to check the cache directory against. + + Returns + ------- + Path + The absolute path of "SAVE_BEFORE_SMASH_CACHE_DIR". + """ + return (smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR).resolve() + def wrap_handle_imports(func): """ diff --git a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py index 4b65d9df..8d7941c0 100644 --- a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py +++ b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import shutil from typing import Any, Dict import torch @@ -23,7 +24,7 @@ from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr from pruna.algorithms.global_utils.recovery.utils import get_trainable_parameters -from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper from pruna.engine.model_checks import ( is_causal_lm, is_flux_pipeline, @@ -31,7 +32,7 @@ is_sd_pipeline, is_sdxl_pipeline, ) -from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model from pruna.logging.logger import pruna_logger @@ -52,7 +53,7 @@ class PERPRecoverer(PrunaAlgorithmBase): """ group_tags: list[AlgorithmTag] = [AlgorithmTag.RECOVERER] # type: ignore[attr-defined] - save_fn = SAVE_FUNCTIONS.pickled + save_fn = None references: dict[str, str] = { "GitHub": "https://github.com/huggingface/peft", "Paper": "https://arxiv.org/pdf/2312.15230", @@ -181,6 +182,45 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) + def apply(self, model: Any, smash_config: SmashConfig) -> Any: + """ + Apply the recovery algorithm and refresh the save cache if needed. + + Recovery modifies weights in-place without changing the model's serialization + format. If a prior algorithm used ``save_before_apply`` (caching the model before + its transformation), the cached snapshot is now stale because recovery changed + the weights. This override refreshes that cache so the already saved model includes + the recovered weights. + + Parameters + ---------- + model : Any + The model to apply the algorithm to. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + + Returns + ------- + Any + The model after recovery has been applied. + """ + result = super().apply(model, smash_config) + + if smash_config.prepare_saving: + save_dir = self.get_save_before_smash_dir(smash_config) + if not save_dir.exists(): + return result + + ori_save_fns = smash_config.save_fns[:] + smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + # Re-save with recovered weights + shutil.rmtree(save_dir, ignore_errors=True) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, smash_config) + # Restore save_fns + smash_config.save_fns = ori_save_fns + return result + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ Recover performances from a given model with a given config. diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 0acc1e12..81429b48 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -119,6 +119,11 @@ def __init__( raise ValueError(f"Unsupported configuration type: {type(configuration)}") self.config_space: ConfigurationSpace = self._configuration.config_space + @property + def prepare_saving(self): + """Getter of _prepare_saving as an object's internal data.""" + return self._prepare_saving + @classmethod def from_list( cls, diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 2cd6f3e1..34fd30b6 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -1,18 +1,18 @@ import os -import pytest -import torch +import shutil from pathlib import Path from unittest.mock import patch + +import pytest +import torch +from diffusers import DiffusionPipeline from transformers import AutoModelForCausalLM -from pruna.config.smash_config import SmashConfig + from pruna import smash -from pruna.engine.save import save_pruna_model -from pruna.engine.save import save_pruna_model_to_hub -from pruna.engine.save import SAVE_FUNCTIONS -from pruna.engine.load import load_pruna_model from pruna.config.smash_config import SmashConfig -from diffusers import DiffusionPipeline +from pruna.engine.load import load_pruna_model from pruna.engine.pruna_model import PrunaModel +from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub @pytest.mark.slow @@ -160,3 +160,68 @@ def test_push_to_hub_path_types(tmp_path) -> None: private=True ) assert mock_upload.called + + +@pytest.mark.cpu +def test_recovery_save_fn_is_none() -> None: + """Test that recovery algorithms use save_fn=None, preserving the prior algorithm's save format.""" + from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer + + assert PERPRecoverer.save_fn is None + + +@pytest.mark.cpu +def test_recovery_does_not_add_to_save_fns(tmp_path) -> None: + """Test that recovery's apply() does not append to save_fns when save_fn is None.""" + + config = SmashConfig() + config.save_fns = ["hqq"] # simulate a prior algorithm's save_fn + + save_fn = None + + # PrunaAlgorithmBase apply logic + if save_fn is not None and save_fn != SAVE_FUNCTIONS.reapply: + config.save_fns.append(save_fn.name) + + assert config.save_fns == ["hqq"], "Recovery should not add to save_fns" + + +@pytest.mark.cpu +def test_recovery_refresh_save_cache(tmp_path) -> None: + """Test that recovery refreshes a stale save_before_apply cache with recovered weights.""" + from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase + + model = AutoModelForCausalLM.from_pretrained("yujiepan/opt-tiny-random") + + config = SmashConfig(device="cpu") + + # Simulate a save_before_apply algorithm having run before recovery: + # 1. Save original (pre-transformation) model to cache + save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, config) + + # 2. Mark save_before_apply in save_fns (as the algorithm would) + config.save_fns.append(SAVE_FUNCTIONS.save_before_apply.name) + + # 3. Simulate the transformation (e.g., half) + recovery modifying weights + model.lm_head.weight.data.fill_(0.99) # "recovered" weights + + # 4. Simulate what recovery's apply() does: refresh the stale cache + ori_save_fns = config.save_fns[:] + config.save_fns = [fn for fn in config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + shutil.rmtree(save_dir, ignore_errors=True) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, config) + config.save_fns = ori_save_fns + + # 5. Verify the cache was refreshed: save_before_apply should copy updated files + save_path = tmp_path / "final_model" + save_pruna_model(model, save_path, config) + + # Load and verify the recovered weights survived the round-trip + loaded_model, _ = load_pruna_model(save_path) + loaded_model = loaded_model.cpu() + assert torch.allclose( + loaded_model.lm_head.weight, torch.full_like(loaded_model.lm_head.weight, 0.99) + ), "Recovered weights should survive save/load through save_before_apply"