Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ conflicts = [
[{ extra = "intel" }, { extra = "stable-fast" }, { extra = "stable-fast-extraindex" }],
[{ extra = "intel" }, { extra = "full" }, { extra = "stable-fast-extraindex" }],
[{ extra = "intel" }, { extra = "vllm" }],
[{ extra = "kvpress" }, { extra = "vbench" }],
]

[tool.uv.sources]
Expand Down Expand Up @@ -234,6 +235,9 @@ intel = [
"torch>=2.7.0,<2.9.0",
"torchvision>=0.22.0,<0.24.0",
]
kvpress = [
"kvpress>=0.5.2",
]

[build-system]
requires = ["hatchling"]
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/gptq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GPTQ(PrunaAlgorithmBase):
processor_required: bool = False
runs_on: list[str] = ["cuda"]
dataset_required: bool = True
compatible_after: Iterable[str] = ["torch_compile", "sage_attn"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "kvpress"]
required_install: str = (
"You must first install the base package with ``pip install pruna`` "
"before installing the GPTQ extension with ``pip install pruna[gptq] --extra-index-url https://prunaai.pythonanywhere.com/``"
Expand Down
1 change: 1 addition & 0 deletions src/pruna/algorithms/half.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Half(PrunaAlgorithmBase):
"stable_fast",
"torch_compile",
"ifw",
"kvpress",
"whisper_s2t",
"sage_attn",
"hyper",
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class HQQ(PrunaAlgorithmBase):
runs_on: list[str] = ["cuda"]
dataset_required: bool = False
compatible_before: Iterable[str] = ["torch_structured", "moe_kernel_tuner"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "moe_kernel_tuner"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "kvpress", "moe_kernel_tuner"]
disjointly_compatible_before: Iterable[str] = []
disjointly_compatible_after: Iterable[str] = ["torchao"]

Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/huggingface_llm_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class LLMInt8(PrunaAlgorithmBase):
runs_on: list[str] = ["cuda", "accelerate"]
save_fn: None = None
compatible_before: Iterable[str] = ["moe_kernel_tuner"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "moe_kernel_tuner"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "kvpress", "moe_kernel_tuner"]

def get_hyperparameters(self) -> list:
"""
Expand Down
181 changes: 181 additions & 0 deletions src/pruna/algorithms/kvpress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
from collections.abc import Iterable
from typing import Any, Dict

from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.hyperparameters import UnconstrainedHyperparameter
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm
from pruna.engine.save import SAVE_FUNCTIONS

PRESS_TYPES = [
"CompactorPress",
"CURPress",
"ExpectedAttentionPress",
"ExpectedAttentionStatsPress",
"FastKVzipPress",
"FinchPress",
"KnormPress",
"KVzapPress",
"KVzipPress",
"KeyDiffPress",
"LagKVPress",
"LeverageScorePress",
"NonCausalAttnPress",
"ObservedAttentionPress",
"PyramidKVPress",
"QFilterPress",
"RandomPress",
"SnapKVPress",
"StreamingLLMPress",
"TOVAPress",
]


class KVPress(PrunaAlgorithmBase):
"""
Compress the KV cache of causal language models using KVPress.

KVPress is a library by NVIDIA that provides over 20 compression strategies (presses) for
reducing the memory footprint of the key-value cache during long-context inference. Each press
scores and prunes KV pairs after the prefill phase according to a chosen importance criterion.

This integration covers all scorer and standalone presses. Wrapper presses (e.g., ChunkPress,
AdaKVPress, PerLayerCompressionPress) that require a nested scorer press as input are not
included, as well as ThinKPress which compresses along the channel dimension with a different
parameter interface.
"""

algorithm_name: str = "kvpress"
group_tags: list[tags] = [tags.PRUNER]
save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.reapply
references: dict[str, str] = {
"GitHub": "https://github.com/NVIDIA/kvpress",
"Article": "https://huggingface.co/blog/nvidia/kvpress",
}
required_install: str = "pip install pruna[kvpress]"
tokenizer_required: bool = False
processor_required: bool = False
dataset_required: bool = False
runs_on: list[str] = ["cuda"]
compatible_before: Iterable[str] = [
"awq", "gptq", "half", "hqq", "llm_int8",
"quanto", "sage_attn", "torchao", "moe_kernel_tuner",
]
compatible_after: Iterable[str] = ["torch_compile", "moe_kernel_tuner"]

def get_hyperparameters(self) -> list:
"""
Configure all algorithm-specific hyperparameters with ConfigSpace.

Returns
-------
list
The hyperparameters.
"""
return [
CategoricalHyperparameter(
"press_type",
choices=PRESS_TYPES,
default_value="ExpectedAttentionPress",
meta={"desc": "The KV cache compression strategy to use."},
),
UniformFloatHyperparameter(
"compression_ratio",
lower=0.0,
upper=1.0,
default_value=0.5,
meta={"desc": "Fraction of KV pairs to remove. 0.0 means no compression."},
),
UnconstrainedHyperparameter(
"press_kwargs",
default_value=None,
meta={"desc": "Additional keyword arguments passed to the press constructor."},
),
]

def model_check_fn(self, model: Any) -> bool:
"""
Check if the model is a causal language model or a pipeline wrapping one.

Parameters
----------
model : Any
The model to check.

Returns
-------
bool
True if the model is compatible with KV cache compression, False otherwise.
"""
return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model)

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Wrap the model's generate method to apply KV cache compression via a press context manager.

Parameters
----------
model : Any
The causal language model to compress.
smash_config : SmashConfigPrefixWrapper
The algorithm-prefixed configuration containing press_type, compression_ratio, and press_kwargs.

Returns
-------
Any
The model with its generate method wrapped to compress the KV cache on each call.
"""
imported_modules = self.import_algorithm_packages()

press_type = smash_config["press_type"]
compression_ratio = smash_config["compression_ratio"]
press_kwargs = smash_config["press_kwargs"] or {}

press_cls = imported_modules[press_type]
press = press_cls(compression_ratio=compression_ratio, **press_kwargs)

original_generate = model.generate

@functools.wraps(original_generate)
def generate_with_press(*args, **kwargs):
with press(model):
return original_generate(*args, **kwargs)

model.generate = generate_with_press
Comment on lines +157 to +164
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reading the original repo they recommend wrapping the call with the context manager instead of the generate.
Is there a specific reason why you chose this way of implementing it?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. I will investigate this.

model._kvpress_original_generate = original_generate
model._kvpress_press = press

return model

def import_algorithm_packages(self) -> Dict[str, Any]:
"""
Lazily import kvpress and collect all supported press classes.

Returns
-------
Dict[str, Any]
A dictionary mapping press class names to their classes.
"""
import kvpress

return {name: getattr(kvpress, name) for name in PRESS_TYPES}
2 changes: 1 addition & 1 deletion src/pruna/algorithms/llm_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class LLMCompressor(PrunaAlgorithmBase):
dataset_required: bool = True
runs_on: list[str] = ["cuda"]
compatible_before: Iterable[str] = ["moe_kernel_tuner"]
compatible_after: Iterable[str] = ["sage_attn", "moe_kernel_tuner"]
compatible_after: Iterable[str] = ["sage_attn", "kvpress", "moe_kernel_tuner"]
required_install = "``uv pip install 'pruna[awq]'``"

def get_hyperparameters(self) -> list:
Expand Down
4 changes: 2 additions & 2 deletions src/pruna/algorithms/moe_kernel_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ class MoeKernelTuner(PrunaAlgorithmBase):
dataset_required: bool = False
compatible_before: Iterable[str] = [
"awq", "deepcache", "diffusers_int8", "fastercache", "flash_attn3",
"fora", "hqq", "hqq_diffusers", "llm_int8", "pab", "padding_pruning",
"fora", "hqq", "hqq_diffusers", "kvpress", "llm_int8", "pab", "padding_pruning",
"qkv_diffusers", "quanto", "reduce_noe", "ring_attn", "sage_attn",
"torch_compile", "torchao",
]
compatible_after: Iterable[str] = [
"awq", "deepcache", "diffusers_int8", "fastercache", "flash_attn3",
"fora", "hqq", "hqq_diffusers", "llm_int8", "pab", "padding_pruning",
"fora", "hqq", "hqq_diffusers", "kvpress", "llm_int8", "pab", "padding_pruning",
"qkv_diffusers", "quanto", "ring_attn", "sage_attn",
"torch_compile", "torchao",
]
Expand Down
1 change: 1 addition & 0 deletions src/pruna/algorithms/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Quanto(PrunaAlgorithmBase):
compatible_before: Iterable[str] = ["qkv_diffusers", "moe_kernel_tuner"]
compatible_after: Iterable[str] = [
"deepcache",
"kvpress",
"sage_attn",
"text_to_image_distillation_inplace_perp",
"text_to_image_distillation_lora",
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/sage_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SageAttn(PrunaAlgorithmBase):
runs_on: list[str] = ["cuda", "accelerate"]
dataset_required: bool = False
compatible_before: Iterable[str | tags] = [tags.QUANTIZER, "moe_kernel_tuner"]
compatible_after: Iterable[str | tags] = ["torch_compile", tags.CACHER, "moe_kernel_tuner"]
compatible_after: Iterable[str | tags] = ["torch_compile", tags.CACHER, "kvpress", "moe_kernel_tuner"]

def model_check_fn(self, model: Any) -> bool:
"""
Expand Down
1 change: 1 addition & 0 deletions src/pruna/algorithms/torch_compile/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TorchCompile(PrunaAlgorithmBase):
"hyper",
"padding_pruning",
"ring_attn",
"kvpress",
"text_to_image_distillation_inplace_perp",
"text_to_image_distillation_lora",
"text_to_image_distillation_perp",
Expand Down
1 change: 1 addition & 0 deletions src/pruna/algorithms/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class Torchao(PrunaAlgorithmBase):
"fora",
"torch_compile",
"sage_attn",
"kvpress",
"img2img_denoise",
"realesrgan_upscale",
"moe_kernel_tuner",
Expand Down
5 changes: 3 additions & 2 deletions src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,9 @@ def add(self, request: str | list[str] | dict[str, Any]) -> None:
# request wants to activate a dictionary of algorithms and their hyperparameters
elif isinstance(request, dict):
for key, value in request.items():
# target modules are a special case, as they are a hyperparameter but their value is a dict
if isinstance(value, dict) and "target_module" not in key:
# if the key is an algorithm name and the value is a dict, treat it as
# algorithm activation + hyperparameter setting (e.g. {"hqq": {"weight_bits": 4}})
if isinstance(value, dict) and key in SMASH_SPACE.get_all_algorithms():
self._configuration[key] = True
for k, v in value.items():
if not k.startswith(key):
Expand Down
42 changes: 42 additions & 0 deletions tests/algorithms/testers/kvpress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pruna import PrunaModel
from pruna.algorithms.kvpress import KVPress

from .base_tester import AlgorithmTesterBase


class TestKVPress(AlgorithmTesterBase):
"""Test the KVPress KV cache compression algorithm with default settings."""

models = ["llama_3_tiny_random"]
reject_models = ["sd_tiny_random"]
allow_pickle_files = False
algorithm_class = KVPress
metrics = ["perplexity"]

def post_smash_hook(self, model: PrunaModel) -> None:
"""Verify that the press was applied to the model."""
assert hasattr(model, "_kvpress_press")
assert hasattr(model, "_kvpress_original_generate")


class TestKVPressSnapKV(AlgorithmTesterBase):
"""Test the KVPress algorithm with SnapKV and custom press_kwargs."""

models = ["llama_3_tiny_random"]
reject_models = ["sd_tiny_random"]
allow_pickle_files = False
algorithm_class = KVPress
metrics = ["perplexity"]
hyperparameters = {
"kvpress_press_type": "SnapKVPress",
"kvpress_compression_ratio": 0.3,
"kvpress_press_kwargs": {"window_size": 32, "kernel_size": 3},
}

def post_smash_hook(self, model: PrunaModel) -> None:
"""Verify that SnapKV press was applied with correct parameters."""
assert hasattr(model, "_kvpress_press")
press = model._kvpress_press
assert type(press).__name__ == "SnapKVPress"
assert press.window_size == 32
assert press.kernel_size == 3