From 81851e25814229f3a17958a1d2bf8d69ba171520 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Wed, 18 Mar 2026 12:45:19 +0000 Subject: [PATCH 01/21] feat: Add FallbackPlugin for transparent model fallback on specific HTTP errors. --- src/google/adk_community/plugins/__init__.py | 19 ++ .../adk_community/plugins/fallback_plugin.py | 213 ++++++++++++++++++ tests/unittests/plugins/__init__.py | 13 ++ .../unittests/plugins/test_fallback_plugin.py | 166 ++++++++++++++ 4 files changed, 411 insertions(+) create mode 100644 src/google/adk_community/plugins/__init__.py create mode 100644 src/google/adk_community/plugins/fallback_plugin.py create mode 100644 tests/unittests/plugins/__init__.py create mode 100644 tests/unittests/plugins/test_fallback_plugin.py diff --git a/src/google/adk_community/plugins/__init__.py b/src/google/adk_community/plugins/__init__.py new file mode 100644 index 00000000..0b4e3431 --- /dev/null +++ b/src/google/adk_community/plugins/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Community plugins for ADK.""" + +from .fallback_plugin import FallbackPlugin + +__all__ = ["FallbackPlugin"] diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py new file mode 100644 index 00000000..1017b759 --- /dev/null +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -0,0 +1,213 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional + +from opentelemetry import trace + +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from ..version import __version__ +from google.adk.plugins.base_plugin import BasePlugin + +logger: logging.Logger = logging.getLogger("google_adk." + __name__) +tracer = trace.get_tracer("google.adk.plugins.fallback_plugin", __version__) + +_FALLBACK_ATTEMPTS_MAX_SIZE = 100 +_FALLBACK_ATTEMPTS_PRUNE_COUNT = 50 + + +class FallbackPlugin(BasePlugin): + """Plugin that implements transparent model fallback on specific HTTP errors. + + This plugin intercepts LLM requests and responses to provide automatic model + fallback when the primary model returns a configured error code (e.g., rate + limit or timeout). Fallback is **non-persistent**: every new request always + starts with the ``root_model``; only that particular request may be retried + with the ``fallback_model``. + + The plugin itself does not re-issue the request. The actual retry must be + handled by the underlying model implementation (e.g. LiteLlm's ``fallbacks`` + parameter). This plugin is responsible for: + + - Resetting the model to ``root_model`` at the start of each request so that + fallback state does not leak across turns. + - Detecting error responses whose ``error_code`` is in ``error_status`` and + annotating the ``LlmResponse`` with structured fallback metadata so that + the caller or the model layer can take remedial action. + - Tracking the number of fallback attempts per request context and + pruning the tracking dictionary to avoid unbounded memory growth. + + Example: + >>> from google.adk.plugins.fallback_plugin import FallbackPlugin + >>> fallback_plugin = FallbackPlugin( + ... root_model="gemini-2.0-flash", + ... fallback_model="gemini-1.5-pro", + ... error_status=[429, 504], + ... ) + >>> runner = Runner( + ... agents=[my_agent], + ... plugins=[fallback_plugin], + ... ) + """ + + def __init__( + self, + name: str = "fallback_plugin", + root_model: Optional[str] = None, + fallback_model: Optional[str] = None, + error_status: Optional[list[int]] = None, # noqa: B006 + ) -> None: + """Initializes the FallbackPlugin. + + Args: + name: The name of the plugin. Defaults to ``"fallback_plugin"``. + root_model: The primary model identifier that every request should start + with. When ``None`` the plugin does not override the model set on the + request. + fallback_model: The backup model identifier to record in the response + metadata when an error matching ``error_status`` is detected. When + ``None`` the plugin logs a warning but does not write any metadata. + error_status: A list of HTTP-style numeric status codes that should be + treated as retriable failures and trigger fallback tracking. Defaults + to ``[429, 504]``. + """ + super().__init__(name=name) + self.root_model = root_model + self.fallback_model = fallback_model + self.error_status = error_status if error_status is not None else [429, 504] + self._error_status_set = {str(s) for s in self.error_status} + + # Maps id(callback_context) -> number of fallback attempts for that context. + self._fallback_attempts: dict[int, int] = {} + + async def before_model_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + """Resets the request model to ``root_model`` before each LLM call. + + This callback is invoked before every LLM request. It ensures non-persistent + fallback behaviour by unconditionally resetting the model to ``root_model`` + whenever no fallback attempt is currently in progress for this context, + so that a fallback from a previous turn cannot bleed into a new one. + + Args: + callback_context: The context for the current agent call. Used as the key + for tracking per-request fallback state. + llm_request: The prepared request object about to be sent to the model. + Its ``model`` field may be mutated to enforce the ``root_model``. + + Returns: + ``None`` always, so that normal LLM processing continues. + """ + context_id = id(callback_context) + + # Initialise the attempt counter for this context on first contact. + if context_id not in self._fallback_attempts: + self._fallback_attempts[context_id] = 0 + + # Only reset to root_model when we are NOT mid-fallback. + if self.root_model and self._fallback_attempts.get(context_id, 0) == 0: + if hasattr(llm_request, "model") and llm_request.model != self.root_model: + logger.info( + "Resetting model from %s to root model: %s", + llm_request.model, + self.root_model, + ) + llm_request.model = self.root_model + + return await super().before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + + async def after_model_callback( + self, + *, + callback_context: CallbackContext, + llm_response: LlmResponse, + ) -> Optional[LlmResponse]: + """Detects retriable errors and annotates the response with fallback metadata. + + This callback is invoked after every LLM response. When the response + carries an ``error_code`` that matches one of the configured ``error_status`` + codes **and** a ``fallback_model`` is configured, the plugin writes the + following keys into ``llm_response.custom_metadata``: + + - ``fallback_triggered`` (``bool``): Always ``True``. + - ``original_model`` (``str``): The value of ``root_model``. + - ``fallback_model`` (``str``): The value of ``fallback_model``. + - ``fallback_attempt`` (``int``): The cumulative attempt count for this + context. + - ``error_code`` (``str``): The string representation of the error code. + + The tracking dictionary is pruned to at most 50 entries whenever its size + exceeds 100, to prevent unbounded memory growth in long-running processes. + + Args: + callback_context: The context for the current agent call. Used as the key + for tracking per-request fallback state. + llm_response: The response received from the model. Its + ``custom_metadata`` field may be populated with fallback tracking data. + + Returns: + ``None`` always, so that normal post-model processing continues. + """ + context_id = id(callback_context) + + if llm_response.error_code and str(llm_response.error_code) in self._error_status_set: + logger.warning( + "Model call failed with error code %s. Error message: %s", + llm_response.error_code, + llm_response.error_message, + ) + + self._fallback_attempts[context_id] = ( + self._fallback_attempts.get(context_id, 0) + 1 + ) + + if self.fallback_model: + logger.info( + "Fallback triggered: %s -> %s (attempt %d)", + self.root_model, + self.fallback_model, + self._fallback_attempts[context_id], + ) + if not llm_response.custom_metadata: + llm_response.custom_metadata = {} + llm_response.custom_metadata["fallback_triggered"] = True + llm_response.custom_metadata["original_model"] = self.root_model + llm_response.custom_metadata["fallback_model"] = self.fallback_model + llm_response.custom_metadata["fallback_attempt"] = ( + self._fallback_attempts[context_id] + ) + llm_response.custom_metadata["error_code"] = str(llm_response.error_code) + else: + logger.warning("No fallback model configured, cannot retry.") + + # Prune the tracking dict to avoid unbounded memory growth. + if len(self._fallback_attempts) > _FALLBACK_ATTEMPTS_MAX_SIZE: + oldest_keys = list(self._fallback_attempts.keys())[:_FALLBACK_ATTEMPTS_PRUNE_COUNT] + for key in oldest_keys: + del self._fallback_attempts[key] + + return await super().after_model_callback( + callback_context=callback_context, llm_response=llm_response + ) \ No newline at end of file diff --git a/tests/unittests/plugins/__init__.py b/tests/unittests/plugins/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/tests/unittests/plugins/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py new file mode 100644 index 00000000..f37462b0 --- /dev/null +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -0,0 +1,166 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import MagicMock + +from google.adk_community.plugins.fallback_plugin import FallbackPlugin + +class TestFallbackPlugin: + """Test cases for FallbackPlugin.""" + + def test_init_defaults(self): + """Test initialization with default values.""" + plugin = FallbackPlugin() + assert plugin.root_model is None + assert plugin.fallback_model is None + assert plugin.error_status == [429, 504] + assert plugin._error_status_set == {"429", "504"} + assert plugin._fallback_attempts == {} + + def test_init_custom(self): + """Test initialization with custom values.""" + plugin = FallbackPlugin( + root_model="gemini-2.0-flash", + fallback_model="gemini-1.5-pro", + error_status=[400, 500], + ) + assert plugin.root_model == "gemini-2.0-flash" + assert plugin.fallback_model == "gemini-1.5-pro" + assert plugin.error_status == [400, 500] + assert plugin._error_status_set == {"400", "500"} + + @pytest.mark.asyncio + async def test_before_model_callback_initializes_context(self): + """Test that before_model_callback initializes context in fallback attempts dict.""" + plugin = FallbackPlugin() + mock_context = MagicMock() + mock_request = MagicMock() + + await plugin.before_model_callback( + callback_context=mock_context, llm_request=mock_request + ) + + context_id = id(mock_context) + assert context_id in plugin._fallback_attempts + assert plugin._fallback_attempts[context_id] == 0 + + @pytest.mark.asyncio + async def test_before_model_callback_resets_model(self): + """Test that before_model_callback resets model to root_model when attempt is 0.""" + plugin = FallbackPlugin(root_model="root-model") + mock_context = MagicMock() + mock_request = MagicMock(model="current-model") + + await plugin.before_model_callback( + callback_context=mock_context, llm_request=mock_request + ) + + assert mock_request.model == "root-model" + + @pytest.mark.asyncio + async def test_before_model_callback_no_reset_mid_fallback(self): + """Test that before_model_callback does not reset model when attempt > 0.""" + plugin = FallbackPlugin(root_model="root-model") + mock_context = MagicMock() + mock_request = MagicMock(model="fallback-model") + + context_id = id(mock_context) + plugin._fallback_attempts[context_id] = 1 + + await plugin.before_model_callback( + callback_context=mock_context, llm_request=mock_request + ) + + assert mock_request.model == "fallback-model" + + @pytest.mark.asyncio + async def test_after_model_callback_annotates_on_error(self): + """Test that after_model_callback annotates response on error status.""" + plugin = FallbackPlugin(root_model="root-model", fallback_model="fallback-model") + mock_context = MagicMock() + mock_response = MagicMock() + mock_response.error_code = 429 + mock_response.error_message = "Rate limit" + mock_response.custom_metadata = {} + + await plugin.after_model_callback( + callback_context=mock_context, llm_response=mock_response + ) + + assert mock_response.custom_metadata["fallback_triggered"] is True + assert mock_response.custom_metadata["original_model"] == "root-model" + assert mock_response.custom_metadata["fallback_model"] == "fallback-model" + assert mock_response.custom_metadata["fallback_attempt"] == 1 + assert mock_response.custom_metadata["error_code"] == "429" + + @pytest.mark.asyncio + async def test_after_model_callback_no_annotate_on_non_error(self): + """Test that after_model_callback does not annotate on success or non-configured error.""" + plugin = FallbackPlugin(root_model="root-model", fallback_model="fallback-model") + mock_context = MagicMock() + mock_response = MagicMock() + mock_response.error_code = None + mock_response.error_message = None + mock_response.custom_metadata = {} + + await plugin.after_model_callback( + callback_context=mock_context, llm_response=mock_response + ) + + assert "fallback_triggered" not in mock_response.custom_metadata + + @pytest.mark.asyncio + async def test_after_model_callback_no_annotate_no_fallback_model(self): + """Test that after_model_callback does not annotate when fallback_model is None.""" + plugin = FallbackPlugin(root_model="root-model") + mock_context = MagicMock() + mock_response = MagicMock() + mock_response.error_code = 429 + mock_response.error_message = "Rate limit" + mock_response.custom_metadata = {} + + await plugin.after_model_callback( + callback_context=mock_context, llm_response=mock_response + ) + + assert "fallback_triggered" not in mock_response.custom_metadata + + @pytest.mark.asyncio + async def test_after_model_callback_prunes_dict(self): + """Test that after_model_callback prunes the tracking dict when it exceeds max size.""" + plugin = FallbackPlugin() + + # Use a large number of context IDs to trigger pruning + # The limit is 100, prune 50. + for i in range(101): + plugin._fallback_attempts[i] = 1 + + assert len(plugin._fallback_attempts) == 101 + + mock_context = MagicMock() + mock_response = MagicMock() + # Trigger an error to enter the prune condition check + mock_response.error_code = 429 + mock_response.error_message = "Rate limit" + mock_response.custom_metadata = {} + + await plugin.after_model_callback( + callback_context=mock_context, llm_response=mock_response + ) + + # After prune, it should be 102 (if context was new) - 50 = 52 + # Or if context was old, it just prunes 50. + # In any case, it should be <= 100. + assert len(plugin._fallback_attempts) <= 100 From e3916c179fd5cc4062c0e39b70562578cbee015b Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Wed, 18 Mar 2026 13:28:57 +0000 Subject: [PATCH 02/21] feat: Replace fallback attempts dictionary with WeakKeyDictionary for automatic garbage collection of contexts and remove manual pruning logic. --- .../adk_community/plugins/fallback_plugin.py | 29 ++++++-------- .../unittests/plugins/test_fallback_plugin.py | 38 +++++++++---------- 2 files changed, 28 insertions(+), 39 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 1017b759..08b0b16b 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +import weakref from typing import Optional from opentelemetry import trace @@ -93,8 +94,8 @@ def __init__( self.error_status = error_status if error_status is not None else [429, 504] self._error_status_set = {str(s) for s in self.error_status} - # Maps id(callback_context) -> number of fallback attempts for that context. - self._fallback_attempts: dict[int, int] = {} + # Maps callback_context -> number of fallback attempts for that context. + self._fallback_attempts: weakref.WeakKeyDictionary[CallbackContext, int] = weakref.WeakKeyDictionary() async def before_model_callback( self, @@ -118,14 +119,12 @@ async def before_model_callback( Returns: ``None`` always, so that normal LLM processing continues. """ - context_id = id(callback_context) - # Initialise the attempt counter for this context on first contact. - if context_id not in self._fallback_attempts: - self._fallback_attempts[context_id] = 0 + if callback_context not in self._fallback_attempts: + self._fallback_attempts[callback_context] = 0 # Only reset to root_model when we are NOT mid-fallback. - if self.root_model and self._fallback_attempts.get(context_id, 0) == 0: + if self.root_model and self._fallback_attempts.get(callback_context, 0) == 0: if hasattr(llm_request, "model") and llm_request.model != self.root_model: logger.info( "Resetting model from %s to root model: %s", @@ -170,8 +169,6 @@ async def after_model_callback( Returns: ``None`` always, so that normal post-model processing continues. """ - context_id = id(callback_context) - if llm_response.error_code and str(llm_response.error_code) in self._error_status_set: logger.warning( "Model call failed with error code %s. Error message: %s", @@ -179,8 +176,8 @@ async def after_model_callback( llm_response.error_message, ) - self._fallback_attempts[context_id] = ( - self._fallback_attempts.get(context_id, 0) + 1 + self._fallback_attempts[callback_context] = ( + self._fallback_attempts.get(callback_context, 0) + 1 ) if self.fallback_model: @@ -188,7 +185,7 @@ async def after_model_callback( "Fallback triggered: %s -> %s (attempt %d)", self.root_model, self.fallback_model, - self._fallback_attempts[context_id], + self._fallback_attempts[callback_context], ) if not llm_response.custom_metadata: llm_response.custom_metadata = {} @@ -196,17 +193,13 @@ async def after_model_callback( llm_response.custom_metadata["original_model"] = self.root_model llm_response.custom_metadata["fallback_model"] = self.fallback_model llm_response.custom_metadata["fallback_attempt"] = ( - self._fallback_attempts[context_id] + self._fallback_attempts[callback_context] ) llm_response.custom_metadata["error_code"] = str(llm_response.error_code) else: logger.warning("No fallback model configured, cannot retry.") - # Prune the tracking dict to avoid unbounded memory growth. - if len(self._fallback_attempts) > _FALLBACK_ATTEMPTS_MAX_SIZE: - oldest_keys = list(self._fallback_attempts.keys())[:_FALLBACK_ATTEMPTS_PRUNE_COUNT] - for key in oldest_keys: - del self._fallback_attempts[key] + return await super().after_model_callback( callback_context=callback_context, llm_response=llm_response diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py index f37462b0..ebdc6eb2 100644 --- a/tests/unittests/plugins/test_fallback_plugin.py +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -52,9 +52,8 @@ async def test_before_model_callback_initializes_context(self): callback_context=mock_context, llm_request=mock_request ) - context_id = id(mock_context) - assert context_id in plugin._fallback_attempts - assert plugin._fallback_attempts[context_id] == 0 + assert mock_context in plugin._fallback_attempts + assert plugin._fallback_attempts[mock_context] == 0 @pytest.mark.asyncio async def test_before_model_callback_resets_model(self): @@ -76,8 +75,7 @@ async def test_before_model_callback_no_reset_mid_fallback(self): mock_context = MagicMock() mock_request = MagicMock(model="fallback-model") - context_id = id(mock_context) - plugin._fallback_attempts[context_id] = 1 + plugin._fallback_attempts[mock_context] = 1 await plugin.before_model_callback( callback_context=mock_context, llm_request=mock_request @@ -138,29 +136,27 @@ async def test_after_model_callback_no_annotate_no_fallback_model(self): assert "fallback_triggered" not in mock_response.custom_metadata @pytest.mark.asyncio - async def test_after_model_callback_prunes_dict(self): - """Test that after_model_callback prunes the tracking dict when it exceeds max size.""" + async def test_after_model_callback_automatic_pruning(self): + """Test that after_model_callback entries are automatically pruned when context is GC'd.""" + import gc plugin = FallbackPlugin() - # Use a large number of context IDs to trigger pruning - # The limit is 100, prune 50. - for i in range(101): - plugin._fallback_attempts[i] = 1 - - assert len(plugin._fallback_attempts) == 101 - - mock_context = MagicMock() + class CustomContext: + pass + + context = CustomContext() mock_response = MagicMock() - # Trigger an error to enter the prune condition check mock_response.error_code = 429 mock_response.error_message = "Rate limit" mock_response.custom_metadata = {} await plugin.after_model_callback( - callback_context=mock_context, llm_response=mock_response + callback_context=context, llm_response=mock_response ) - # After prune, it should be 102 (if context was new) - 50 = 52 - # Or if context was old, it just prunes 50. - # In any case, it should be <= 100. - assert len(plugin._fallback_attempts) <= 100 + assert context in plugin._fallback_attempts + + del context + gc.collect() # Force GC + + assert len(plugin._fallback_attempts) == 0 From 9c79e71e3dfa1a7dbdadc5be0d85629fe777f046 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 09:18:01 +0000 Subject: [PATCH 03/21] I have resolved the inconsistency in fallback_plugin.py by choosing the second approach: updating the docstrings and removing the unused constants. Since a previous refactor introduced weakref.WeakKeyDictionary to handle memory management via Python's garbage collection, maintaining manual size-based pruning would be redundant and potentially risky (as it might drop tracking state for active, live contexts). Relying on the garbage collector is a cleaner and more robust solution for this use case. Here are the changes that were made: 1. Updated fallback_plugin.py The unused constants were removed, and the docstrings were updated to reflect that pruning is handled automatically as keys are garbage-collected. diff --- src/google/adk_community/plugins/fallback_plugin.py +++ src/google/adk_community/plugins/fallback_plugin.py @@ -32,2 +32,1 @@ -_FALLBACK_ATTEMPTS_MAX_SIZE = 100 -_FALLBACK_ATTEMPTS_PRUNE_COUNT = 50 + @@ -54,2 +53,2 @@ - - Tracking the number of fallback attempts per request context and - pruning the tracking dictionary to avoid unbounded memory growth. + - Tracking the number of fallback attempts per request context using + weak references to prevent unbounded memory growth. @@ -160,2 +158,2 @@ - The tracking dictionary is pruned to at most 50 entries whenever its size - exceeds 100, to prevent unbounded memory growth in long-running processes. + The tracking dictionary uses weak references and is pruned automatically + when contexts are garbage collected, preventing unbounded memory growth. 2. Verification I ran the unit tests for the fallback plugin from the project's virtual environment, and all tests passed: bash /usr/local/google/home/benmizrahi/Documents/adk-python-community/.venv/bin/pytest /usr/local/google/home/benmizrahi/Documents/adk-python-community/tests/unittests/plugins/test_fallback_plugin.py Output: ============================== 9 passed in 5.67s =============================== --- src/google/adk_community/plugins/fallback_plugin.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 08b0b16b..1302962a 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -29,9 +29,6 @@ logger: logging.Logger = logging.getLogger("google_adk." + __name__) tracer = trace.get_tracer("google.adk.plugins.fallback_plugin", __version__) -_FALLBACK_ATTEMPTS_MAX_SIZE = 100 -_FALLBACK_ATTEMPTS_PRUNE_COUNT = 50 - class FallbackPlugin(BasePlugin): """Plugin that implements transparent model fallback on specific HTTP errors. @@ -51,8 +48,8 @@ class FallbackPlugin(BasePlugin): - Detecting error responses whose ``error_code`` is in ``error_status`` and annotating the ``LlmResponse`` with structured fallback metadata so that the caller or the model layer can take remedial action. - - Tracking the number of fallback attempts per request context and - pruning the tracking dictionary to avoid unbounded memory growth. + - Tracking the number of fallback attempts per request context using + weak references to prevent unbounded memory growth. Example: >>> from google.adk.plugins.fallback_plugin import FallbackPlugin @@ -157,8 +154,8 @@ async def after_model_callback( context. - ``error_code`` (``str``): The string representation of the error code. - The tracking dictionary is pruned to at most 50 entries whenever its size - exceeds 100, to prevent unbounded memory growth in long-running processes. + The tracking dictionary uses weak references and is pruned automatically + when contexts are garbage collected, preventing unbounded memory growth. Args: callback_context: The context for the current agent call. Used as the key From 68ce77dceabb6d55acf5a34c735a56c428165cd1 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 09:20:43 +0000 Subject: [PATCH 04/21] refactor: use setdefault to initialize and access the fallback attempt counter --- src/google/adk_community/plugins/fallback_plugin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 1302962a..83b9261d 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -116,12 +116,10 @@ async def before_model_callback( Returns: ``None`` always, so that normal LLM processing continues. """ - # Initialise the attempt counter for this context on first contact. - if callback_context not in self._fallback_attempts: - self._fallback_attempts[callback_context] = 0 + attempt_count = self._fallback_attempts.setdefault(callback_context, 0) # Only reset to root_model when we are NOT mid-fallback. - if self.root_model and self._fallback_attempts.get(callback_context, 0) == 0: + if self.root_model and attempt_count == 0: if hasattr(llm_request, "model") and llm_request.model != self.root_model: logger.info( "Resetting model from %s to root model: %s", From 86e89133e421ebc9f69502f504766499eec931f1 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 09:21:39 +0000 Subject: [PATCH 05/21] refactor: Move `gc` import to the top of `test_fallback_plugin.py`. --- tests/unittests/plugins/test_fallback_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py index ebdc6eb2..b2b86c85 100644 --- a/tests/unittests/plugins/test_fallback_plugin.py +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import pytest from unittest.mock import MagicMock @@ -138,7 +139,6 @@ async def test_after_model_callback_no_annotate_no_fallback_model(self): @pytest.mark.asyncio async def test_after_model_callback_automatic_pruning(self): """Test that after_model_callback entries are automatically pruned when context is GC'd.""" - import gc plugin = FallbackPlugin() class CustomContext: From 83239b6bcf3d916b682e7ca275adfa57e0f78019 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 09:28:14 +0000 Subject: [PATCH 06/21] refactor: simplify model comparison and refine custom metadata initialization in FallbackPlugin. --- src/google/adk_community/plugins/fallback_plugin.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 83b9261d..ee0ab3ba 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -69,7 +69,7 @@ def __init__( name: str = "fallback_plugin", root_model: Optional[str] = None, fallback_model: Optional[str] = None, - error_status: Optional[list[int]] = None, # noqa: B006 + error_status: Optional[list[int]] = None, ) -> None: """Initializes the FallbackPlugin. @@ -120,7 +120,7 @@ async def before_model_callback( # Only reset to root_model when we are NOT mid-fallback. if self.root_model and attempt_count == 0: - if hasattr(llm_request, "model") and llm_request.model != self.root_model: + if llm_request.model != self.root_model: logger.info( "Resetting model from %s to root model: %s", llm_request.model, @@ -182,7 +182,7 @@ async def after_model_callback( self.fallback_model, self._fallback_attempts[callback_context], ) - if not llm_response.custom_metadata: + if llm_response.custom_metadata is None: llm_response.custom_metadata = {} llm_response.custom_metadata["fallback_triggered"] = True llm_response.custom_metadata["original_model"] = self.root_model @@ -193,9 +193,7 @@ async def after_model_callback( llm_response.custom_metadata["error_code"] = str(llm_response.error_code) else: logger.warning("No fallback model configured, cannot retry.") - - - + return await super().after_model_callback( callback_context=callback_context, llm_response=llm_response ) \ No newline at end of file From f9478ceb87540cd0e160abc5ae5cf5ee31055835 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 09:33:23 +0000 Subject: [PATCH 07/21] style: improve readability in `fallback_plugin.py` by splitting a long line. --- src/google/adk_community/plugins/fallback_plugin.py | 5 +++-- tests/unittests/plugins/__init__.py | 1 + tests/unittests/plugins/test_fallback_plugin.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index ee0ab3ba..0d49f0a1 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -193,7 +193,8 @@ async def after_model_callback( llm_response.custom_metadata["error_code"] = str(llm_response.error_code) else: logger.warning("No fallback model configured, cannot retry.") - + return await super().after_model_callback( callback_context=callback_context, llm_response=llm_response - ) \ No newline at end of file + ) + \ No newline at end of file diff --git a/tests/unittests/plugins/__init__.py b/tests/unittests/plugins/__init__.py index 0a2669d7..36a1e8d7 100644 --- a/tests/unittests/plugins/__init__.py +++ b/tests/unittests/plugins/__init__.py @@ -11,3 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py index b2b86c85..91a202f0 100644 --- a/tests/unittests/plugins/test_fallback_plugin.py +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -160,3 +160,4 @@ class CustomContext: gc.collect() # Force GC assert len(plugin._fallback_attempts) == 0 + \ No newline at end of file From 92fb4f9e275bf7679e4ef5fcc062f0710f04e7cc Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 09:42:12 +0000 Subject: [PATCH 08/21] refactor: Improve fallback attempt tracking clarity and update version import path in `FallbackPlugin`. --- .../adk_community/plugins/fallback_plugin.py | 16 +++++++--------- tests/unittests/plugins/test_fallback_plugin.py | 3 ++- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 0d49f0a1..35ceefdf 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -23,7 +23,7 @@ from google.adk.agents.callback_context import CallbackContext from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse -from ..version import __version__ +from google.adk_community.version import __version__ from google.adk.plugins.base_plugin import BasePlugin logger: logging.Logger = logging.getLogger("google_adk." + __name__) @@ -171,25 +171,22 @@ async def after_model_callback( llm_response.error_message, ) - self._fallback_attempts[callback_context] = ( - self._fallback_attempts.get(callback_context, 0) + 1 - ) + attempt_count = self._fallback_attempts.get(callback_context, 0) + 1 + self._fallback_attempts[callback_context] = attempt_count if self.fallback_model: logger.info( "Fallback triggered: %s -> %s (attempt %d)", self.root_model, self.fallback_model, - self._fallback_attempts[callback_context], + attempt_count, ) if llm_response.custom_metadata is None: llm_response.custom_metadata = {} llm_response.custom_metadata["fallback_triggered"] = True llm_response.custom_metadata["original_model"] = self.root_model llm_response.custom_metadata["fallback_model"] = self.fallback_model - llm_response.custom_metadata["fallback_attempt"] = ( - self._fallback_attempts[callback_context] - ) + llm_response.custom_metadata["fallback_attempt"] = attempt_count llm_response.custom_metadata["error_code"] = str(llm_response.error_code) else: logger.warning("No fallback model configured, cannot retry.") @@ -197,4 +194,5 @@ async def after_model_callback( return await super().after_model_callback( callback_context=callback_context, llm_response=llm_response ) - \ No newline at end of file + + diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py index 91a202f0..8801873a 100644 --- a/tests/unittests/plugins/test_fallback_plugin.py +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -160,4 +160,5 @@ class CustomContext: gc.collect() # Force GC assert len(plugin._fallback_attempts) == 0 - \ No newline at end of file + + From 799fa14748fcc43d4f0f908f1562ffcd3de06e98 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 09:51:13 +0000 Subject: [PATCH 09/21] feat: track and include the original model of a request chain in fallback metadata. --- src/google/adk_community/plugins/fallback_plugin.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 35ceefdf..77b47055 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -93,6 +93,8 @@ def __init__( # Maps callback_context -> number of fallback attempts for that context. self._fallback_attempts: weakref.WeakKeyDictionary[CallbackContext, int] = weakref.WeakKeyDictionary() + # Maps callback_context -> original model for that context's request chain. + self._original_models: weakref.WeakKeyDictionary[CallbackContext, str] = weakref.WeakKeyDictionary() async def before_model_callback( self, @@ -118,6 +120,11 @@ async def before_model_callback( """ attempt_count = self._fallback_attempts.setdefault(callback_context, 0) + if attempt_count == 0: + # First attempt for this context. Record the original model for the chain. + original_model = self.root_model or llm_request.model + self._original_models[callback_context] = original_model + # Only reset to root_model when we are NOT mid-fallback. if self.root_model and attempt_count == 0: if llm_request.model != self.root_model: @@ -184,7 +191,7 @@ async def after_model_callback( if llm_response.custom_metadata is None: llm_response.custom_metadata = {} llm_response.custom_metadata["fallback_triggered"] = True - llm_response.custom_metadata["original_model"] = self.root_model + llm_response.custom_metadata["original_model"] = self._original_models.get(callback_context) llm_response.custom_metadata["fallback_model"] = self.fallback_model llm_response.custom_metadata["fallback_attempt"] = attempt_count llm_response.custom_metadata["error_code"] = str(llm_response.error_code) From 16c9960bc042f64d21250075b1d5493e32f8d64d Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 11:58:25 +0200 Subject: [PATCH 10/21] Update tests/unittests/plugins/test_fallback_plugin.py code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/unittests/plugins/test_fallback_plugin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py index 8801873a..8a13e6d5 100644 --- a/tests/unittests/plugins/test_fallback_plugin.py +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -89,11 +89,15 @@ async def test_after_model_callback_annotates_on_error(self): """Test that after_model_callback annotates response on error status.""" plugin = FallbackPlugin(root_model="root-model", fallback_model="fallback-model") mock_context = MagicMock() + mock_request = MagicMock(model="any-model") mock_response = MagicMock() mock_response.error_code = 429 mock_response.error_message = "Rate limit" mock_response.custom_metadata = {} + await plugin.before_model_callback( + callback_context=mock_context, llm_request=mock_request + ) await plugin.after_model_callback( callback_context=mock_context, llm_response=mock_response ) From f380090698d742230f5c6fffe856383ad3bf7405 Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 11:58:43 +0200 Subject: [PATCH 11/21] Update src/google/adk_community/plugins/fallback_plugin.py code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk_community/plugins/fallback_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 77b47055..e196e86b 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -153,7 +153,7 @@ async def after_model_callback( following keys into ``llm_response.custom_metadata``: - ``fallback_triggered`` (``bool``): Always ``True``. - - ``original_model`` (``str``): The value of ``root_model``. + - ``original_model`` (``str``): The model used for the initial request. - ``fallback_model`` (``str``): The value of ``fallback_model``. - ``fallback_attempt`` (``int``): The cumulative attempt count for this context. From 9359905a75ab7602d05277b8859fe86257aeac3e Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 11:58:59 +0200 Subject: [PATCH 12/21] Update src/google/adk_community/plugins/fallback_plugin.py code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk_community/plugins/fallback_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index e196e86b..1c100bbe 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -184,7 +184,7 @@ async def after_model_callback( if self.fallback_model: logger.info( "Fallback triggered: %s -> %s (attempt %d)", - self.root_model, + self._original_models.get(callback_context), self.fallback_model, attempt_count, ) From 2c3615d17c05548c6cd99ec6ebd1afb96fbc0dca Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 10:00:25 +0000 Subject: [PATCH 13/21] test: Verify `_original_models` cleanup by garbage collection in the fallback plugin test. --- tests/unittests/plugins/test_fallback_plugin.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py index 8a13e6d5..b8544142 100644 --- a/tests/unittests/plugins/test_fallback_plugin.py +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -149,20 +149,26 @@ class CustomContext: pass context = CustomContext() + mock_request = MagicMock(model="any-model") mock_response = MagicMock() mock_response.error_code = 429 mock_response.error_message = "Rate limit" mock_response.custom_metadata = {} + await plugin.before_model_callback( + callback_context=context, llm_request=mock_request + ) await plugin.after_model_callback( callback_context=context, llm_response=mock_response ) assert context in plugin._fallback_attempts + assert context in plugin._original_models del context gc.collect() # Force GC assert len(plugin._fallback_attempts) == 0 + assert len(plugin._original_models) == 0 From 3ee8206b9ff3f59b81237f5792c1ffab3dab8866 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 10:05:50 +0000 Subject: [PATCH 14/21] chore: Remove trailing blank line from test_fallback_plugin.py. --- tests/unittests/plugins/test_fallback_plugin.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unittests/plugins/test_fallback_plugin.py b/tests/unittests/plugins/test_fallback_plugin.py index b8544142..a859fe1e 100644 --- a/tests/unittests/plugins/test_fallback_plugin.py +++ b/tests/unittests/plugins/test_fallback_plugin.py @@ -170,5 +170,3 @@ class CustomContext: assert len(plugin._fallback_attempts) == 0 assert len(plugin._original_models) == 0 - - From 207d9a58cf5ee00454065bf21acf97f34dd98a3b Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 12:12:27 +0200 Subject: [PATCH 15/21] Update src/google/adk_community/plugins/fallback_plugin.py code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk_community/plugins/fallback_plugin.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 1c100bbe..815c894d 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -190,11 +190,13 @@ async def after_model_callback( ) if llm_response.custom_metadata is None: llm_response.custom_metadata = {} - llm_response.custom_metadata["fallback_triggered"] = True - llm_response.custom_metadata["original_model"] = self._original_models.get(callback_context) - llm_response.custom_metadata["fallback_model"] = self.fallback_model - llm_response.custom_metadata["fallback_attempt"] = attempt_count - llm_response.custom_metadata["error_code"] = str(llm_response.error_code) + llm_response.custom_metadata.update({ + "fallback_triggered": True, + "original_model": self._original_models.get(callback_context), + "fallback_model": self.fallback_model, + "fallback_attempt": attempt_count, + "error_code": str(llm_response.error_code), + }) else: logger.warning("No fallback model configured, cannot retry.") From 4ca018332c390f4c757184686f376d5ea1d2a10c Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 10:13:14 +0000 Subject: [PATCH 16/21] style: remove trailing blank line. --- src/google/adk_community/plugins/fallback_plugin.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 815c894d..83530580 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -203,5 +203,3 @@ async def after_model_callback( return await super().after_model_callback( callback_context=callback_context, llm_response=llm_response ) - - From b2da6d13848f5a7454c4b5fd585a509aa4b2d719 Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 12:17:47 +0200 Subject: [PATCH 17/21] Update src/google/adk_community/plugins/fallback_plugin.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk_community/plugins/fallback_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index 83530580..e3eed128 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -16,7 +16,7 @@ import logging import weakref -from typing import Optional +from typing import Optional, Sequence from opentelemetry import trace From b3a31931797f837a4df535a82d66fe9988c5ad4e Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 12:17:58 +0200 Subject: [PATCH 18/21] Update src/google/adk_community/plugins/fallback_plugin.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk_community/plugins/fallback_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index e3eed128..e2e624b3 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -69,7 +69,7 @@ def __init__( name: str = "fallback_plugin", root_model: Optional[str] = None, fallback_model: Optional[str] = None, - error_status: Optional[list[int]] = None, + error_status: Optional[Sequence[int]] = None, ) -> None: """Initializes the FallbackPlugin. From 308a2b610a0f8907bc39f341cfbb72418370a644 Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 12:18:08 +0200 Subject: [PATCH 19/21] Update src/google/adk_community/plugins/fallback_plugin.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk_community/plugins/fallback_plugin.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index e2e624b3..d1a841cd 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -125,9 +125,8 @@ async def before_model_callback( original_model = self.root_model or llm_request.model self._original_models[callback_context] = original_model - # Only reset to root_model when we are NOT mid-fallback. - if self.root_model and attempt_count == 0: - if llm_request.model != self.root_model: + # Reset to root_model if it's not already set. + if self.root_model and llm_request.model != self.root_model: logger.info( "Resetting model from %s to root model: %s", llm_request.model, From 1018faaaca2067276ee90c814e47ddc800eff721 Mon Sep 17 00:00:00 2001 From: g-benmizrahi Date: Thu, 19 Mar 2026 10:39:18 +0000 Subject: [PATCH 20/21] fix: Clear fallback state for a context after successful or non-retriable model calls. --- src/google/adk_community/plugins/fallback_plugin.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index d1a841cd..d74ca47b 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -198,6 +198,13 @@ async def after_model_callback( }) else: logger.warning("No fallback model configured, cannot retry.") + else: + # On success or non-retriable error, the fallback sequence is complete. + # Clear the state to ensure the next request for this context is fresh. + if callback_context in self._fallback_attempts: + del self._fallback_attempts[callback_context] + if callback_context in self._original_models: + del self._original_models[callback_context] return await super().after_model_callback( callback_context=callback_context, llm_response=llm_response From aeae065030b1a87bc8b1cd3832f75dc8d53d4085 Mon Sep 17 00:00:00 2001 From: Ben Mizrahi Date: Thu, 19 Mar 2026 12:44:16 +0200 Subject: [PATCH 21/21] Update src/google/adk_community/plugins/fallback_plugin.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk_community/plugins/fallback_plugin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/google/adk_community/plugins/fallback_plugin.py b/src/google/adk_community/plugins/fallback_plugin.py index d74ca47b..9171ac04 100644 --- a/src/google/adk_community/plugins/fallback_plugin.py +++ b/src/google/adk_community/plugins/fallback_plugin.py @@ -201,10 +201,8 @@ async def after_model_callback( else: # On success or non-retriable error, the fallback sequence is complete. # Clear the state to ensure the next request for this context is fresh. - if callback_context in self._fallback_attempts: - del self._fallback_attempts[callback_context] - if callback_context in self._original_models: - del self._original_models[callback_context] + self._fallback_attempts.pop(callback_context, None) + self._original_models.pop(callback_context, None) return await super().after_model_callback( callback_context=callback_context, llm_response=llm_response