diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index dec85690b3..94b2ed9a33 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -293,9 +293,20 @@ async def run_async( if ctx.end_invocation: return - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - yield event + try: + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + except Exception as agent_error: + # Notify plugins that this agent run failed before re-raising. + # after_agent_callback is intentionally skipped so plugins can + # distinguish a clean completion from a fatal failure. + await ctx.plugin_manager.run_on_agent_error_callback( + agent=self, + callback_context=CallbackContext(ctx), + error=agent_error, + ) + raise if ctx.end_invocation: return @@ -326,9 +337,18 @@ async def run_live( if ctx.end_invocation: return - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event + try: + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + except Exception as agent_error: + # Notify plugins that this live agent run failed before re-raising. + await ctx.plugin_manager.run_on_agent_error_callback( + agent=self, + callback_context=CallbackContext(ctx), + error=agent_error, + ) + raise if event := await self._handle_after_agent_callback(ctx): yield event diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 3639f61aa2..e39104b351 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,63 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_agent_error_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + error: Exception, + ) -> None: + """Callback executed when an unhandled exception escapes an agent's run. + + This callback fires when an exception propagates out of the agent's + ``_run_async_impl`` or ``_run_live_impl`` before ``after_agent_callback`` + has had a chance to execute. It is intended purely for observability + (logging, metrics, tracing) — the original exception is always re-raised + after all registered plugins have been notified. + + Unlike ``on_tool_error_callback`` and ``on_model_error_callback``, this + callback cannot swallow or replace the error; it always returns ``None``. + + Args: + agent: The agent instance whose execution raised the exception. + callback_context: The callback context for the failed agent invocation. + error: The exception that was raised. + + Returns: + None. The return value is ignored; the exception is re-raised by the + framework regardless. + """ + pass + + async def on_run_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + """Callback executed when an unhandled exception escapes a runner invocation. + + This callback fires when an exception propagates out of the runner's main + execution loop before ``after_run_callback`` has had a chance to execute. + It is intended purely for observability (logging, metrics, tracing) — the + original exception is always re-raised after all registered plugins have + been notified. + + This fills the gap where a fatal error (e.g. an unrecoverable model crash + or tool exception) would otherwise cause the invocation to disappear from + observability sinks without ever emitting a terminal event. + + Unlike ``on_tool_error_callback`` and ``on_model_error_callback``, this + callback cannot swallow or replace the error; it always returns ``None``. + + Args: + invocation_context: The context for the entire invocation. + error: The exception that escaped the runner's execution loop. + + Returns: + None. The return value is ignored; the exception is re-raised by the + framework regardless. + """ + pass diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index c781e8fa4e..496e9df192 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,8 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_agent_error_callback", + "on_run_error_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -257,6 +259,46 @@ async def run_on_tool_error_callback( error=error, ) + async def run_on_agent_error_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + error: Exception, + ) -> None: + """Runs the ``on_agent_error_callback`` for all plugins. + + All registered plugins are notified even if an earlier plugin raises — + failures in individual plugins are logged but do not prevent subsequent + plugins from being called. The original agent error is never suppressed + by this method. + """ + await self._run_error_callbacks( + "on_agent_error_callback", + agent=agent, + callback_context=callback_context, + error=error, + ) + + async def run_on_run_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + """Runs the ``on_run_error_callback`` for all plugins. + + All registered plugins are notified even if an earlier plugin raises — + failures in individual plugins are logged but do not prevent subsequent + plugins from being called. The original runner error is never suppressed + by this method. + """ + await self._run_error_callbacks( + "on_run_error_callback", + invocation_context=invocation_context, + error=error, + ) + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: @@ -306,6 +348,41 @@ async def _run_callbacks( return None + async def _run_error_callbacks( + self, callback_name: PluginCallbackName, **kwargs: Any + ) -> None: + """Executes an error-notification callback for **all** registered plugins. + + Unlike ``_run_callbacks``, this method does **not** stop on the first + non-``None`` return value. Error callbacks are pure observers — every + plugin deserves a chance to record the failure even if an earlier plugin + in the chain itself encounters an error. + + Individual plugin failures are logged but do not prevent subsequent + plugins from being called, and they do not propagate to the caller. The + underlying framework error that triggered this notification is always + re-raised by the caller independently. + + Args: + callback_name: The name of the error callback method to execute. + **kwargs: Keyword arguments forwarded to each plugin's callback. + """ + for plugin in self.plugins: + callback_method = getattr(plugin, callback_name) + try: + await callback_method(**kwargs) + except Exception as e: + # Log but continue — a broken observability plugin must not hide the + # original error from the framework or prevent other plugins from + # receiving the notification. + logger.error( + "Error in plugin '%s' during '%s' callback: %s", + plugin.name, + callback_name, + e, + exc_info=True, + ) + async def close(self) -> None: """Calls the close method on all registered plugins concurrently. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 8e352794a4..f60ecbad6b 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -849,73 +849,84 @@ async def _exec_with_plugin( buffered_events: list[Event] = [] is_transcribing: bool = False - async with Aclosing(execute_fn(invocation_context)) as agen: - async for event in agen: - _apply_run_config_custom_metadata( - event, invocation_context.run_config - ) - if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-partial event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text(event.input_transcription) - or _has_non_empty_transcription_text( - event.output_transcription + try: + async with Aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + _apply_run_config_custom_metadata( + event, invocation_context.run_config + ) + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + # only buffer function call and function response event which is + # non-partial + buffered_events.append(event) + continue + # Note for live/bidi: for audio response, it's considered as + # non-partial event(event.partial=None) + # event.partial=False and event.partial=None are considered as + # non-partial event; event.partial=True is considered as partial + # event. + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text(event.input_transcription) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + # transcription end signal, append buffered events + is_transcribing = False + logger.debug( + 'Appending transcription finished event: %s', event ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - 'Appending transcription finished event: %s', event + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=session, event=event + ) + + for buffered_event in buffered_events: + logger.debug('Appending buffered event: %s', buffered_event) + await self.session_service.append_event( + session=session, event=buffered_event + ) + yield buffered_event # yield buffered events to caller + buffered_events = [] + else: + # non-transcription event or empty transcription event, for + # example, event that stores blob reference, should be appended. + if self._should_append_event(event, is_live_call): + logger.debug('Appending non-buffered event: %s', event) + await self.session_service.append_event( + session=session, event=event + ) + else: + if event.partial is not True: + await self.session_service.append_event( + session=session, event=event ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=session, event=event - ) - - for buffered_event in buffered_events: - logger.debug('Appending buffered event: %s', buffered_event) - await self.session_service.append_event( - session=session, event=buffered_event - ) - yield buffered_event # yield buffered events to caller - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug('Appending non-buffered event: %s', event) - await self.session_service.append_event( - session=session, event=event - ) - else: - if event.partial is not True: - await self.session_service.append_event( - session=session, event=event - ) - # Step 3: Run the on_event callbacks to optionally modify the event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - if modified_event: - _apply_run_config_custom_metadata( - modified_event, invocation_context.run_config + # Step 3: Run the on_event callbacks to optionally modify the event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event ) - yield modified_event - else: - yield event + if modified_event: + _apply_run_config_custom_metadata( + modified_event, invocation_context.run_config + ) + yield modified_event + else: + yield event + except Exception as run_error: + # Step 3b: Notify all plugins that this invocation failed. + # The callback is fire-and-forget — it cannot suppress the error. + # after_run_callback is intentionally skipped on the error path so + # that plugins can distinguish clean completions from fatal failures. + await plugin_manager.run_on_run_error_callback( + invocation_context=invocation_context, + error=run_error, + ) + raise # Step 4: Run the after_run callbacks to perform global cleanup tasks or # finalizing logs and metrics data. diff --git a/tests/unittests/agents/test_agent_error_callbacks.py b/tests/unittests/agents/test_agent_error_callbacks.py new file mode 100644 index 0000000000..395a9e652d --- /dev/null +++ b/tests/unittests/agents/test_agent_error_callbacks.py @@ -0,0 +1,266 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for on_agent_error_callback in BaseAgent lifecycle. + +These tests verify that: + - on_agent_error_callback is called when _run_async_impl raises. + - on_agent_error_callback is called when _run_live_impl raises. + - after_agent_callback is NOT called on the error path. + - The original exception is re-raised unchanged after notification. + - on_agent_error_callback receives the agent, callback_context, and error. + - on_agent_error_callback is NOT called on a successful run. +""" + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import ClassVar +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.plugins.plugin_manager import PluginManager +from google.genai import types +import pytest +from typing_extensions import override + +from .. import testing_utils + +# --------------------------------------------------------------------------- +# Concrete agent implementations +# --------------------------------------------------------------------------- + + +class _SuccessAgent(BaseAgent): + + def __init__(self): + super().__init__(name="success_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content(parts=[types.Part.from_text(text="done")]), + ) + + +class _FailingAgent(BaseAgent): + BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded") + + def __init__(self): + super().__init__(name="failing_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingAgent.BOOM + yield # pragma: no cover + + +class _FailingLiveAgent(BaseAgent): + BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded") + + def __init__(self): + super().__init__(name="failing_live_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield # pragma: no cover + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingLiveAgent.BOOM + yield # pragma: no cover + + +# --------------------------------------------------------------------------- +# Tracking plugin +# --------------------------------------------------------------------------- + + +class TrackingPlugin(BasePlugin): + __test__ = False + + def __init__(self, name: str = "tracker"): + super().__init__(name) + self.after_agent_called = False + self.agent_error_calls: list[dict] = [] + + async def after_agent_callback(self, *, agent, callback_context, **kwargs): + self.after_agent_called = True + + async def on_agent_error_callback( + self, *, agent, callback_context, error, **kwargs + ) -> None: + self.agent_error_calls.append( + {"agent": agent, "callback_context": callback_context, "error": error} + ) + + +# --------------------------------------------------------------------------- +# Helper to drive run_async +# --------------------------------------------------------------------------- + + +async def _collect_events( + agent: BaseAgent, plugins: list[BasePlugin] +) -> list[Event]: + inv_ctx = await testing_utils.create_invocation_context( + agent=agent, plugins=plugins + ) + events = [] + async for event in agent.run_async(inv_ctx): + events.append(event) + return events + + +async def _collect_live_events( + agent: BaseAgent, plugins: list[BasePlugin] +) -> list[Event]: + inv_ctx = await testing_utils.create_invocation_context( + agent=agent, plugins=plugins + ) + events = [] + async for event in agent.run_live(inv_ctx): + events.append(event) + return events + + +# --------------------------------------------------------------------------- +# Tests — run_async path +# --------------------------------------------------------------------------- + + +class TestAgentOnAgentErrorCallbackAsync: + + @pytest.mark.asyncio + async def test_on_agent_error_callback_called_when_impl_raises(self): + tracker = TrackingPlugin() + with pytest.raises(RuntimeError, match="agent impl exploded"): + await _collect_events(_FailingAgent(), [tracker]) + + assert len(tracker.agent_error_calls) == 1 + assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_on_agent_error_callback_receives_correct_agent(self): + tracker = TrackingPlugin() + agent = _FailingAgent() + + with pytest.raises(RuntimeError): + await _collect_events(agent, [tracker]) + + assert tracker.agent_error_calls[0]["agent"] is agent + + @pytest.mark.asyncio + async def test_on_agent_error_callback_receives_callback_context(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker]) + + cb_ctx = tracker.agent_error_calls[0]["callback_context"] + assert isinstance(cb_ctx, CallbackContext) + + @pytest.mark.asyncio + async def test_original_exception_reraised_after_notification(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError) as exc_info: + await _collect_events(_FailingAgent(), [tracker]) + + assert exc_info.value is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_after_agent_callback_not_called_on_error(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker]) + + assert not tracker.after_agent_called + + @pytest.mark.asyncio + async def test_on_agent_error_callback_not_called_on_success(self): + tracker = TrackingPlugin() + events = await _collect_events(_SuccessAgent(), [tracker]) + + assert len(events) >= 1 + assert len(tracker.agent_error_calls) == 0 + + @pytest.mark.asyncio + async def test_after_agent_callback_still_called_on_success(self): + tracker = TrackingPlugin() + await _collect_events(_SuccessAgent(), [tracker]) + + assert tracker.after_agent_called + + @pytest.mark.asyncio + async def test_multiple_plugins_all_notified_on_agent_error(self): + tracker_a = TrackingPlugin("a") + tracker_b = TrackingPlugin("b") + + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker_a, tracker_b]) + + assert len(tracker_a.agent_error_calls) == 1 + assert len(tracker_b.agent_error_calls) == 1 + + +# --------------------------------------------------------------------------- +# Tests — run_live path +# --------------------------------------------------------------------------- + + +class TestAgentOnAgentErrorCallbackLive: + + @pytest.mark.asyncio + async def test_on_agent_error_callback_called_when_live_impl_raises(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError, match="live agent impl exploded"): + await _collect_live_events(_FailingLiveAgent(), [tracker]) + + assert len(tracker.agent_error_calls) == 1 + assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM + + @pytest.mark.asyncio + async def test_after_agent_callback_not_called_on_live_error(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError): + await _collect_live_events(_FailingLiveAgent(), [tracker]) + + assert not tracker.after_agent_called + + @pytest.mark.asyncio + async def test_original_live_exception_reraised_unchanged(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError) as exc_info: + await _collect_live_events(_FailingLiveAgent(), [tracker]) + + assert exc_info.value is _FailingLiveAgent.BOOM diff --git a/tests/unittests/plugins/test_lifecycle_error_callbacks.py b/tests/unittests/plugins/test_lifecycle_error_callbacks.py new file mode 100644 index 0000000000..c95d55c772 --- /dev/null +++ b/tests/unittests/plugins/test_lifecycle_error_callbacks.py @@ -0,0 +1,315 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the new on_run_error_callback and on_agent_error_callback +lifecycle hooks added to BasePlugin and PluginManager (issue #4774). + +These tests focus on: + 1. PluginManager._run_error_callbacks — fan-out semantics (all plugins + notified even when one raises). + 2. PluginManager.run_on_run_error_callback and + run_on_agent_error_callback — correct argument forwarding. + 3. BasePlugin default no-op implementations return None. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.plugins.plugin_manager import PluginManager +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class ObservabilityPlugin(BasePlugin): + """A test plugin that records every on_*_error_callback invocation.""" + + __test__ = False + + def __init__(self, name: str, *, raise_on_error: bool = False): + super().__init__(name) + self.agent_error_calls: list[dict] = [] + self.run_error_calls: list[dict] = [] + self._raise_on_error = raise_on_error + + async def on_agent_error_callback( + self, *, agent, callback_context, error + ) -> None: + self.agent_error_calls.append( + {"agent": agent, "callback_context": callback_context, "error": error} + ) + if self._raise_on_error: + raise RuntimeError( + f"{self.name} intentionally raised in on_agent_error_callback" + ) + + async def on_run_error_callback(self, *, invocation_context, error) -> None: + self.run_error_calls.append( + {"invocation_context": invocation_context, "error": error} + ) + if self._raise_on_error: + raise RuntimeError( + f"{self.name} intentionally raised in on_run_error_callback" + ) + + +# --------------------------------------------------------------------------- +# BasePlugin default no-op tests +# --------------------------------------------------------------------------- + + +class TestBasePluginDefaults: + """Verifies that the default BasePlugin implementations are safe no-ops.""" + + @pytest.mark.asyncio + async def test_on_agent_error_callback_returns_none_by_default(self): + """Default on_agent_error_callback must return None without raising.""" + + class MinimalPlugin(BasePlugin): + pass + + plugin = MinimalPlugin(name="minimal") + result = await plugin.on_agent_error_callback( + agent=Mock(), + callback_context=Mock(), + error=ValueError("boom"), + ) + assert result is None + + @pytest.mark.asyncio + async def test_on_run_error_callback_returns_none_by_default(self): + """Default on_run_error_callback must return None without raising.""" + + class MinimalPlugin(BasePlugin): + pass + + plugin = MinimalPlugin(name="minimal") + result = await plugin.on_run_error_callback( + invocation_context=Mock(), + error=RuntimeError("boom"), + ) + assert result is None + + +# --------------------------------------------------------------------------- +# PluginManager._run_error_callbacks fan-out semantics +# --------------------------------------------------------------------------- + + +class TestRunErrorCallbacksFanOut: + """Verifies that _run_error_callbacks notifies ALL plugins regardless of + earlier plugin failures (unlike the early-exit _run_callbacks).""" + + @pytest.mark.asyncio + async def test_all_plugins_notified_when_no_failures(self): + """All plugins must be called when none raise.""" + plugin_a = ObservabilityPlugin("a") + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + inv_ctx = Mock(spec=InvocationContext) + error = ValueError("run exploded") + + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + assert len(plugin_a.run_error_calls) == 1 + assert len(plugin_b.run_error_calls) == 1 + assert plugin_a.run_error_calls[0]["error"] is error + assert plugin_b.run_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_subsequent_plugins_called_even_when_first_raises(self): + """If plugin_a raises during on_run_error_callback, plugin_b must still + be called — a broken observability plugin must not silence others.""" + plugin_a = ObservabilityPlugin("a", raise_on_error=True) + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + inv_ctx = Mock(spec=InvocationContext) + error = ValueError("original error") + + # Should NOT raise even though plugin_a raises internally + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + assert len(plugin_b.run_error_calls) == 1 + assert plugin_b.run_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_agent_error_all_plugins_notified_when_no_failures(self): + """All plugins must be called for on_agent_error_callback.""" + plugin_a = ObservabilityPlugin("a") + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = RuntimeError("agent died") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + assert len(plugin_a.agent_error_calls) == 1 + assert len(plugin_b.agent_error_calls) == 1 + assert plugin_a.agent_error_calls[0]["agent"] is agent + assert plugin_b.agent_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_agent_error_subsequent_plugins_called_even_when_first_raises( + self, + ): + """Broken plugin in on_agent_error_callback must not block others.""" + plugin_a = ObservabilityPlugin("a", raise_on_error=True) + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = RuntimeError("agent died") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + assert len(plugin_b.agent_error_calls) == 1 + + @pytest.mark.asyncio + async def test_no_plugins_registered_does_not_raise(self): + """run_on_run_error_callback with zero plugins must be a safe no-op.""" + manager = PluginManager() + await manager.run_on_run_error_callback( + invocation_context=Mock(), error=ValueError("x") + ) + + @pytest.mark.asyncio + async def test_no_plugins_registered_agent_does_not_raise(self): + """run_on_agent_error_callback with zero plugins must be a safe no-op.""" + manager = PluginManager() + await manager.run_on_agent_error_callback( + agent=Mock(), callback_context=Mock(), error=ValueError("x") + ) + + +# --------------------------------------------------------------------------- +# PluginManager — argument forwarding +# --------------------------------------------------------------------------- + + +class TestArgumentForwarding: + """Verifies exact argument passing to each plugin callback.""" + + @pytest.mark.asyncio + async def test_run_on_run_error_callback_passes_correct_kwargs(self): + plugin = ObservabilityPlugin("p") + manager = PluginManager(plugins=[plugin]) + + inv_ctx = Mock(spec=InvocationContext) + error = KeyError("missing_key") + + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + call = plugin.run_error_calls[0] + assert call["invocation_context"] is inv_ctx + assert call["error"] is error + + @pytest.mark.asyncio + async def test_run_on_agent_error_callback_passes_correct_kwargs(self): + plugin = ObservabilityPlugin("p") + manager = PluginManager(plugins=[plugin]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = TypeError("bad_type") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + call = plugin.agent_error_calls[0] + assert call["agent"] is agent + assert call["callback_context"] is cb_ctx + assert call["error"] is error + + +# --------------------------------------------------------------------------- +# Contrast with _run_callbacks early-exit semantics +# --------------------------------------------------------------------------- + + +class TestErrorCallbacksDoNotEarlyExit: + """Confirms error callbacks ignore non-None returns (no early exit).""" + + @pytest.mark.asyncio + async def test_on_run_error_return_value_is_ignored(self): + """on_run_error_callback returning a value must not stop other plugins.""" + + class ReturningPlugin(BasePlugin): + + def __init__(self, name): + super().__init__(name) + self.called = False + + async def on_run_error_callback(self, **kwargs): + self.called = True + return "non-none-value-that-should-be-ignored" + + p1 = ReturningPlugin("p1") + p2 = ReturningPlugin("p2") + manager = PluginManager(plugins=[p1, p2]) + + await manager.run_on_run_error_callback( + invocation_context=Mock(), error=ValueError("x") + ) + + assert p1.called + assert p2.called # must NOT be short-circuited + + @pytest.mark.asyncio + async def test_on_agent_error_return_value_is_ignored(self): + """on_agent_error_callback returning a value must not stop other plugins.""" + + class ReturningPlugin(BasePlugin): + + def __init__(self, name): + super().__init__(name) + self.called = False + + async def on_agent_error_callback(self, **kwargs): + self.called = True + return "non-none-value" + + p1 = ReturningPlugin("p1") + p2 = ReturningPlugin("p2") + manager = PluginManager(plugins=[p1, p2]) + + await manager.run_on_agent_error_callback( + agent=Mock(), callback_context=Mock(), error=ValueError("x") + ) + + assert p1.called + assert p2.called diff --git a/tests/unittests/runners/test_runner_error_callbacks.py b/tests/unittests/runners/test_runner_error_callbacks.py new file mode 100644 index 0000000000..226567dee3 --- /dev/null +++ b/tests/unittests/runners/test_runner_error_callbacks.py @@ -0,0 +1,283 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for on_run_error_callback in the Runner lifecycle. + +These tests verify that: + - on_run_error_callback is called when the execution generator raises. + - after_run_callback is NOT called on the error path. + - The original exception is re-raised unchanged. + - on_run_error_callback receives the correct invocation_context and error. + - on_run_error_callback is NOT called on successful runs. +""" + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import ClassVar +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.runners import InMemoryRunner +from google.genai import types +import pytest +from typing_extensions import override + +from .. import testing_utils + +# --------------------------------------------------------------------------- +# Concrete agent implementations for testing +# --------------------------------------------------------------------------- + + +class _SuccessAgent(BaseAgent): + """Agent that yields one event then completes successfully.""" + + def __init__(self): + super().__init__(name="success_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content(parts=[types.Part.from_text(text="ok")]), + ) + + +class _FailingAgent(BaseAgent): + """Agent that raises a RuntimeError during execution.""" + + BOOM: ClassVar[RuntimeError] = RuntimeError("agent exploded mid-run") + + def __init__(self): + super().__init__(name="failing_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingAgent.BOOM + yield # make this an async generator + + +# --------------------------------------------------------------------------- +# Tracking plugin +# --------------------------------------------------------------------------- + + +class TrackingPlugin(BasePlugin): + """Records lifecycle callback invocations for assertions.""" + + __test__ = False + + def __init__(self, name: str = "tracker"): + super().__init__(name) + self.after_run_called = False + self.run_error_calls: list[dict] = [] + + async def after_run_callback(self, *, invocation_context, **kwargs) -> None: + self.after_run_called = True + + async def on_run_error_callback( + self, *, invocation_context, error, **kwargs + ) -> None: + self.run_error_calls.append( + {"invocation_context": invocation_context, "error": error} + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRunnerOnRunErrorCallback: + + def _make_runner(self, agent: BaseAgent, plugins: list[BasePlugin]): + return InMemoryRunner(agent=agent, plugins=plugins) + + @pytest.mark.asyncio + async def test_on_run_error_callback_called_when_agent_raises(self): + """on_run_error_callback must fire when the agent execution raises.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u1" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError, match="agent exploded mid-run"): + async for _ in runner.run_async( + user_id="u1", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker.run_error_calls) == 1 + assert tracker.run_error_calls[0]["error"] is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_after_run_callback_not_called_on_error(self): + """after_run_callback must NOT be called when execution raises.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u2" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u2", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert not tracker.after_run_called + + @pytest.mark.asyncio + async def test_original_exception_is_reraised_unchanged(self): + """The exact original exception must propagate to the caller.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u3" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError) as exc_info: + async for _ in runner.run_async( + user_id="u3", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert exc_info.value is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_on_run_error_callback_receives_correct_invocation_context( + self, + ): + """The invocation_context passed to on_run_error_callback must be the + same one used for the run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u4" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u4", + session_id=session.id, + new_message=user_msg, + ): + pass + + inv_ctx = tracker.run_error_calls[0]["invocation_context"] + assert isinstance(inv_ctx, InvocationContext) + assert inv_ctx.session.id == session.id + + @pytest.mark.asyncio + async def test_on_run_error_callback_not_called_on_success(self): + """on_run_error_callback must NOT be called for a successful run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_SuccessAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u5" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + async for _ in runner.run_async( + user_id="u5", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker.run_error_calls) == 0 + + @pytest.mark.asyncio + async def test_after_run_callback_called_on_success(self): + """after_run_callback must still be called for a successful run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_SuccessAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u6" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + async for _ in runner.run_async( + user_id="u6", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert tracker.after_run_called + + @pytest.mark.asyncio + async def test_multiple_plugins_all_notified_on_error(self): + """All registered plugins must receive on_run_error_callback on failure.""" + tracker_a = TrackingPlugin("tracker_a") + tracker_b = TrackingPlugin("tracker_b") + runner = self._make_runner(_FailingAgent(), [tracker_a, tracker_b]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u7" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u7", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker_a.run_error_calls) == 1 + assert len(tracker_b.run_error_calls) == 1