From 6b913dadd9c26a422856f61991d088cb4c4f80aa Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Wed, 11 Feb 2026 00:09:56 +0100 Subject: [PATCH 01/19] feat: add Token Merging algorithm --- src/pruna/algorithms/token_merging.py | 528 ++++++++++++++++++++++++++ 1 file changed, 528 insertions(+) create mode 100644 src/pruna/algorithms/token_merging.py diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py new file mode 100644 index 00000000..ed8c3369 --- /dev/null +++ b/src/pruna/algorithms/token_merging.py @@ -0,0 +1,528 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +from typing import Any, Callable + +import torch +from ConfigSpace import UniformIntegerHyperparameter +from transformers import ImageClassificationPipeline + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.logging.logger import pruna_logger + +# --------------------------------------------------------------------------- +# Token merging utility functions (adapted from facebook/ToMe) +# --------------------------------------------------------------------------- + + +def _do_nothing(x: torch.Tensor, mode: str | None = None) -> torch.Tensor: + """Identity function used as a no-op merge / unmerge.""" + return x + + +def _parse_r(num_layers: int, r: int | list[int] | tuple[int, float]) -> list[int]: + """ + Process a constant *r* or *r* schedule into a per-layer list. + + Parameters + ---------- + num_layers : int + Number of transformer blocks. + r : int | list[int] | tuple[int, float] + Token reduction amount. Can be a constant ``int``, a ``(r, inflection)`` + tuple, or an explicit per-layer list. + Inflection describes the trend of the r value over layers. + It can increase (+1), decrease (-1), or stay constant (0). + Any value between -1 and +1 is accepted. + + Returns + ------- + list[int] + A list of length ``num_layers`` with the number of tokens to merge in + each layer. + """ + inflect = 0 + if isinstance(r, list): + if len(r) < num_layers: + r = r + [0] * (num_layers - len(r)) + return list(r) + elif isinstance(r, tuple): + r, inflect = r + + min_val = int(r * (1.0 - inflect)) + max_val = 2 * r - min_val + step = (max_val - min_val) / (num_layers - 1) + return [int(min_val + step * i) for i in range(num_layers)] + + +def _bipartite_soft_matching( + tokens: torch.Tensor, + r: int, + class_token: bool = False, + distill_token: bool = False, +) -> tuple[Callable, Callable]: + """ + Apply ToMe with a balanced matching set (50 %, 50 %). + + Parameters + ---------- + tokens : torch.Tensor + Token tensor of shape ``[batch, tokens, channels]``. + r : int + Number of tokens to remove (at most 50 % of tokens). + class_token : bool + Whether a class token is present (will not be merged). + distill_token : bool + Whether a distillation token is present (will not be merged). + + Returns + ------- + tuple[Callable, Callable] + ``(merge, unmerge)`` callables. + """ + protected = int(class_token) + int(distill_token) + t = tokens.shape[1] + r = min(r, (t - protected) // 2) + + if r <= 0: + return _do_nothing, _do_nothing + + with torch.no_grad(): + tokens = tokens / tokens.norm(dim=-1, keepdim=True) + a, b = tokens[..., ::2, :], tokens[..., 1::2, :] + scores = a @ b.transpose(-1, -2) + + if class_token: + scores[..., 0, :] = -math.inf + if distill_token: + scores[..., :, 0] = -math.inf + + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged tokens + src_idx = edge_idx[..., :r, :] # Merged tokens + dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + + if class_token: + unm_idx = unm_idx.sort(dim=1)[0] + + def merge(x: torch.Tensor, mode: str = "mean") -> torch.Tensor: + """Merge tokens by scattering sources into their matched destinations.""" + src, dst = x[..., ::2, :], x[..., 1::2, :] + n, t1, c = src.shape + unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + if distill_token: + return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1) + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + """Reverse a previous merge operation (approximate inverse).""" + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + n, _, c = unm.shape + + src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)) + + out = torch.zeros(n, tokens.shape[1], c, device=x.device, dtype=x.dtype) + out[..., 1::2, :] = dst + out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm) + out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src) + + return out + + return merge, unmerge + + +def _merge_wavg(merge: Callable, x: torch.Tensor, size: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """ + Merge via weighted average based on token size. + + Parameters + ---------- + merge : Callable + The merge function returned by ``_bipartite_soft_matching``. + x : torch.Tensor + Token tensor to merge. + size : torch.Tensor | None + Current token sizes. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(merged_x, new_size)``. + """ + if size is None: + size = torch.ones_like(x[..., 0, None]) + + x = merge(x * size, mode="sum") + size = merge(size, mode="sum") + x = x / size + return x, size + + +def _merge_source(merge: Callable, x: torch.Tensor, source: torch.Tensor | None = None) -> torch.Tensor: + """ + Track merge sources as an adjacency matrix. + + Parameters + ---------- + merge : Callable + The merge function returned by ``_bipartite_soft_matching``. + x : torch.Tensor + Token tensor (used to infer shape when *source* is ``None``). + source : torch.Tensor | None + Existing source adjacency matrix, or ``None`` to initialise. + + Returns + ------- + torch.Tensor + Updated source adjacency matrix. + """ + if source is None: + n, t, _ = x.shape + source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t) + return merge(source, mode="amax") + + +# --------------------------------------------------------------------------- +# ToMe-aware HuggingFace ViT modules (module-level for picklability) +# --------------------------------------------------------------------------- + +try: + from transformers.models.vit.modeling_vit import ViTLayer as _HFViTLayer + from transformers.models.vit.modeling_vit import ViTSelfAttention as _HFViTSelfAttention + + class ToMeViTSelfAttention(_HFViTSelfAttention): + """ + Self-attention with proportional attention and key-metric side-output for ToMe. + + Modifications over the base HuggingFace ``ViTSelfAttention``: + - Uses eager attention to inject proportional attention weighting when + ``self._tome_info["prop_attn"]`` is ``True`` and a token ``size`` is available. + - Stores the mean of *k* over heads in ``self._tome_info["metric"]`` so that + the enclosing ``ToMeViTLayer`` can use it for bipartite matching without + requiring changes to the intermediate ``ViTAttention`` wrapper. + """ + + _tome_info: dict[str, Any] + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass with proportional attention and key-metric storage.""" + batch_size = hidden_states.shape[0] + new_shape = (batch_size, -1, self.num_attention_heads, self.attention_head_size) + + key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) + query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) + + # Eager attention so we can inject the proportional-attention term. + attn_weights = (query_layer @ key_layer.transpose(-2, -1)) * self.scaling + + # Proportional attention: bias scores by log(token_size). + if self._tome_info["prop_attn"] and self._tome_info["size"] is not None: + attn_weights = attn_weights + self._tome_info["size"].log()[:, None, None, :, 0] + + if head_mask is not None: + attn_weights = attn_weights + head_mask + + attn_weights = attn_weights.softmax(dim=-1) + attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_prob if self.training else 0.0) + + context_layer = (attn_probs @ value_layer).transpose(1, 2) + context_layer = context_layer.reshape(batch_size, -1, self.all_head_size) + + # Store the key mean as the similarity metric for token merging. + self._tome_info["metric"] = key_layer.mean(1) + + return context_layer, attn_weights + + class ToMeViTLayer(_HFViTLayer): + """ + ViT encoder layer that applies Token Merging between attention and MLP. + + After the attention sub-layer and its residual connection, this layer + performs bipartite soft matching on the key-metric stored in + ``self._tome_info["metric"]`` and merges the ``r`` most similar token + pairs before proceeding to the MLP sub-layer. + """ + + _tome_info: dict[str, Any] + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with token merging between attention and MLP.""" + # --- self-attention + first residual --- + hidden_states_norm = self.layernorm_before(hidden_states) + attention_output = self.attention(hidden_states_norm, head_mask) + hidden_states = attention_output + hidden_states + + # --- token merging --- + r = self._tome_info["r"].pop(0) + if r > 0: + metric = self._tome_info["metric"] + merge, _ = _bipartite_soft_matching( + metric, + r, + self._tome_info["class_token"], + self._tome_info["distill_token"], + ) + if self._tome_info["trace_source"]: + self._tome_info["source"] = _merge_source(merge, hidden_states, self._tome_info["source"]) + hidden_states, self._tome_info["size"] = _merge_wavg(merge, hidden_states, self._tome_info["size"]) + + # --- MLP + second residual --- + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, hidden_states) + + return layer_output + +except ImportError: + ToMeViTSelfAttention = None + ToMeViTLayer = None + + +# --------------------------------------------------------------------------- +# Picklable model wrapper +# --------------------------------------------------------------------------- + + +class ToMeModelWrapper(torch.nn.Module): + """ + Wrapper that initialises ``_tome_info`` on every forward call. + + This class is defined at module level so that the wrapped model can be + pickled and unpickled without issues. On each forward pass it resets the + per-layer ``r`` schedule and clears any accumulated token-size / source + state before delegating to the underlying model. + + Parameters + ---------- + model : torch.nn.Module + A HuggingFace ViT model (already patched with ``ToMeViTLayer`` / + ``ToMeViTSelfAttention``). + r : int + The number of tokens to merge per layer. + tome_info : dict + The shared mutable state dict read/written by all ``ToMeViTLayer`` + instances. + num_layers : int + The number of transformer layers in the model (used by ``_parse_r``). + """ + + def __init__( + self, + model: torch.nn.Module, + r: int, + tome_info: dict, + num_layers: int, + ) -> None: + super().__init__() + self.model = model + self.r = r + self._tome_info = tome_info + self.num_layers = num_layers + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Initialise ToMe state and forward through the wrapped model.""" + self._tome_info["r"] = _parse_r(self.num_layers, self.r) + self._tome_info["size"] = None + self._tome_info["source"] = None + self._tome_info["metric"] = None + return self.model(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Delegate attribute access to the wrapped model for convenience.""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) + + +# --------------------------------------------------------------------------- +# Algorithm +# --------------------------------------------------------------------------- + + +class TokenMerging(PrunaAlgorithmBase): + """ + Apply Token Merging (ToMe) to HuggingFace Vision Transformer models. + + Token Merging progressively merges similar tokens between the attention + and MLP stages of each transformer block, reducing the total number of + tokens and therefore speeding up inference with minimal quality loss. + """ + + algorithm_name: str = "token_merging" + group_tags: list[str] = [tags.KERNEL] + save_fn = SAVE_FUNCTIONS.reapply + references: dict[str, str] = { + "Paper": "https://arxiv.org/abs/2210.09461", + "GitHub": "https://github.com/facebookresearch/ToMe", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cpu", "cuda"] + dataset_required: bool = False + + def model_check_fn(self, model: Any) -> bool: + """ + Check whether *model* contains HuggingFace ``ViTLayer`` blocks. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + ``True`` if the model contains at least one ``ViTLayer``. + """ + try: + from transformers.models.vit.modeling_vit import ViTLayer + except ImportError: + pruna_logger.warning("Transformers library not found. Token merging will not be applied.") + return False + + return any(isinstance(m, ViTLayer) for m in model.model.modules()) + + def import_algorithm_packages(self) -> dict[str, Any]: + """ + Import the required HuggingFace ViT classes. + + Returns + ------- + Dict[str, Any] + Dictionary with ``ViTLayer`` and ``ViTSelfAttention``. + """ + from transformers.models.vit.modeling_vit import ViTLayer, ViTSelfAttention + + return dict(ViTLayer=ViTLayer, ViTSelfAttention=ViTSelfAttention) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Apply Token Merging to a HuggingFace ViT model. + + For every ``ViTLayer`` in the model, swaps its class to + ``ToMeViTLayer`` (which performs bipartite token merging after + self-attention). For every ``ViTSelfAttention``, swaps its class to + ``ToMeViTSelfAttention`` (which uses eager attention with proportional + weighting and stores the key metric). The model is then wrapped in a + ``ToMeModelWrapper`` that resets the shared ``_tome_info`` state + before every forward pass. + + Parameters + ---------- + model : Any + A HuggingFace ViT model (e.g. ``ViTForImageClassification`` + or ``ViTModel``). + smash_config : SmashConfigPrefixWrapper + Algorithm configuration providing ``r``, ``trace_source``, and + ``prop_attn``. + + Returns + ------- + ToMeModelWrapper + The wrapped model with Token Merging applied. + """ + if isinstance(model, ImageClassificationPipeline): + model = model.model + + imported = self.import_algorithm_packages() + vit_layer_cls = imported["ViTLayer"] + vit_self_attn_cls = imported["ViTSelfAttention"] + + r = smash_config["r"] + trace_source = smash_config["trace_source"] + prop_attn = smash_config["prop_attn"] + + # Shared mutable state dict – every ToMe module reads from / writes to this. + tome_info: dict[str, Any] = { + "r": r, + "size": None, + "source": None, + "metric": None, + "trace_source": trace_source, + "prop_attn": prop_attn, + "class_token": True, + "distill_token": False, + } + + # Swap every ViTLayer / ViTSelfAttention to the ToMe-aware variants. + num_layers = 0 + for module in model.modules(): + if isinstance(module, vit_layer_cls): + module.__class__ = ToMeViTLayer + module._tome_info = tome_info + num_layers += 1 + elif isinstance(module, vit_self_attn_cls): + module.__class__ = ToMeViTSelfAttention + module._tome_info = tome_info + + return ToMeModelWrapper(model, r, tome_info, num_layers) + + def get_hyperparameters(self) -> list: + """ + Return the algorithm-specific hyperparameters. + + Returns + ------- + list + A list containing: + - ``r`` – number of tokens to merge per layer (int, 0–128). + - ``trace_source`` – whether to track merge provenance (bool). + - ``prop_attn`` – whether to use proportional attention (bool). + """ + return [ + UniformIntegerHyperparameter( + "r", + lower=0, + upper=128, + default_value=16, + meta=dict( + desc=( + "Number of tokens to merge per transformer layer. " + "Higher values speed up inference but may reduce accuracy." + ) + ), + ), + Boolean( + name="trace_source", + default=False, + meta=dict(desc="Track the source of each merged token (useful for visualisation)."), + ), + Boolean( + name="prop_attn", + default=True, + meta=dict(desc="Use proportional attention weights based on token size."), + ), + ] From addab931bae76428de5e194467842540b3ef39d5 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Wed, 11 Feb 2026 00:12:20 +0100 Subject: [PATCH 02/19] test: add Token Merging test class --- tests/algorithms/testers/token_merging.py | 48 +++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/algorithms/testers/token_merging.py diff --git a/tests/algorithms/testers/token_merging.py b/tests/algorithms/testers/token_merging.py new file mode 100644 index 00000000..390e987a --- /dev/null +++ b/tests/algorithms/testers/token_merging.py @@ -0,0 +1,48 @@ +from typing import Any + +import torch +from PIL import Image +from transformers import ImageClassificationPipeline + +from pruna import PrunaModel +from pruna.algorithms.token_merging import TokenMerging +from pruna.engine.utils import get_device + +from .base_tester import AlgorithmTesterBase + + +class TestTokenMerging(AlgorithmTesterBase): + """Test the token merging algorithm.""" + + models = ["vit_base", "vit_large"] + reject_models = [] + hyperparameters = {"token_merging_r": 16} + allow_pickle_files = True + algorithm_class = TokenMerging + metrics = ["total_macs", "latency"] + + def pre_smash_hook(self, model: Any) -> None: + """Hook to modify the model before smashing.""" + # Store original model info + self.input_image = Image.open("husky.png") + # Necessary to set the device to the same device as the model + model.device = torch.device(get_device(model)) + self.original_pred = model(self.input_image) + self.original_pred = [p["label"] for p in self.original_pred] + if isinstance(model, ImageClassificationPipeline): + self.input_image = model.preprocess(self.input_image)["pixel_values"] + self.input_image = self.input_image.to(model.device) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Hook to modify the model after smashing.""" + # Verify that token merging was applied + print(model.__class__) + print(model.model.__class__) + assert hasattr(model, "_tome_info"), "Model should have _tome_info attribute" + + output = model(self.input_image) + pred_labels = [model.config.id2label[p] for p in output[0].topk(5).indices[0].tolist()] + print("Output: ", pred_labels) + print("Original: ", self.original_pred) + assert model._tome_info["size"] is not None, "Size should be set" + assert pred_labels[0] == self.original_pred[0], "Most likely class should remain the same" From 39d5cda41129513e852b2b6d945595f989a58961 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Wed, 11 Feb 2026 00:13:01 +0100 Subject: [PATCH 03/19] test: add fixtures for ViT models --- tests/fixtures.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/fixtures.py b/tests/fixtures.py index 5710747e..5b62ec1d 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -123,6 +123,15 @@ def get_automodel_transformers(model_id: str, **kwargs: dict[str, Any]) -> tuple return model, smash_config +def get_vit_pipeline_for_specific_task(model_id: str, task: str, **kwargs: dict[str, Any]) -> tuple[Any, SmashConfig]: + """Get a transformers pipeline for specific task.""" + model = pipeline(task, model=model_id, **kwargs) + smash_config = SmashConfig() + + smash_config.add_data("ImageNet") + return model, smash_config + + def get_transformers_pipeline_for_specific_task( model_id: str, task: str, **kwargs: dict[str, Any] ) -> tuple[Any, SmashConfig]: @@ -183,6 +192,10 @@ def get_autoregressive_text_to_image_model(model_id: str) -> tuple[Any, SmashCon "shufflenet": partial(get_torchvision_model, "shufflenet_v2_x0_5"), "mobilenet_v2": partial(get_torchvision_model, "mobilenet_v2"), "resnet_18": partial(get_torchvision_model, "resnet18"), + "vit_base": partial(get_vit_pipeline_for_specific_task, "google/vit-base-patch16-224", task="image-classification"), + "vit_large": partial( + get_vit_pipeline_for_specific_task, "google/vit-large-patch16-224", task="image-classification" + ), # image generation models "stable_diffusion_v1_4": partial(get_diffusers_model, "CompVis/stable-diffusion-v1-4"), "stable_diffusion_3_medium_diffusers": partial( From 71af5fa8ce3b289c8a881a3c3cc1699e6d46bccf Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Tue, 17 Feb 2026 00:51:31 +0100 Subject: [PATCH 04/19] fix: adapt tome function signatures to HF classes --- src/pruna/algorithms/token_merging.py | 30 +++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index ed8c3369..993ea959 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -14,7 +14,7 @@ from __future__ import annotations import math -from typing import Any, Callable +from typing import Any, Callable, Optional, Tuple, Union import torch from ConfigSpace import UniformIntegerHyperparameter @@ -230,8 +230,9 @@ class ToMeViTSelfAttention(_HFViTSelfAttention): def forward( self, hidden_states: torch.Tensor, - head_mask: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: """Forward pass with proportional attention and key-metric storage.""" batch_size = hidden_states.shape[0] new_shape = (batch_size, -1, self.num_attention_heads, self.attention_head_size) @@ -256,10 +257,12 @@ def forward( context_layer = (attn_probs @ value_layer).transpose(1, 2) context_layer = context_layer.reshape(batch_size, -1, self.all_head_size) + outputs = (context_layer, attn_probs) if output_attentions else (context_layer,) + # Store the key mean as the similarity metric for token merging. self._tome_info["metric"] = key_layer.mean(1) - return context_layer, attn_weights + return outputs class ToMeViTLayer(_HFViTLayer): """ @@ -276,12 +279,19 @@ class ToMeViTLayer(_HFViTLayer): def forward( self, hidden_states: torch.Tensor, - head_mask: torch.Tensor | None = None, - ) -> torch.Tensor: + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: """Forward pass with token merging between attention and MLP.""" # --- self-attention + first residual --- - hidden_states_norm = self.layernorm_before(hidden_states) - attention_output = self.attention(hidden_states_norm, head_mask) + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + hidden_states = attention_output + hidden_states # --- token merging --- @@ -303,7 +313,9 @@ def forward( layer_output = self.intermediate(layer_output) layer_output = self.output(layer_output, hidden_states) - return layer_output + outputs = (layer_output,) + outputs + + return outputs except ImportError: ToMeViTSelfAttention = None From 88b521b159cafd70d98472033adcdad3eefa4750 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Fri, 6 Mar 2026 18:14:36 +0100 Subject: [PATCH 05/19] chore: remove print statements from test code --- tests/algorithms/testers/token_merging.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/algorithms/testers/token_merging.py b/tests/algorithms/testers/token_merging.py index 390e987a..be04361f 100644 --- a/tests/algorithms/testers/token_merging.py +++ b/tests/algorithms/testers/token_merging.py @@ -36,13 +36,9 @@ def pre_smash_hook(self, model: Any) -> None: def post_smash_hook(self, model: PrunaModel) -> None: """Hook to modify the model after smashing.""" # Verify that token merging was applied - print(model.__class__) - print(model.model.__class__) assert hasattr(model, "_tome_info"), "Model should have _tome_info attribute" output = model(self.input_image) pred_labels = [model.config.id2label[p] for p in output[0].topk(5).indices[0].tolist()] - print("Output: ", pred_labels) - print("Original: ", self.original_pred) assert model._tome_info["size"] is not None, "Size should be set" assert pred_labels[0] == self.original_pred[0], "Most likely class should remain the same" From 4a4706fff2e8d9830bcbb2a5b1aab1c82a74f750 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Fri, 6 Mar 2026 18:16:54 +0100 Subject: [PATCH 06/19] fix: add division by zero safeguard in _parse_r --- src/pruna/algorithms/token_merging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 993ea959..1a8c96c6 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -68,7 +68,7 @@ def _parse_r(num_layers: int, r: int | list[int] | tuple[int, float]) -> list[in min_val = int(r * (1.0 - inflect)) max_val = 2 * r - min_val - step = (max_val - min_val) / (num_layers - 1) + step = (max_val - min_val) / (num_layers - 1) if num_layers > 1 else 0 return [int(min_val + step * i) for i in range(num_layers)] From 82e0ed488b7dbc09f057eedcd1b2c73844cc4d34 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Fri, 6 Mar 2026 18:20:24 +0100 Subject: [PATCH 07/19] fix: change model_check_fn to support HF models and pipelines --- src/pruna/algorithms/token_merging.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 1a8c96c6..8ab85909 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -26,7 +26,7 @@ from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.save import SAVE_FUNCTIONS from pruna.logging.logger import pruna_logger - +from pruna.engine.model_checks import is_vit, is_transformers_pipeline_with_vit # --------------------------------------------------------------------------- # Token merging utility functions (adapted from facebook/ToMe) # --------------------------------------------------------------------------- @@ -425,7 +425,7 @@ def model_check_fn(self, model: Any) -> bool: pruna_logger.warning("Transformers library not found. Token merging will not be applied.") return False - return any(isinstance(m, ViTLayer) for m in model.model.modules()) + return is_vit(model) or is_transformers_pipeline_with_vit(model) def import_algorithm_packages(self) -> dict[str, Any]: """ @@ -466,8 +466,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: ToMeModelWrapper The wrapped model with Token Merging applied. """ - if isinstance(model, ImageClassificationPipeline): - model = model.model + if is_transformers_pipeline_with_vit(model): + return self._apply_to_model_within_transformers_pipeline(model, smash_config) imported = self.import_algorithm_packages() vit_layer_cls = imported["ViTLayer"] From 2a8b29201cd1a70c90794f9a4256f19c44e8b048 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Fri, 6 Mar 2026 18:21:13 +0100 Subject: [PATCH 08/19] fix: multiply head_mask instead of adding it --- src/pruna/algorithms/token_merging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 8ab85909..dbfde0a4 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -249,7 +249,7 @@ def forward( attn_weights = attn_weights + self._tome_info["size"].log()[:, None, None, :, 0] if head_mask is not None: - attn_weights = attn_weights + head_mask + attn_weights = attn_weights * head_mask attn_weights = attn_weights.softmax(dim=-1) attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_prob if self.training else 0.0) From b8543ec58c3e62e5e2e96263f41865dcb3c8a091 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Fri, 6 Mar 2026 18:22:00 +0100 Subject: [PATCH 09/19] perf: pre-compute parsed r at init --- src/pruna/algorithms/token_merging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index dbfde0a4..b9293c9c 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -362,10 +362,11 @@ def __init__( self.r = r self._tome_info = tome_info self.num_layers = num_layers + self.parsed_r = _parse_r(self.num_layers, self.r) def forward(self, *args: Any, **kwargs: Any) -> Any: """Initialise ToMe state and forward through the wrapped model.""" - self._tome_info["r"] = _parse_r(self.num_layers, self.r) + self._tome_info["r"] = self.parsed_r self._tome_info["size"] = None self._tome_info["source"] = None self._tome_info["metric"] = None From ac34a2fe6c827a7551f354d98ae3514907ba1279 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Tue, 7 Apr 2026 18:20:20 +0200 Subject: [PATCH 10/19] fix: remove output_attentions from forward call --- src/pruna/algorithms/token_merging.py | 31 ++++++++------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index b9293c9c..2c34b05d 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -14,7 +14,7 @@ from __future__ import annotations import math -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple import torch from ConfigSpace import UniformIntegerHyperparameter @@ -231,8 +231,7 @@ def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass with proportional attention and key-metric storage.""" batch_size = hidden_states.shape[0] new_shape = (batch_size, -1, self.num_attention_heads, self.attention_head_size) @@ -252,17 +251,17 @@ def forward( attn_weights = attn_weights * head_mask attn_weights = attn_weights.softmax(dim=-1) - attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_prob if self.training else 0.0) + attn_probs = torch.nn.functional.dropout( + attn_weights, p=self.dropout_prob if self.training else 0.0 + ) context_layer = (attn_probs @ value_layer).transpose(1, 2) context_layer = context_layer.reshape(batch_size, -1, self.all_head_size) - outputs = (context_layer, attn_probs) if output_attentions else (context_layer,) - # Store the key mean as the similarity metric for token merging. self._tome_info["metric"] = key_layer.mean(1) - return outputs + return context_layer, attn_probs class ToMeViTLayer(_HFViTLayer): """ @@ -280,18 +279,13 @@ def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + ) -> torch.Tensor: """Forward pass with token merging between attention and MLP.""" # --- self-attention + first residual --- - self_attention_outputs = self.attention( + attention_output = self.attention( self.layernorm_before(hidden_states), head_mask, - output_attentions=output_attentions, ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - hidden_states = attention_output + hidden_states # --- token merging --- @@ -313,9 +307,7 @@ def forward( layer_output = self.intermediate(layer_output) layer_output = self.output(layer_output, hidden_states) - outputs = (layer_output,) + outputs - - return outputs + return layer_output except ImportError: ToMeViTSelfAttention = None @@ -420,11 +412,6 @@ def model_check_fn(self, model: Any) -> bool: bool ``True`` if the model contains at least one ``ViTLayer``. """ - try: - from transformers.models.vit.modeling_vit import ViTLayer - except ImportError: - pruna_logger.warning("Transformers library not found. Token merging will not be applied.") - return False return is_vit(model) or is_transformers_pipeline_with_vit(model) From 10231b3b56de6a1701399255f44966b0a5bdb69d Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Tue, 7 Apr 2026 18:20:48 +0200 Subject: [PATCH 11/19] fix: add is_vit and is_transformers_pipeline_with_vit --- src/pruna/engine/model_checks.py | 38 +++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index fa5fb763..d1b69bd4 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -24,7 +24,10 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, ) -from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline +from transformers.pipelines.automatic_speech_recognition import ( + AutomaticSpeechRecognitionPipeline, +) +from transformers.pipelines.image_classification import ImageClassificationPipeline from transformers.pipelines.text2text_generation import Text2TextGenerationPipeline from transformers.pipelines.text_generation import TextGenerationPipeline @@ -124,6 +127,39 @@ def is_moe_lm(model: Any) -> bool: return hasattr(getattr(model, "config", None), "num_experts") +def is_vit(model: Any) -> bool: + """ + Check if the model is a ViT model. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a ViT model, False otherwise. + """ + return model.__class__.__name__ == "ViTForImageClassification" + + +def is_transformers_pipeline_with_vit(model: Any) -> bool: + """ + Check if the model is a transformers pipeline with a ViT model. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a transformers pipeline, False otherwise. + """ + return isinstance(model, ImageClassificationPipeline) and is_vit(getattr(model, "model", None)) + def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: """ Check if the model is a transformers pipeline (for tasks like text generation, classification, etc.). From 55bbf274c1e7de48f0eb8a7042d87bff151ca26a Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Tue, 7 Apr 2026 23:21:57 +0200 Subject: [PATCH 12/19] test: cleanup tests --- tests/algorithms/testers/token_merging.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/algorithms/testers/token_merging.py b/tests/algorithms/testers/token_merging.py index be04361f..533a222c 100644 --- a/tests/algorithms/testers/token_merging.py +++ b/tests/algorithms/testers/token_merging.py @@ -2,7 +2,6 @@ import torch from PIL import Image -from transformers import ImageClassificationPipeline from pruna import PrunaModel from pruna.algorithms.token_merging import TokenMerging @@ -16,29 +15,28 @@ class TestTokenMerging(AlgorithmTesterBase): models = ["vit_base", "vit_large"] reject_models = [] - hyperparameters = {"token_merging_r": 16} + hyperparameters = {"token_merging_r": 8} allow_pickle_files = True algorithm_class = TokenMerging - metrics = ["total_macs", "latency"] + metrics = [] def pre_smash_hook(self, model: Any) -> None: """Hook to modify the model before smashing.""" - # Store original model info self.input_image = Image.open("husky.png") # Necessary to set the device to the same device as the model model.device = torch.device(get_device(model)) self.original_pred = model(self.input_image) self.original_pred = [p["label"] for p in self.original_pred] - if isinstance(model, ImageClassificationPipeline): - self.input_image = model.preprocess(self.input_image)["pixel_values"] - self.input_image = self.input_image.to(model.device) def post_smash_hook(self, model: PrunaModel) -> None: """Hook to modify the model after smashing.""" - # Verify that token merging was applied - assert hasattr(model, "_tome_info"), "Model should have _tome_info attribute" + # The _tome_info lives on the ToMeModelWrapper inside the pipeline's .model + inner_model = model.model.model if hasattr(model.model, "model") else model.model + assert hasattr(inner_model, "_tome_info"), "Inner model should have _tome_info attribute" + # Pass the PIL image so the pipeline preprocesses it output = model(self.input_image) - pred_labels = [model.config.id2label[p] for p in output[0].topk(5).indices[0].tolist()] - assert model._tome_info["size"] is not None, "Size should be set" - assert pred_labels[0] == self.original_pred[0], "Most likely class should remain the same" + pred_labels = [p["label"] for p in output] + assert inner_model._tome_info["size"] is not None, "Size should be set" + # Check that the original top-1 is still in the top-5 after merging + assert self.original_pred[0] in pred_labels[:5], "Original top-1 should remain in top-5" From 07126e9f1a48d6de5923379b93bd72b0e73e6d09 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Tue, 7 Apr 2026 23:42:54 +0200 Subject: [PATCH 13/19] style: sort imports and remove unused imports --- src/pruna/algorithms/token_merging.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 2c34b05d..0db1cfc0 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -12,21 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import math from typing import Any, Callable, Optional, Tuple -import torch from ConfigSpace import UniformIntegerHyperparameter -from transformers import ImageClassificationPipeline +import torch from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import is_transformers_pipeline_with_vit, is_vit from pruna.engine.save import SAVE_FUNCTIONS -from pruna.logging.logger import pruna_logger -from pruna.engine.model_checks import is_vit, is_transformers_pipeline_with_vit # --------------------------------------------------------------------------- # Token merging utility functions (adapted from facebook/ToMe) # --------------------------------------------------------------------------- From a85d18a527ba298f6348d6a4f0b50dabfc1482db Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Tue, 7 Apr 2026 23:47:55 +0200 Subject: [PATCH 14/19] fix: make a copy of the list in the forward for supporting multiple passes --- src/pruna/algorithms/token_merging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 0db1cfc0..7250d30e 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -355,7 +355,8 @@ def __init__( def forward(self, *args: Any, **kwargs: Any) -> Any: """Initialise ToMe state and forward through the wrapped model.""" - self._tome_info["r"] = self.parsed_r + # Make a copy of the list to avoid modifying the original + self._tome_info["r"] = list(self.parsed_r) self._tome_info["size"] = None self._tome_info["source"] = None self._tome_info["metric"] = None From e771b571b04bb22c6db0f256c467977a737eea6b Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Tue, 7 Apr 2026 23:53:19 +0200 Subject: [PATCH 15/19] fix: read test images from HF dataset --- tests/algorithms/testers/token_merging.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/algorithms/testers/token_merging.py b/tests/algorithms/testers/token_merging.py index 533a222c..0bb3bc7a 100644 --- a/tests/algorithms/testers/token_merging.py +++ b/tests/algorithms/testers/token_merging.py @@ -1,7 +1,7 @@ from typing import Any +from datasets import load_dataset import torch -from PIL import Image from pruna import PrunaModel from pruna.algorithms.token_merging import TokenMerging @@ -22,7 +22,9 @@ class TestTokenMerging(AlgorithmTesterBase): def pre_smash_hook(self, model: Any) -> None: """Hook to modify the model before smashing.""" - self.input_image = Image.open("husky.png") + dataset = load_dataset("timm/mini-imagenet", split="test") + sample = dataset[2] + self.input_image = sample["image"] # Necessary to set the device to the same device as the model model.device = torch.device(get_device(model)) self.original_pred = model(self.input_image) From dbcd5fdd3f9ed49f7c0c0930bd15a8268d7d1c65 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Wed, 8 Apr 2026 00:12:35 +0200 Subject: [PATCH 16/19] style: ruff formatting --- src/pruna/algorithms/token_merging.py | 4 +--- src/pruna/engine/model_checks.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 7250d30e..dc07ff5d 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -248,9 +248,7 @@ def forward( attn_weights = attn_weights * head_mask attn_weights = attn_weights.softmax(dim=-1) - attn_probs = torch.nn.functional.dropout( - attn_weights, p=self.dropout_prob if self.training else 0.0 - ) + attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_prob if self.training else 0.0) context_layer = (attn_probs @ value_layer).transpose(1, 2) context_layer = context_layer.reshape(batch_size, -1, self.all_head_size) diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index d1b69bd4..55670a55 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -140,7 +140,7 @@ def is_vit(model: Any) -> bool: ------- bool True if the model is a ViT model, False otherwise. - """ + """ return model.__class__.__name__ == "ViTForImageClassification" @@ -160,6 +160,7 @@ def is_transformers_pipeline_with_vit(model: Any) -> bool: """ return isinstance(model, ImageClassificationPipeline) and is_vit(getattr(model, "model", None)) + def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: """ Check if the model is a transformers pipeline (for tasks like text generation, classification, etc.). From c8102733099dfb37e2575dc74c3fc891e17519c0 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Wed, 8 Apr 2026 00:19:41 +0200 Subject: [PATCH 17/19] fix: apply residual formatting --- src/pruna/algorithms/token_merging.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index dc07ff5d..5a392dc6 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations + import math from typing import Any, Callable, Optional, Tuple -from ConfigSpace import UniformIntegerHyperparameter import torch +from ConfigSpace import UniformIntegerHyperparameter from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags @@ -24,6 +25,7 @@ from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.model_checks import is_transformers_pipeline_with_vit, is_vit from pruna.engine.save import SAVE_FUNCTIONS + # --------------------------------------------------------------------------- # Token merging utility functions (adapted from facebook/ToMe) # --------------------------------------------------------------------------- @@ -408,7 +410,6 @@ def model_check_fn(self, model: Any) -> bool: bool ``True`` if the model contains at least one ``ViTLayer``. """ - return is_vit(model) or is_transformers_pipeline_with_vit(model) def import_algorithm_packages(self) -> dict[str, Any]: From 0c5e5477da8820c6de211ae36bb8dd533bc4e8be Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Wed, 8 Apr 2026 10:20:08 +0200 Subject: [PATCH 18/19] style: fix ty checks --- src/pruna/algorithms/token_merging.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 5a392dc6..50361e8b 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -307,8 +307,8 @@ def forward( return layer_output except ImportError: - ToMeViTSelfAttention = None - ToMeViTLayer = None + ToMeViTSelfAttention: type | None = None + ToMeViTLayer: type | None = None # --------------------------------------------------------------------------- @@ -385,7 +385,7 @@ class TokenMerging(PrunaAlgorithmBase): """ algorithm_name: str = "token_merging" - group_tags: list[str] = [tags.KERNEL] + group_tags: list[tags] = [tags.KERNEL] save_fn = SAVE_FUNCTIONS.reapply references: dict[str, str] = { "Paper": "https://arxiv.org/abs/2210.09461", @@ -505,12 +505,12 @@ def get_hyperparameters(self) -> list: lower=0, upper=128, default_value=16, - meta=dict( - desc=( + meta={ + "desc": ( "Number of tokens to merge per transformer layer. " "Higher values speed up inference but may reduce accuracy." ) - ), + }, ), Boolean( name="trace_source", From 4b562ab92838f377fe8e97bb36c97f6d45e9d0d0 Mon Sep 17 00:00:00 2001 From: Renato Sortino Date: Wed, 8 Apr 2026 11:09:02 +0200 Subject: [PATCH 19/19] docs: fix docstrings --- src/pruna/algorithms/token_merging.py | 58 +++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/src/pruna/algorithms/token_merging.py b/src/pruna/algorithms/token_merging.py index 50361e8b..dd6cbc57 100644 --- a/src/pruna/algorithms/token_merging.py +++ b/src/pruna/algorithms/token_merging.py @@ -222,6 +222,11 @@ class ToMeViTSelfAttention(_HFViTSelfAttention): - Stores the mean of *k* over heads in ``self._tome_info["metric"]`` so that the enclosing ``ToMeViTLayer`` can use it for bipartite matching without requiring changes to the intermediate ``ViTAttention`` wrapper. + + Parameters + ---------- + config : object + The ViT model configuration. """ _tome_info: dict[str, Any] @@ -231,7 +236,21 @@ def forward( hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass with proportional attention and key-metric storage.""" + """ + Forward pass with proportional attention and key-metric storage. + + Parameters + ---------- + hidden_states : torch.Tensor + Input token tensor of shape ``[batch, tokens, channels]``. + head_mask : torch.Tensor, optional + Mask for attention heads. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Context layer and attention probabilities. + """ batch_size = hidden_states.shape[0] new_shape = (batch_size, -1, self.num_attention_heads, self.attention_head_size) @@ -268,6 +287,11 @@ class ToMeViTLayer(_HFViTLayer): performs bipartite soft matching on the key-metric stored in ``self._tome_info["metric"]`` and merges the ``r`` most similar token pairs before proceeding to the MLP sub-layer. + + Parameters + ---------- + config : object + The ViT model configuration. """ _tome_info: dict[str, Any] @@ -277,7 +301,21 @@ def forward( hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Forward pass with token merging between attention and MLP.""" + """ + Forward pass with token merging between attention and MLP. + + Parameters + ---------- + hidden_states : torch.Tensor + Input token tensor of shape ``[batch, tokens, channels]``. + head_mask : torch.Tensor, optional + Mask for attention heads. + + Returns + ------- + torch.Tensor + Output tensor after attention, token merging, and MLP. + """ # --- self-attention + first residual --- attention_output = self.attention( self.layernorm_before(hidden_states), @@ -354,7 +392,21 @@ def __init__( self.parsed_r = _parse_r(self.num_layers, self.r) def forward(self, *args: Any, **kwargs: Any) -> Any: - """Initialise ToMe state and forward through the wrapped model.""" + """ + Initialise ToMe state and forward through the wrapped model. + + Parameters + ---------- + *args : Any + Positional arguments forwarded to the wrapped model. + **kwargs : Any + Keyword arguments forwarded to the wrapped model. + + Returns + ------- + Any + The output of the wrapped model's forward pass. + """ # Make a copy of the list to avoid modifying the original self._tome_info["r"] = list(self.parsed_r) self._tome_info["size"] = None