-
Notifications
You must be signed in to change notification settings - Fork 89
feat: integrate KVPress for KV cache compression (#366) #623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kschwethelm
wants to merge
10
commits into
PrunaAI:main
Choose a base branch
from
kschwethelm:feat/kvpress
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
abe8245
feat: integrate KVPress for KV cache compression
kschwethelm 54440b6
feat: bump kvpress to >=0.5.2, add FastKVzipPress
kschwethelm 2cd50b7
feat: add press_kwargs for press-specific parameters
kschwethelm f55489b
fix: compatibility, press_kwargs, unit tests, remove wrappers
kschwethelm 9c434e3
feat: add KV_CACHER tag, replace explicit kvpress references
kschwethelm 5250c84
refactor: rename KV_CACHER tag to KV_COMPRESSOR, improve docstrings
kschwethelm ebe9960
docs: document excluded wrapper presses in kvpress docstring
kschwethelm da9199e
ci: gate kvpress tests behind requires_kvpress marker
kschwethelm a537ff4
refactor: remove KV_COMPRESSOR tag, reference kvpress by name
kschwethelm a700a2b
test: remove kvpress API smoke tests and CI scaffolding
kschwethelm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import functools | ||
| from collections.abc import Iterable | ||
| from typing import Any, Dict | ||
|
|
||
| from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter | ||
|
|
||
| from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase | ||
| from pruna.algorithms.base.tags import AlgorithmTag as tags | ||
| from pruna.config.hyperparameters import UnconstrainedHyperparameter | ||
| from pruna.config.smash_config import SmashConfigPrefixWrapper | ||
| from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm | ||
| from pruna.engine.save import SAVE_FUNCTIONS | ||
|
|
||
| PRESS_TYPES = [ | ||
| "CompactorPress", | ||
| "CURPress", | ||
| "ExpectedAttentionPress", | ||
| "ExpectedAttentionStatsPress", | ||
| "FastKVzipPress", | ||
| "FinchPress", | ||
| "KnormPress", | ||
| "KVzapPress", | ||
| "KVzipPress", | ||
| "KeyDiffPress", | ||
| "LagKVPress", | ||
| "LeverageScorePress", | ||
| "NonCausalAttnPress", | ||
| "ObservedAttentionPress", | ||
| "PyramidKVPress", | ||
| "QFilterPress", | ||
| "RandomPress", | ||
| "SnapKVPress", | ||
| "StreamingLLMPress", | ||
| "TOVAPress", | ||
| ] | ||
|
|
||
|
|
||
| class KVPress(PrunaAlgorithmBase): | ||
| """ | ||
| Compress the KV cache of causal language models using KVPress. | ||
|
|
||
| KVPress is a library by NVIDIA that provides over 20 compression strategies (presses) for | ||
| reducing the memory footprint of the key-value cache during long-context inference. Each press | ||
| scores and prunes KV pairs after the prefill phase according to a chosen importance criterion. | ||
|
|
||
| This integration covers all scorer and standalone presses. Wrapper presses (e.g., ChunkPress, | ||
| AdaKVPress, PerLayerCompressionPress) that require a nested scorer press as input are not | ||
| included, as well as ThinKPress which compresses along the channel dimension with a different | ||
| parameter interface. | ||
| """ | ||
|
|
||
| algorithm_name: str = "kvpress" | ||
| group_tags: list[tags] = [tags.PRUNER] | ||
| save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.reapply | ||
| references: dict[str, str] = { | ||
| "GitHub": "https://github.com/NVIDIA/kvpress", | ||
| "Article": "https://huggingface.co/blog/nvidia/kvpress", | ||
| } | ||
| required_install: str = "pip install pruna[kvpress]" | ||
| tokenizer_required: bool = False | ||
| processor_required: bool = False | ||
| dataset_required: bool = False | ||
| runs_on: list[str] = ["cuda"] | ||
| compatible_before: Iterable[str] = [ | ||
| "awq", "gptq", "half", "hqq", "llm_int8", | ||
| "quanto", "sage_attn", "torchao", "moe_kernel_tuner", | ||
| ] | ||
| compatible_after: Iterable[str] = ["torch_compile", "moe_kernel_tuner"] | ||
|
|
||
| def get_hyperparameters(self) -> list: | ||
| """ | ||
| Configure all algorithm-specific hyperparameters with ConfigSpace. | ||
|
|
||
| Returns | ||
| ------- | ||
| list | ||
| The hyperparameters. | ||
| """ | ||
| return [ | ||
| CategoricalHyperparameter( | ||
| "press_type", | ||
| choices=PRESS_TYPES, | ||
| default_value="ExpectedAttentionPress", | ||
| meta={"desc": "The KV cache compression strategy to use."}, | ||
| ), | ||
| UniformFloatHyperparameter( | ||
| "compression_ratio", | ||
| lower=0.0, | ||
| upper=1.0, | ||
| default_value=0.5, | ||
| meta={"desc": "Fraction of KV pairs to remove. 0.0 means no compression."}, | ||
| ), | ||
| UnconstrainedHyperparameter( | ||
| "press_kwargs", | ||
| default_value=None, | ||
| meta={"desc": "Additional keyword arguments passed to the press constructor."}, | ||
| ), | ||
| ] | ||
|
|
||
| def model_check_fn(self, model: Any) -> bool: | ||
| """ | ||
| Check if the model is a causal language model or a pipeline wrapping one. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
| The model to check. | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| True if the model is compatible with KV cache compression, False otherwise. | ||
| """ | ||
| return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model) | ||
|
|
||
| def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: | ||
| """ | ||
| Wrap the model's generate method to apply KV cache compression via a press context manager. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Any | ||
| The causal language model to compress. | ||
| smash_config : SmashConfigPrefixWrapper | ||
| The algorithm-prefixed configuration containing press_type, compression_ratio, and press_kwargs. | ||
|
|
||
| Returns | ||
| ------- | ||
| Any | ||
| The model with its generate method wrapped to compress the KV cache on each call. | ||
| """ | ||
| imported_modules = self.import_algorithm_packages() | ||
|
|
||
| press_type = smash_config["press_type"] | ||
| compression_ratio = smash_config["compression_ratio"] | ||
| press_kwargs = smash_config["press_kwargs"] or {} | ||
|
|
||
| press_cls = imported_modules[press_type] | ||
| press = press_cls(compression_ratio=compression_ratio, **press_kwargs) | ||
|
|
||
| original_generate = model.generate | ||
|
|
||
| @functools.wraps(original_generate) | ||
| def generate_with_press(*args, **kwargs): | ||
| with press(model): | ||
| return original_generate(*args, **kwargs) | ||
|
|
||
| model.generate = generate_with_press | ||
| model._kvpress_original_generate = original_generate | ||
| model._kvpress_press = press | ||
|
|
||
| return model | ||
|
|
||
| def import_algorithm_packages(self) -> Dict[str, Any]: | ||
| """ | ||
| Lazily import kvpress and collect all supported press classes. | ||
|
|
||
| Returns | ||
| ------- | ||
| Dict[str, Any] | ||
| A dictionary mapping press class names to their classes. | ||
| """ | ||
| import kvpress | ||
|
|
||
| return {name: getattr(kvpress, name) for name in PRESS_TYPES} | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| from pruna import PrunaModel | ||
| from pruna.algorithms.kvpress import KVPress | ||
|
|
||
| from .base_tester import AlgorithmTesterBase | ||
|
|
||
|
|
||
| class TestKVPress(AlgorithmTesterBase): | ||
| """Test the KVPress KV cache compression algorithm with default settings.""" | ||
|
|
||
| models = ["llama_3_tiny_random"] | ||
| reject_models = ["sd_tiny_random"] | ||
| allow_pickle_files = False | ||
| algorithm_class = KVPress | ||
| metrics = ["perplexity"] | ||
|
|
||
| def post_smash_hook(self, model: PrunaModel) -> None: | ||
| """Verify that the press was applied to the model.""" | ||
| assert hasattr(model, "_kvpress_press") | ||
| assert hasattr(model, "_kvpress_original_generate") | ||
|
|
||
|
|
||
| class TestKVPressSnapKV(AlgorithmTesterBase): | ||
| """Test the KVPress algorithm with SnapKV and custom press_kwargs.""" | ||
|
|
||
| models = ["llama_3_tiny_random"] | ||
| reject_models = ["sd_tiny_random"] | ||
| allow_pickle_files = False | ||
| algorithm_class = KVPress | ||
| metrics = ["perplexity"] | ||
| hyperparameters = { | ||
| "kvpress_press_type": "SnapKVPress", | ||
| "kvpress_compression_ratio": 0.3, | ||
| "kvpress_press_kwargs": {"window_size": 32, "kernel_size": 3}, | ||
| } | ||
|
|
||
| def post_smash_hook(self, model: PrunaModel) -> None: | ||
| """Verify that SnapKV press was applied with correct parameters.""" | ||
| assert hasattr(model, "_kvpress_press") | ||
| press = model._kvpress_press | ||
| assert type(press).__name__ == "SnapKVPress" | ||
| assert press.window_size == 32 | ||
| assert press.kernel_size == 3 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reading the original repo they recommend wrapping the call with the context manager instead of the generate.
Is there a specific reason why you chose this way of implementing it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. I will investigate this.