From dd006fe98e97655b27ff93c3376acd1eec7dcbb3 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Tue, 10 Mar 2026 14:34:31 +0900 Subject: [PATCH 1/3] [utils] Improve cache utils to support layer-based caches Let's improve its coverage. TICO-DCO-1.0-Signed-off-by: Dayoung Lee --- .../TinyLlamaWithDynamicCache/__init__.py | 1 + .../model/TinyLlamaWithDynamicCache/model.py | 100 +++ .../requirements.txt | 1 + .../unit_test/utils_test/test_pytree_utils.py | 397 ++++++++++++ tico/utils/pytree_utils.py | 572 ++++++++++++++---- 5 files changed, 961 insertions(+), 110 deletions(-) create mode 100644 test/modules/model/TinyLlamaWithDynamicCache/__init__.py create mode 100644 test/modules/model/TinyLlamaWithDynamicCache/model.py create mode 100644 test/modules/model/TinyLlamaWithDynamicCache/requirements.txt create mode 100644 test/unit_test/utils_test/test_pytree_utils.py diff --git a/test/modules/model/TinyLlamaWithDynamicCache/__init__.py b/test/modules/model/TinyLlamaWithDynamicCache/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/test/modules/model/TinyLlamaWithDynamicCache/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/test/modules/model/TinyLlamaWithDynamicCache/model.py b/test/modules/model/TinyLlamaWithDynamicCache/model.py new file mode 100644 index 00000000..15db288c --- /dev/null +++ b/test/modules/model/TinyLlamaWithDynamicCache/model.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E test: TinyLlama in decode mode with a DynamicCache. + +Scenario +-------- +Simulates the token-generation (decode) step where a previously-computed +key/value cache is fed back into the model alongside a single new token. + +register_dynamic_cache() selects the correct pytree flatten strategy +automatically based on the installed transformers version: + +* transformers with DynamicLayer (newer): Layer-based layout (cache.layers) +* transformers without DynamicLayer (e.g. 4.52.x): legacy layout + (cache.key_cache / cache.value_cache) + +register_dynamic_layer() is also called so that if the Layer-based layout is +in use, DynamicLayer objects inside the cache are also pytree-traversable. +It is a safe no-op when DynamicLayer does not exist in the installed +transformers version. +""" + +import torch + +from tico.utils.pytree_utils import register_dynamic_cache, register_dynamic_layer +from transformers import AutoModelForCausalLM +from transformers.cache_utils import DynamicCache + +from test.modules.base import TestModuleBase + +# Number of previously-processed tokens to pre-fill into the cache. +_PAST_SEQ_LEN = 5 + +# To suppress warning: +# _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. +# (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. +# (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message. +# WeightsUnpickler error: Unsupported global: GLOBAL transformers.cache_utils.DynamicCache was not an allowed global by default. Please use `torch.serialization.add_safe_globals([transformers.cache_utils.DynamicCache])` or the `torch.serialization.safe_globals([transformers.cache_utils.DynamicCache])` context manager to allowlist this global if you trust this class/function. +torch.serialization.add_safe_globals([DynamicCache]) + + +class TinyLlamaWithDynamicCache(TestModuleBase): + """TinyLlama decode step with a pre-populated DynamicCache.""" + + def __init__(self): + super().__init__() + self.model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0").to( + "cpu" + ) + self.cfg = self.model.config + self.rtol = 1e-4 + self.atol = 1e-4 + + # register_dynamic_cache picks the right flatten strategy for the + # installed transformers version automatically. + # register_dynamic_layer is a no-op when DynamicLayer doesn't exist. + register_dynamic_cache() + register_dynamic_layer() + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def get_example_inputs(self): + cfg = self.cfg + num_layers = cfg.num_hidden_layers + num_kv_heads = getattr(cfg, "num_key_value_heads", cfg.num_attention_heads) + head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads) + + # Single new token (decode step). + input_ids = torch.tensor([[869]], dtype=torch.long) # token id for '▁.' + attention_mask = torch.ones(1, _PAST_SEQ_LEN + 1, dtype=torch.long) + position_ids = torch.tensor([[_PAST_SEQ_LEN]], dtype=torch.long) + + # Build a DynamicCache pre-filled with random past KV pairs. + past_key_values = DynamicCache() + for layer_idx in range(num_layers): + past_key_values.update( + torch.randn(1, num_kv_heads, _PAST_SEQ_LEN, head_dim), + torch.randn(1, num_kv_heads, _PAST_SEQ_LEN, head_dim), + layer_idx, + ) + + return ( + input_ids, + attention_mask, + position_ids, + past_key_values, + ), {} diff --git a/test/modules/model/TinyLlamaWithDynamicCache/requirements.txt b/test/modules/model/TinyLlamaWithDynamicCache/requirements.txt new file mode 100644 index 00000000..1e4043b8 --- /dev/null +++ b/test/modules/model/TinyLlamaWithDynamicCache/requirements.txt @@ -0,0 +1 @@ +transformers==4.52.4 diff --git a/test/unit_test/utils_test/test_pytree_utils.py b/test/unit_test/utils_test/test_pytree_utils.py new file mode 100644 index 00000000..ffe9b68d --- /dev/null +++ b/test/unit_test/utils_test/test_pytree_utils.py @@ -0,0 +1,397 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tico/utils/pytree_utils.py. + +Each test class covers one cache type. All tests are skipped when +transformers is not installed so the suite stays green in minimal +environments. +""" + +import unittest + +import torch +import torch.utils._pytree as pytree + +from tico.utils.installed_packages import is_transformers_installed + +_SKIP = not is_transformers_installed() +_SKIP_REASON = "transformers is not installed" + + +def _make_tensor(*shape): + return torch.randn(*shape) + + +# --------------------------------------------------------------------------- +# Helper: round-trip a registered pytree node through flatten → unflatten +# --------------------------------------------------------------------------- + + +def _roundtrip(obj): + """Flatten obj with torch pytree, then unflatten and return.""" + leaves, treespec = pytree.tree_flatten(obj) + return pytree.tree_unflatten(leaves, treespec) + + +# --------------------------------------------------------------------------- +# DynamicCache +# --------------------------------------------------------------------------- + + +@unittest.skipIf(_SKIP, _SKIP_REASON) +class TestRegisterDynamicCache(unittest.TestCase): + def setUp(self): + from tico.utils.pytree_utils import register_dynamic_cache + + register_dynamic_cache() + + def _make_cache(self): + import transformers + from packaging.version import Version + from transformers.cache_utils import DynamicCache + + cache = DynamicCache() + if Version(transformers.__version__) < Version("4.54.0"): + # Legacy attribute-based structure + cache.key_cache = [_make_tensor(1, 4, 8, 16)] + cache.value_cache = [_make_tensor(1, 4, 8, 16)] + else: + # Layer-based structure — populate via standard update call + # so that cache.layers is initialised correctly. + k = _make_tensor(1, 4, 8, 16) + v = _make_tensor(1, 4, 8, 16) + cache.update(k, v, layer_idx=0) + return cache + + def test_roundtrip_leaves_preserved(self): + """Flatten → unflatten keeps all tensor data intact.""" + cache = self._make_cache() + restored = _roundtrip(cache) + + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.assertEqual(len(restored.key_cache), len(cache.key_cache)) + torch.testing.assert_close(restored.key_cache[0], cache.key_cache[0]) + torch.testing.assert_close(restored.value_cache[0], cache.value_cache[0]) + else: + self.assertEqual(len(restored.layers), len(cache.layers)) + + def test_idempotent_registration(self): + """Calling register_dynamic_cache a second time must not raise.""" + from tico.utils.pytree_utils import register_dynamic_cache + + register_dynamic_cache() # second call — should be a silent no-op + + def test_flatten_returns_tensors(self): + """Flattened leaves must all be tensors.""" + cache = self._make_cache() + leaves, _ = pytree.tree_flatten(cache) + for leaf in leaves: + self.assertIsInstance(leaf, torch.Tensor) + + +# --------------------------------------------------------------------------- +# StaticCache +# --------------------------------------------------------------------------- + + +@unittest.skipIf(_SKIP, _SKIP_REASON) +class TestRegisterStaticCache(unittest.TestCase): + def setUp(self): + from tico.utils.pytree_utils import register_static_cache + + register_static_cache() + + def _make_cache(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest("StaticCache with layers API requires transformers >= 4.54.0") + + from transformers import AutoConfig + from transformers.cache_utils import StaticCache + + # Build a minimal config + cfg = AutoConfig.for_model("llama") + cfg.num_hidden_layers = 2 + cfg.num_attention_heads = 4 + cfg.num_key_value_heads = 4 + cfg.head_dim = 8 + cfg.max_position_embeddings = 32 + return StaticCache(config=cfg, max_batch_size=1, max_cache_len=16) + + def test_roundtrip_layers_preserved(self): + cache = self._make_cache() + n_layers_before = len(cache.layers) + restored = _roundtrip(cache) + self.assertEqual(len(restored.layers), n_layers_before) + + def test_idempotent_registration(self): + from tico.utils.pytree_utils import register_static_cache + + register_static_cache() + + +# --------------------------------------------------------------------------- +# StaticLayer +# --------------------------------------------------------------------------- + + +@unittest.skipIf(_SKIP, _SKIP_REASON) +class TestRegisterStaticLayer(unittest.TestCase): + def setUp(self): + from tico.utils.pytree_utils import register_static_layer + + register_static_layer() + + def _make_layer(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest("StaticLayer requires transformers >= 4.54.0") + + from transformers.cache_utils import StaticLayer + + layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=4, head_dim=4) + layer.is_initialized = True + layer.keys = _make_tensor(1, 4, 16, 8) + layer.values = _make_tensor(1, 4, 16, 8) + layer.dtype = layer.keys.dtype + layer.device = layer.keys.device + layer.max_batch_size = 1 + layer.num_heads = 4 + layer.head_dim = 8 + return layer + + def test_roundtrip_tensors_preserved(self): + layer = self._make_layer() + restored = _roundtrip(layer) + torch.testing.assert_close(restored.keys, layer.keys) + torch.testing.assert_close(restored.values, layer.values) + + def test_roundtrip_metadata_preserved(self): + layer = self._make_layer() + restored = _roundtrip(layer) + self.assertEqual(restored.max_cache_len, layer.max_cache_len) + self.assertEqual(restored.num_heads, layer.num_heads) + self.assertEqual(restored.head_dim, layer.head_dim) + + def test_uninitialised_layer_raises(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest("StaticLayer requires transformers >= 4.54.0") + + from tico.utils.pytree_utils import _flatten_static_layer + from transformers.cache_utils import StaticLayer + + layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=2, head_dim=4) + layer.is_initialized = False + with self.assertRaises(ValueError): + _flatten_static_layer(layer) + + def test_idempotent_registration(self): + from tico.utils.pytree_utils import register_static_layer + + register_static_layer() + + +# --------------------------------------------------------------------------- +# DynamicLayer +# --------------------------------------------------------------------------- + + +@unittest.skipIf(_SKIP, _SKIP_REASON) +class TestRegisterDynamicLayer(unittest.TestCase): + def setUp(self): + from tico.utils.pytree_utils import register_dynamic_layer + + register_dynamic_layer() + + def _make_layer(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest("DynamicLayer requires transformers >= 4.54.0") + + from transformers.cache_utils import DynamicLayer + + layer = DynamicLayer() + layer.is_initialized = True + layer.keys = _make_tensor(1, 4, 8, 16) + layer.values = _make_tensor(1, 4, 8, 16) + layer.dtype = layer.keys.dtype + layer.device = layer.keys.device + return layer + + def test_roundtrip_tensors_preserved(self): + layer = self._make_layer() + restored = _roundtrip(layer) + torch.testing.assert_close(restored.keys, layer.keys) + torch.testing.assert_close(restored.values, layer.values) + + def test_roundtrip_metadata_preserved(self): + layer = self._make_layer() + restored = _roundtrip(layer) + self.assertEqual(restored.is_initialized, layer.is_initialized) + self.assertEqual(restored.dtype, layer.dtype) + self.assertEqual(restored.device, layer.device) + + def test_uninitialised_layer_raises(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest("DynamicLayer requires transformers >= 4.54.0") + + from tico.utils.pytree_utils import _flatten_dynamic_layer + from transformers.cache_utils import DynamicLayer + + layer = DynamicLayer() + layer.is_initialized = False + with self.assertRaises(ValueError): + _flatten_dynamic_layer(layer) + + def test_idempotent_registration(self): + from tico.utils.pytree_utils import register_dynamic_layer + + register_dynamic_layer() + + +# --------------------------------------------------------------------------- +# EncoderDecoderCache +# --------------------------------------------------------------------------- + + +@unittest.skipIf(_SKIP, _SKIP_REASON) +class TestRegisterEncoderDecoderCache(unittest.TestCase): + def setUp(self): + from tico.utils.pytree_utils import ( + register_dynamic_cache, + register_encoder_decoder_cache, + ) + + register_dynamic_cache() + register_encoder_decoder_cache() + + def _make_cache(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest( + "EncoderDecoderCache with Layer-based internals requires transformers >= 4.54.0" + ) + + from transformers.cache_utils import DynamicCache, EncoderDecoderCache + + self_cache = DynamicCache() + cross_cache = DynamicCache() + k = _make_tensor(1, 4, 8, 16) + v = _make_tensor(1, 4, 8, 16) + self_cache.update(k, v, layer_idx=0) + cross_cache.update(k.clone(), v.clone(), layer_idx=0) + return EncoderDecoderCache(self_cache, cross_cache) + + def test_roundtrip_self_and_cross_caches_preserved(self): + cache = self._make_cache() + n_self = len(cache.self_attention_cache.layers) + n_cross = len(cache.cross_attention_cache.layers) + restored = _roundtrip(cache) + self.assertEqual(len(restored.self_attention_cache.layers), n_self) + self.assertEqual(len(restored.cross_attention_cache.layers), n_cross) + + def test_idempotent_registration(self): + from tico.utils.pytree_utils import register_encoder_decoder_cache + + register_encoder_decoder_cache() + + +# --------------------------------------------------------------------------- +# Consistent flatten key paths +# --------------------------------------------------------------------------- + + +@unittest.skipIf(_SKIP, _SKIP_REASON) +class TestFlattenKeyPaths(unittest.TestCase): + """The _flatten_with_keys_* helpers must return keys that match the + children produced by the main _flatten_* function.""" + + def test_dynamic_layer_keys_match_children(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest("DynamicLayer requires transformers >= 4.54.0") + + from tico.utils.pytree_utils import ( + _flatten_dynamic_layer, + _flatten_with_keys_dynamic_layer, + ) + from transformers.cache_utils import DynamicLayer + + layer = DynamicLayer() + layer.is_initialized = True + layer.keys = _make_tensor(1, 2, 4, 8) + layer.values = _make_tensor(1, 2, 4, 8) + layer.dtype = layer.keys.dtype + layer.device = layer.keys.device + + children, _ = _flatten_dynamic_layer(layer) + keyed, _ = _flatten_with_keys_dynamic_layer(layer) + + self.assertEqual(len(keyed), len(children)) + for (_, tensor_keyed), tensor_plain in zip(keyed, children): + torch.testing.assert_close(tensor_keyed, tensor_plain) + + def test_static_layer_keys_match_children(self): + import transformers + from packaging.version import Version + + if Version(transformers.__version__) < Version("4.54.0"): + self.skipTest("StaticLayer requires transformers >= 4.54.0") + + from tico.utils.pytree_utils import ( + _flatten_static_layer, + _flatten_with_keys_static_layer, + ) + from transformers.cache_utils import StaticLayer + + layer = StaticLayer(max_cache_len=8, batch_size=1, num_heads=2, head_dim=4) + layer.is_initialized = True + layer.keys = _make_tensor(1, 2, 8, 4) + layer.values = _make_tensor(1, 2, 8, 4) + layer.dtype = layer.keys.dtype + layer.device = layer.keys.device + # layer.max_batch_size = 1 + # layer.num_heads = 2 + # layer.head_dim = 4 + + children, _ = _flatten_static_layer(layer) + keyed, _ = _flatten_with_keys_static_layer(layer) + + self.assertEqual(len(keyed), len(children)) + for (_, tensor_keyed), tensor_plain in zip(keyed, children): + torch.testing.assert_close(tensor_keyed, tensor_plain) + + +if __name__ == "__main__": + unittest.main() diff --git a/tico/utils/pytree_utils.py b/tico/utils/pytree_utils.py index 7645a675..8117b584 100644 --- a/tico/utils/pytree_utils.py +++ b/tico/utils/pytree_utils.py @@ -1,134 +1,486 @@ -import threading +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Tuple import torch -from packaging.version import Version +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree from tico.utils import logging from tico.utils.installed_packages import is_transformers_installed -__all__ = ["register_dynamic_cache"] +__all__ = [ + "register_dynamic_cache", + "register_static_cache", + "register_dynamic_layer", + "register_static_layer", + "register_encoder_decoder_cache", +] + +################################################################################## +# [Dynamic/Static Cache Registration] +# Reference: https://github.com/Samsung/TICO/issues/417 +# +# {transformer-version}: {status (implementation detail)} +# 4.51.0: ✅ Pre-registered (list-based) +# 4.51.1: ✅ Pre-registered (list-based) +# 4.52.0: ✅ Pre-registered (list-based) +# 4.52.1: ✅ Pre-registered (list-based) +# 4.53.0: ✅ Pre-registered (list-based) +# 4.54.0: ✅ Pre-registered (layers-based) +# 4.54.1: ✅ Pre-registered (layers-based) +# 4.55.0: ✅ Pre-registered (layers-based) +# 4.55.1: ✅ Pre-registered (layers-based) +# 4.56.0: ❌ Not registered (layers-based) +# 4.56.1: ❌ Not registered (layers-based) +# 4.57.0: ❌ Not registered (layers-based) +# 4.57.1: ❌ Not registered (layers-based) +# 4.57.2: ❌ Not registered (layers-based) +# 4.57.3: ❌ Not registered (layers-based) +# 4.57.4: ❌ Not registered (layers-based) +# 4.57.5: ❌ Not registered (layers-based) +# 4.57.6: ❌ Not registered (layers-based) +# 5.0.0: ❌ Not registered (layers-based) +# +################################################################################## + +################################################################################## +# +# All _flatten_* / _unflatten_* helpers are defined at module scope (not inside +# functions) so that torch pytree serialization can locate them by name. +# +# Convention for every cache type: +# _flatten_ -> (children, aux_data) [main pytree API] +# _unflatten_ -> reconstructed object +# _flatten_with_keys_ -> keyed children list [for pytree.register_pytree_node] +# _flatten__for_fx -> flat list [for fx_pytree.register_pytree_flatten_spec]\ +# +################################################################################## + + +# --------------------------------------------------------------------------- +# StaticCache +# --------------------------------------------------------------------------- + + +def _flatten_static_cache(cache) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + children = (cache.layers,) + aux_data = { + "layer_class_to_replicate": getattr(cache, "layer_class_to_replicate", None), + "offloading": getattr(cache, "offloading", False), + } + return children, aux_data + + +def _unflatten_static_cache(children: Tuple[Any, ...], aux_data: Dict[str, Any]): + from transformers.cache_utils import StaticCache + + instance = StaticCache.__new__(StaticCache) + (instance.layers,) = children + for key, value in aux_data.items(): + setattr(instance, key, value) + return instance + + +def _flatten_with_keys_static_cache(cache): + children, aux_data = _flatten_static_cache(cache) + return [(pytree.GetAttrKey("layers"), children[0])], aux_data + + +def _flatten_static_cache_for_fx(cache, spec): + children, _ = _flatten_static_cache(cache) + return list(children) + + +def register_static_cache(): + # StaticCache uses a layers-based structure only when StaticLayer is available + # (transformers >= ~4.57). On older versions _flatten_static_cache with non-existent + # '.layers' attribute must have been pre-registered to pytree, so we skip registration. + try: + from transformers.cache_utils import StaticCache, StaticLayer # noqa: F401 + except ImportError: + logger = logging.getLogger(__name__) + logger.debug( + "StaticCache / StaticLayer not available in this transformers version; " + "skipping StaticCache pytree registration." + ) + return + + try: + pytree.register_pytree_node( + StaticCache, + _flatten_static_cache, + _unflatten_static_cache, + serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_cache, + ) + fx_pytree.register_pytree_flatten_spec( + StaticCache, _flatten_static_cache_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.debug(f"StaticCache is already registered as pytree flattenable. {e}") + + +# --------------------------------------------------------------------------- +# StaticLayer +# --------------------------------------------------------------------------- + + +def _flatten_static_layer(layer) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """Split a StaticLayer into (tensor children, static metadata).""" + if not layer.is_initialized: + raise ValueError( + f"{layer} cannot be flattened. StaticLayer must be initialized " + "with tensors of a specific shape before use with torch.export." + ) + children = (layer.keys, layer.values) + aux_data: Dict[str, Any] = { + "max_cache_len": layer.max_cache_len, + "is_initialized": layer.is_initialized, + "dtype": layer.keys.dtype, + "device": layer.keys.device, + "max_batch_size": layer.max_batch_size, + "num_heads": layer.num_heads, + "head_dim": layer.head_dim, + } + return children, aux_data + + +def _unflatten_static_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any]): + """Reconstruct a StaticLayer from flattened data.""" + from transformers.cache_utils import StaticLayer + + keys, values = children + obj = StaticLayer( + max_cache_len=aux_data["max_cache_len"], + batch_size=aux_data["max_batch_size"], + num_heads=aux_data["num_heads"], + head_dim=aux_data["head_dim"], + ) + obj.is_initialized = aux_data["is_initialized"] + obj.keys = keys + obj.values = values + obj.dtype = aux_data["dtype"] + obj.device = aux_data["device"] + obj.max_batch_size = aux_data["max_batch_size"] + obj.num_heads = aux_data["num_heads"] + obj.head_dim = aux_data["head_dim"] + return obj + + +def _flatten_with_keys_static_layer(layer): + children, aux_data = _flatten_static_layer(layer) + return [ + (pytree.GetAttrKey("keys"), children[0]), + (pytree.GetAttrKey("values"), children[1]), + ], aux_data + + +def _flatten_static_layer_for_fx(layer, spec): + children, _ = _flatten_static_layer(layer) + return list(children) + + +def register_static_layer(): + try: + from transformers.cache_utils import StaticLayer + except ImportError: + logger = logging.getLogger(__name__) + logger.debug( + "StaticLayer not available in this transformers version; " + "skipping StaticLayer pytree registration." + ) + return + + try: + pytree.register_pytree_node( + StaticLayer, + _flatten_static_layer, + _unflatten_static_layer, + serialized_type_name=f"{StaticLayer.__module__}.{StaticLayer.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_layer, + ) + fx_pytree.register_pytree_flatten_spec( + StaticLayer, _flatten_static_layer_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.debug(f"StaticLayer is already registered as pytree flattenable. {e}") + + +# --------------------------------------------------------------------------- +# DynamicLayer +# --------------------------------------------------------------------------- + + +def _flatten_dynamic_layer(layer) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + if not layer.is_initialized: + raise ValueError( + f"{layer} cannot be flattened. DynamicLayer must be initialized " + "with tensors of a specific shape before use with torch.export." + ) + children = (layer.keys, layer.values) + aux_data: Dict[str, Any] = { + "is_initialized": layer.is_initialized, + "dtype": layer.keys.dtype, + "device": layer.keys.device, + } + return children, aux_data + + +def _unflatten_dynamic_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any]): + from transformers.cache_utils import DynamicLayer + + keys, values = children + obj = DynamicLayer() + obj.keys = keys + obj.values = values + obj.is_initialized = aux_data["is_initialized"] + obj.dtype = aux_data["dtype"] + obj.device = aux_data["device"] + return obj + + +def _flatten_with_keys_dynamic_layer(layer): + children, aux_data = _flatten_dynamic_layer(layer) + return [ + (pytree.GetAttrKey("keys"), children[0]), + (pytree.GetAttrKey("values"), children[1]), + ], aux_data + + +def _flatten_dynamic_layer_for_fx(layer, spec): + children, _ = _flatten_dynamic_layer(layer) + return list(children) + + +def register_dynamic_layer(): + try: + from transformers.cache_utils import DynamicLayer + except ImportError: + logger = logging.getLogger(__name__) + logger.debug( + "DynamicLayer not available in this transformers version; " + "skipping DynamicLayer pytree registration." + ) + return + + try: + pytree.register_pytree_node( + DynamicLayer, + _flatten_dynamic_layer, + _unflatten_dynamic_layer, + serialized_type_name=f"{DynamicLayer.__module__}.{DynamicLayer.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_layer, + ) + fx_pytree.register_pytree_flatten_spec( + DynamicLayer, _flatten_dynamic_layer_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.debug(f"DynamicLayer is already registered as pytree flattenable. {e}") + + +# --------------------------------------------------------------------------- +# DynamicCache +# --------------------------------------------------------------------------- + + +def _flatten_dynamic_cache(cache) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + children = (cache.layers,) + aux_data = { + "layer_class_to_replicate": getattr(cache, "layer_class_to_replicate", None), + "offloading": getattr(cache, "offloading", False), + } + return children, aux_data + + +def _unflatten_dynamic_cache(children: Tuple[Any, ...], aux_data: Dict[str, Any]): + from transformers.cache_utils import DynamicCache + + instance = DynamicCache.__new__(DynamicCache) + (instance.layers,) = children + for key, value in aux_data.items(): + setattr(instance, key, value) + return instance + + +def _flatten_with_keys_dynamic_cache(cache): + children, aux_data = _flatten_dynamic_cache(cache) + return [(pytree.GetAttrKey("layers"), children[0])], aux_data + + +def _flatten_dynamic_cache_for_fx(cache, spec): + children, _ = _flatten_dynamic_cache(cache) + return list(children) + + +# Legacy flatten/unflatten for transformers versions that do not have +# DynamicLayer (e.g. <= 4.52.x), which store tensors directly in +# key_cache / value_cache instead of the Layer-based cache.layers structure. +def _flatten_dynamic_cache_legacy(cache) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + children = (cache.key_cache, cache.value_cache) + aux_data: Dict[str, Any] = {} + return children, aux_data + + +def _unflatten_dynamic_cache_legacy( + children: Tuple[Any, ...], aux_data: Dict[str, Any] +): + from transformers.cache_utils import DynamicCache + + key_cache, value_cache = children + cache = DynamicCache() + cache.key_cache = key_cache + cache.value_cache = value_cache + return cache + + +def _flatten_with_keys_dynamic_cache_legacy(cache): + children, aux_data = _flatten_dynamic_cache_legacy(cache) + return [ + (pytree.GetAttrKey("key_cache"), children[0]), + (pytree.GetAttrKey("value_cache"), children[1]), + ], aux_data + + +def _flatten_dynamic_cache_for_fx_legacy(cache, spec): + children, _ = _flatten_dynamic_cache_legacy(cache) + return list(children) def register_dynamic_cache(): - PyTreeRegistryHelper().register_dynamic_cache() + """Register DynamicCache as a pytree node. + Two layouts exist across transformers versions: -class PyTreeRegistryHelper: - """ - Thread-safe singleton helper class for registering custom PyTree nodes. + * **Layer-based** (newer, requires ``DynamicLayer`` to be importable): + ``cache.layers`` is a list of ``DynamicLayer`` objects; each layer + holds ``keys`` and ``values`` tensors. Both ``DynamicCache`` and + ``DynamicLayer`` must be registered as pytree nodes for + ``torch.export`` to trace through the cache. - This class provides functionality to register DynamicCache as a PyTree node - for torch.export compatibility. This registration is only needed for - transformers versions below 4.50.0. + * **Legacy** (older, e.g. transformers <= 4.52.x): + The cache stores tensors directly in ``cache.key_cache`` and + ``cache.value_cache`` lists. No ``DynamicLayer`` class exists. - Thread Safety: - - Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation - - Uses the same lock to protect the registration process from concurrent calls + The correct layout is detected by checking whether ``DynamicLayer`` can + be imported rather than by comparing version strings, which have proven + unreliable across patch releases. """ + if not is_transformers_installed: # type: ignore[truthy-function] + raise ImportError("transformers package is not installed") + + from transformers.cache_utils import DynamicCache + + # Feature-detect the Layer-based layout. + try: + from transformers.cache_utils import DynamicLayer as _DL # noqa: F401 + + _has_dynamic_layer = True + except ImportError: + _has_dynamic_layer = False - _instance = None # Class variable to hold the singleton instance - _has_called = False # Flag to track if registration has been performed - _lock = threading.Lock() # Class-level lock for thread-safe operations - - def __init__(self): - """Private constructor to prevent direct instantiation""" - pass - - def __new__(cls, *args, **kwargs): - """ - Thread-safe singleton instance creation using double-checked locking pattern. - - Returns: - PyTreeRegistryHelper: The singleton instance of this class - """ - if not cls._instance: - with cls._lock: # Acquire lock for thread-safe instantiation - if not cls._instance: # Double-check after acquiring lock - cls._instance = super().__new__(cls) - return cls._instance - - def register_dynamic_cache(self): - """ - Registers DynamicCache as a PyTree node for torch.export compatibility. - - This method is thread-safe and idempotent - it will only perform the - registration once, even if called multiple times from different threads. - - Note: - This registration is only needed for transformers versions below 4.50.0. - - Raises: - ImportError: If transformers package is not installed - """ - with self._lock: # Acquire lock for thread-safe registration - if self.__class__._has_called: - logger = logging.getLogger(__name__) - logger.debug("register_dynamic_cache already called, skipping") - return - - self.__class__._has_called = True + if not _has_dynamic_layer: + # Legacy layout: flatten key_cache / value_cache directly. + try: + pytree.register_pytree_node( + DynamicCache, + _flatten_dynamic_cache_legacy, + _unflatten_dynamic_cache_legacy, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_cache_legacy, + ) + fx_pytree.register_pytree_flatten_spec( + DynamicCache, _flatten_dynamic_cache_for_fx_legacy + ) + except ValueError as e: logger = logging.getLogger(__name__) - logger.info("Registering DynamicCache PyTree node") - - if not is_transformers_installed: # type: ignore[truthy-function] - raise ImportError("transformers package is not installed") - - import transformers - - HAS_TRANSFORMERS_LESS_4_50_0 = Version(transformers.__version__) < Version( - "4.50.0" - ) - if not HAS_TRANSFORMERS_LESS_4_50_0: - return - - from transformers.cache_utils import DynamicCache - - def _flatten_dynamic_cache(dynamic_cache: DynamicCache): - if not isinstance(dynamic_cache, DynamicCache): - raise RuntimeError( - "This pytree flattening function should only be applied to DynamicCache" - ) - HAS_TORCH_2_6_0 = Version(torch.__version__) >= Version("2.6.0") - if not HAS_TORCH_2_6_0: - logger = logging.getLogger(__name__) - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten(dictionary) - - def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - for k, v in dictionary.items(): - setattr(cache, k, v) - return cache - - def _flatten_dynamic_cache_for_fx(cache, spec): - dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), - } - return torch.fx._pytree._dict_flatten_spec(dictionary, spec) - - torch.utils._pytree.register_pytree_node( + logger.debug( + f"DynamicCache is already registered as pytree flattenable. {e}" + ) + return + + try: + pytree.register_pytree_node( DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache, serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, ) - # TODO: This won't be needed in torch 2.7+. - torch.fx._pytree.register_pytree_flatten_spec( + fx_pytree.register_pytree_flatten_spec( DynamicCache, _flatten_dynamic_cache_for_fx ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.debug(f"DynamicCache is already registered as pytree flattenable. {e}") + + +# --------------------------------------------------------------------------- +# EncoderDecoderCache +# --------------------------------------------------------------------------- + + +def _flatten_encoder_decoder_cache(cache) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + children = (cache.self_attention_cache, cache.cross_attention_cache) + aux_data: Dict[str, Any] = {} + return children, aux_data + + +def _unflatten_encoder_decoder_cache( + children: Tuple[Any, ...], aux_data: Dict[str, Any] +): + from transformers.cache_utils import EncoderDecoderCache + + self_cache, cross_cache = children + return EncoderDecoderCache(self_cache, cross_cache) + + +def _flatten_with_keys_encoder_decoder_cache(cache): + children, aux_data = _flatten_encoder_decoder_cache(cache) + return [ + (pytree.GetAttrKey("self_attention_cache"), children[0]), + (pytree.GetAttrKey("cross_attention_cache"), children[1]), + ], aux_data + + +def _flatten_encoder_decoder_cache_for_fx(cache, spec): + children, _ = _flatten_encoder_decoder_cache(cache) + return list(children) + + +def register_encoder_decoder_cache(): + from transformers.cache_utils import EncoderDecoderCache + + try: + pytree.register_pytree_node( + EncoderDecoderCache, + _flatten_encoder_decoder_cache, + _unflatten_encoder_decoder_cache, + serialized_type_name=( + f"{EncoderDecoderCache.__module__}.{EncoderDecoderCache.__name__}" + ), + flatten_with_keys_fn=_flatten_with_keys_encoder_decoder_cache, + ) + fx_pytree.register_pytree_flatten_spec( + EncoderDecoderCache, _flatten_encoder_decoder_cache_for_fx + ) + except ValueError as e: + logger = logging.getLogger(__name__) + logger.debug( + f"EncoderDecoderCache is already registered as pytree flattenable. {e}" + ) From 349cea70e0f25f2e12c2a73117379ba9f3a894b4 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Tue, 10 Mar 2026 17:32:15 +0900 Subject: [PATCH 2/3] fix --- .../unit_test/utils_test/test_pytree_utils.py | 25 +++++--- tico/utils/pytree_utils.py | 59 ++++++++++++------- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/test/unit_test/utils_test/test_pytree_utils.py b/test/unit_test/utils_test/test_pytree_utils.py index ffe9b68d..ca0faa33 100644 --- a/test/unit_test/utils_test/test_pytree_utils.py +++ b/test/unit_test/utils_test/test_pytree_utils.py @@ -53,9 +53,13 @@ def _roundtrip(obj): @unittest.skipIf(_SKIP, _SKIP_REASON) class TestRegisterDynamicCache(unittest.TestCase): def setUp(self): - from tico.utils.pytree_utils import register_dynamic_cache + from tico.utils.pytree_utils import ( + register_dynamic_cache, + register_dynamic_layer, + ) register_dynamic_cache() + register_dynamic_layer() def _make_cache(self): import transformers @@ -168,7 +172,10 @@ def _make_layer(self): from transformers.cache_utils import StaticLayer - layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=4, head_dim=4) + if Version(torch.__version__) >= Version("2.10.0"): + layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=4, head_dim=4) + else: + layer = StaticLayer(max_cache_len=16) layer.is_initialized = True layer.keys = _make_tensor(1, 4, 16, 8) layer.values = _make_tensor(1, 4, 16, 8) @@ -202,7 +209,10 @@ def test_uninitialised_layer_raises(self): from tico.utils.pytree_utils import _flatten_static_layer from transformers.cache_utils import StaticLayer - layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=2, head_dim=4) + if Version(torch.__version__) >= Version("2.10.0"): + layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=2, head_dim=4) + else: + layer = StaticLayer(max_cache_len=16) layer.is_initialized = False with self.assertRaises(ValueError): _flatten_static_layer(layer) @@ -375,15 +385,16 @@ def test_static_layer_keys_match_children(self): ) from transformers.cache_utils import StaticLayer - layer = StaticLayer(max_cache_len=8, batch_size=1, num_heads=2, head_dim=4) + if Version(torch.__version__) >= Version("2.10.0"): + layer = StaticLayer(max_cache_len=8, batch_size=1, num_heads=2, head_dim=4) + else: + layer = StaticLayer(max_cache_len=8) + layer.is_initialized = True layer.keys = _make_tensor(1, 2, 8, 4) layer.values = _make_tensor(1, 2, 8, 4) layer.dtype = layer.keys.dtype layer.device = layer.keys.device - # layer.max_batch_size = 1 - # layer.num_heads = 2 - # layer.head_dim = 4 children, _ = _flatten_static_layer(layer) keyed, _ = _flatten_with_keys_static_layer(layer) diff --git a/tico/utils/pytree_utils.py b/tico/utils/pytree_utils.py index 8117b584..03dc21b5 100644 --- a/tico/utils/pytree_utils.py +++ b/tico/utils/pytree_utils.py @@ -141,20 +141,26 @@ def register_static_cache(): def _flatten_static_layer(layer) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """Split a StaticLayer into (tensor children, static metadata).""" - if not layer.is_initialized: + if not getattr(layer, "is_initialized", False): raise ValueError( f"{layer} cannot be flattened. StaticLayer must be initialized " "with tensors of a specific shape before use with torch.export." ) + + # 1. convert tracing tensors children = (layer.keys, layer.values) + + # 2. extract aux with compatibility aux_data: Dict[str, Any] = { - "max_cache_len": layer.max_cache_len, - "is_initialized": layer.is_initialized, - "dtype": layer.keys.dtype, - "device": layer.keys.device, - "max_batch_size": layer.max_batch_size, - "num_heads": layer.num_heads, - "head_dim": layer.head_dim, + "max_cache_len": getattr(layer, "max_cache_len", 0), + "is_initialized": getattr(layer, "is_initialized", False), + "dtype": getattr(layer, "dtype", None), + "device": getattr(layer, "device", None), + "max_batch_size": getattr( + layer, "max_batch_size", getattr(layer, "batch_size", 0) + ), + "num_heads": getattr(layer, "num_heads", 0), + "head_dim": getattr(layer, "head_dim", 0), } return children, aux_data @@ -164,28 +170,37 @@ def _unflatten_static_layer(children: Tuple[Any, ...], aux_data: Dict[str, Any]) from transformers.cache_utils import StaticLayer keys, values = children - obj = StaticLayer( - max_cache_len=aux_data["max_cache_len"], - batch_size=aux_data["max_batch_size"], - num_heads=aux_data["num_heads"], - head_dim=aux_data["head_dim"], - ) - obj.is_initialized = aux_data["is_initialized"] + + # avoid __init__ signature mismatch error, depending on torch version + obj = object.__new__(StaticLayer) + + for key, value in aux_data.items(): + if value is not None: + setattr(obj, key, value) + + # For compatibility + if not hasattr(obj, "batch_size") and "max_batch_size" in aux_data: + setattr(obj, "batch_size", aux_data["max_batch_size"]) + obj.keys = keys obj.values = values - obj.dtype = aux_data["dtype"] - obj.device = aux_data["device"] - obj.max_batch_size = aux_data["max_batch_size"] - obj.num_heads = aux_data["num_heads"] - obj.head_dim = aux_data["head_dim"] + return obj def _flatten_with_keys_static_layer(layer): children, aux_data = _flatten_static_layer(layer) + + # In casse of GetAttrKey does not exist in old torch version + KeyClass = getattr(pytree, "GetAttrKey", getattr(pytree, "MappingKey", None)) + + if KeyClass is None: + # In case of very old torch version + return [("keys", children[0]), ("values", children[1])], aux_data + return [ - (pytree.GetAttrKey("keys"), children[0]), - (pytree.GetAttrKey("values"), children[1]), + (KeyClass("keys"), children[0]), + (KeyClass("values"), children[1]), ], aux_data From 52c1cf4f9a5bded91e5c09a168efde3e3a89df87 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Tue, 10 Mar 2026 18:18:05 +0900 Subject: [PATCH 3/3] fix --- .../unit_test/utils_test/test_pytree_utils.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/test/unit_test/utils_test/test_pytree_utils.py b/test/unit_test/utils_test/test_pytree_utils.py index ca0faa33..c7f8da84 100644 --- a/test/unit_test/utils_test/test_pytree_utils.py +++ b/test/unit_test/utils_test/test_pytree_utils.py @@ -155,6 +155,26 @@ def test_idempotent_registration(self): # StaticLayer # --------------------------------------------------------------------------- +import inspect + + +def _create_static_layer(**potential_kwargs): + # torch ver | + # ----------|------ + # 2.6.0 | requires max_cache_len + # ... + # 2.10.0 | requires max_cache_len, batch_size, num_heads, head_dim + # 2.12.0.dev| requires max_cache_len + from transformers.cache_utils import StaticLayer + + sig = inspect.signature(StaticLayer.__init__) + valid_params = set(sig.parameters.keys()) + + init_kwargs = {k: v for k, v in potential_kwargs.items() if k in valid_params} + + obj = StaticLayer(**init_kwargs) + return obj + @unittest.skipIf(_SKIP, _SKIP_REASON) class TestRegisterStaticLayer(unittest.TestCase): @@ -172,10 +192,9 @@ def _make_layer(self): from transformers.cache_utils import StaticLayer - if Version(torch.__version__) >= Version("2.10.0"): - layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=4, head_dim=4) - else: - layer = StaticLayer(max_cache_len=16) + layer = _create_static_layer( + max_cache_len=16, batch_size=1, num_heads=4, head_dim=4 + ) layer.is_initialized = True layer.keys = _make_tensor(1, 4, 16, 8) layer.values = _make_tensor(1, 4, 16, 8) @@ -209,10 +228,9 @@ def test_uninitialised_layer_raises(self): from tico.utils.pytree_utils import _flatten_static_layer from transformers.cache_utils import StaticLayer - if Version(torch.__version__) >= Version("2.10.0"): - layer = StaticLayer(max_cache_len=16, batch_size=1, num_heads=2, head_dim=4) - else: - layer = StaticLayer(max_cache_len=16) + layer = _create_static_layer( + max_cache_len=16, batch_size=1, num_heads=2, head_dim=4 + ) layer.is_initialized = False with self.assertRaises(ValueError): _flatten_static_layer(layer) @@ -385,10 +403,9 @@ def test_static_layer_keys_match_children(self): ) from transformers.cache_utils import StaticLayer - if Version(torch.__version__) >= Version("2.10.0"): - layer = StaticLayer(max_cache_len=8, batch_size=1, num_heads=2, head_dim=4) - else: - layer = StaticLayer(max_cache_len=8) + layer = _create_static_layer( + max_cache_len=8, batch_size=1, num_heads=2, head_dim=4 + ) layer.is_initialized = True layer.keys = _make_tensor(1, 2, 8, 4)