From cea0bb9a53045ead400fac8f2d9396a8d19c3c88 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:06:01 -0600 Subject: [PATCH 1/5] refactor: add prepare_saving as config property and get util function: get_save_before_smash_dir --- src/pruna/algorithms/base/pruna_base.py | 10 ++++++++-- src/pruna/config/smash_config.py | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0784069b..0cce1e68 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -16,6 +16,7 @@ import functools from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, Iterable from transformers import Pipeline @@ -355,8 +356,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: Any The model after the algorithm has been applied. """ - if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config._prepare_saving: - save_dir = smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config.prepare_saving: + save_dir = self.get_save_before_smash_dir(smash_config) save_pruna_model(model, save_dir, smash_config) # save algorithms to reapply after loading @@ -447,6 +448,11 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]: """ return _expand_tags_into_algorithm_names(self.disjointly_compatible_after) + @staticmethod + def get_save_before_smash_dir(smash_config: SmashConfig) -> Path: + """Get the save directory for the algorithm caches.""" + return smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + def wrap_handle_imports(func): """ diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 0acc1e12..81429b48 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -119,6 +119,11 @@ def __init__( raise ValueError(f"Unsupported configuration type: {type(configuration)}") self.config_space: ConfigurationSpace = self._configuration.config_space + @property + def prepare_saving(self): + """Getter of _prepare_saving as an object's internal data.""" + return self._prepare_saving + @classmethod def from_list( cls, From 67880c11facbd1bf719cc6723368722cc096abc3 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:28:06 -0600 Subject: [PATCH 2/5] refactor: save_fns won't append any new save_fn; save recovered model's weights for "save_before_apply" algos --- .../global_utils/recovery/perp_recoverer.py | 46 +++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py index 4b65d9df..8d7941c0 100644 --- a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py +++ b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import shutil from typing import Any, Dict import torch @@ -23,7 +24,7 @@ from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr from pruna.algorithms.global_utils.recovery.utils import get_trainable_parameters -from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper from pruna.engine.model_checks import ( is_causal_lm, is_flux_pipeline, @@ -31,7 +32,7 @@ is_sd_pipeline, is_sdxl_pipeline, ) -from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model from pruna.logging.logger import pruna_logger @@ -52,7 +53,7 @@ class PERPRecoverer(PrunaAlgorithmBase): """ group_tags: list[AlgorithmTag] = [AlgorithmTag.RECOVERER] # type: ignore[attr-defined] - save_fn = SAVE_FUNCTIONS.pickled + save_fn = None references: dict[str, str] = { "GitHub": "https://github.com/huggingface/peft", "Paper": "https://arxiv.org/pdf/2312.15230", @@ -181,6 +182,45 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) + def apply(self, model: Any, smash_config: SmashConfig) -> Any: + """ + Apply the recovery algorithm and refresh the save cache if needed. + + Recovery modifies weights in-place without changing the model's serialization + format. If a prior algorithm used ``save_before_apply`` (caching the model before + its transformation), the cached snapshot is now stale because recovery changed + the weights. This override refreshes that cache so the already saved model includes + the recovered weights. + + Parameters + ---------- + model : Any + The model to apply the algorithm to. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + + Returns + ------- + Any + The model after recovery has been applied. + """ + result = super().apply(model, smash_config) + + if smash_config.prepare_saving: + save_dir = self.get_save_before_smash_dir(smash_config) + if not save_dir.exists(): + return result + + ori_save_fns = smash_config.save_fns[:] + smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + # Re-save with recovered weights + shutil.rmtree(save_dir, ignore_errors=True) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, smash_config) + # Restore save_fns + smash_config.save_fns = ori_save_fns + return result + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ Recover performances from a given model with a given config. From c86401d648b17848c819be0c91b75c9b84e915c0 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:31:43 -0600 Subject: [PATCH 3/5] test: add unit tests to verify main ideas --- tests/engine/test_save.py | 81 +++++++++++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 8 deletions(-) diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 2cd6f3e1..1191836a 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -1,18 +1,18 @@ import os -import pytest -import torch +import shutil from pathlib import Path from unittest.mock import patch + +import pytest +import torch +from diffusers import DiffusionPipeline from transformers import AutoModelForCausalLM -from pruna.config.smash_config import SmashConfig + from pruna import smash -from pruna.engine.save import save_pruna_model -from pruna.engine.save import save_pruna_model_to_hub -from pruna.engine.save import SAVE_FUNCTIONS -from pruna.engine.load import load_pruna_model from pruna.config.smash_config import SmashConfig -from diffusers import DiffusionPipeline +from pruna.engine.load import PICKLED_FILE_NAME, load_pruna_model from pruna.engine.pruna_model import PrunaModel +from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub @pytest.mark.slow @@ -160,3 +160,68 @@ def test_push_to_hub_path_types(tmp_path) -> None: private=True ) assert mock_upload.called + + +@pytest.mark.cpu +def test_recovery_save_fn_is_none() -> None: + """Test that recovery algorithms use save_fn=None, preserving the prior algorithm's save format.""" + from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer + + assert PERPRecoverer.save_fn is None + + +@pytest.mark.cpu +def test_recovery_does_not_add_to_save_fns(tmp_path) -> None: + """Test that recovery's apply() does not append to save_fns when save_fn is None.""" + + config = SmashConfig() + config.save_fns = ["hqq"] # simulate a prior algorithm's save_fn + + save_fn = None + + # PrunaAlgorithmBase apply logic + if save_fn is not None and save_fn != SAVE_FUNCTIONS.reapply: + config.save_fns.append(save_fn.name) + + assert config.save_fns == ["hqq"], "Recovery should not add to save_fns" + + +@pytest.mark.cpu +def test_recovery_refresh_save_cache(tmp_path) -> None: + """Test that recovery refreshes a stale save_before_apply cache with recovered weights.""" + from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase + + model = AutoModelForCausalLM.from_pretrained("yujiepan/opt-tiny-random") + + config = SmashConfig(device="cpu") + + # Simulate a save_before_apply algorithm having run before recovery: + # 1. Save original (pre-transformation) model to cache + save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, config) + + # 2. Mark save_before_apply in save_fns (as the algorithm would) + config.save_fns.append(SAVE_FUNCTIONS.save_before_apply.name) + + # 3. Simulate the transformation (e.g., half) + recovery modifying weights + model.lm_head.weight.data.fill_(0.99) # "recovered" weights + + # 4. Simulate what recovery's apply() does: refresh the stale cache + ori_save_fns = config.save_fns[:] + config.save_fns = [fn for fn in config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + shutil.rmtree(save_dir, ignore_errors=True) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, config) + config.save_fns = ori_save_fns + + # 5. Verify the cache was refreshed: save_before_apply should copy updated files + save_path = tmp_path / "final_model" + save_pruna_model(model, save_path, config) + + # Load and verify the recovered weights survived the round-trip + loaded_model, _ = load_pruna_model(save_path) + loaded_model = loaded_model.cpu() + assert torch.allclose( + loaded_model.lm_head.weight, torch.full_like(loaded_model.lm_head.weight, 0.99) + ), "Recovered weights should survive save/load through save_before_apply" From 10627a395d30bb95ad7a785546d763853f0cec1a Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:39:30 -0600 Subject: [PATCH 4/5] fix: remove unused imports from test_save.py --- tests/engine/test_save.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 1191836a..34fd30b6 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -10,7 +10,7 @@ from pruna import smash from pruna.config.smash_config import SmashConfig -from pruna.engine.load import PICKLED_FILE_NAME, load_pruna_model +from pruna.engine.load import load_pruna_model from pruna.engine.pruna_model import PrunaModel from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub From e5aaccd2e2a918da39d82334439640aac0a34ce4 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Wed, 8 Apr 2026 09:31:00 -0600 Subject: [PATCH 5/5] docs: add docstring for get_save_before_smash_dir --- src/pruna/algorithms/base/pruna_base.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0cce1e68..fa9a0405 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -450,8 +450,20 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]: @staticmethod def get_save_before_smash_dir(smash_config: SmashConfig) -> Path: - """Get the save directory for the algorithm caches.""" - return smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + """ + Get the save directory for the algorithm caches. + + Parameters + ---------- + smash_config : SmashConfig + The SmashConfig to check the cache directory against. + + Returns + ------- + Path + The absolute path of "SAVE_BEFORE_SMASH_CACHE_DIR". + """ + return (smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR).resolve() def wrap_handle_imports(func):