From 7cdef7c867b0b6af400b8a32d00de9f6bfbd3b93 Mon Sep 17 00:00:00 2001 From: STHITAPRAJNAS Date: Tue, 24 Mar 2026 03:31:44 +0000 Subject: [PATCH 1/2] feat(plugins): add on_agent_error_callback and on_run_error_callback lifecycle hooks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #4774 When an unhandled exception propagates out of an agent's _run_async_impl / _run_live_impl, or out of the runner's execution loop, the existing after_agent_callback / after_run_callback were silently skipped. This made fatal failures invisible to observability plugins (e.g. BigQuery analytics), inflating success rates and losing failure events entirely. Changes: - BasePlugin: add on_agent_error_callback(agent, callback_context, error) and on_run_error_callback(invocation_context, error) with safe no-op defaults. - PluginManager: add run_on_agent_error_callback / run_on_run_error_callback dispatch methods backed by a new _run_error_callbacks helper that fans out to ALL plugins (no early-exit) and logs — but does not propagate — individual plugin failures. - base_agent.py: wrap run_async / run_live generator loops in try/except; call run_on_agent_error_callback before re-raising. - runners.py: wrap the execute_fn generator loop in try/except; call run_on_run_error_callback before re-raising. after_run_callback is intentionally skipped on the error path so plugins can distinguish clean completions from fatal failures. Tests (30 new, all passing): - tests/unittests/plugins/test_lifecycle_error_callbacks.py - tests/unittests/runners/test_runner_error_callbacks.py - tests/unittests/agents/test_agent_error_callbacks.py --- src/google/adk/agents/base_agent.py | 32 +- src/google/adk/plugins/base_plugin.py | 60 ++++ src/google/adk/plugins/plugin_manager.py | 77 +++++ src/google/adk/runners.py | 137 ++++---- .../agents/test_agent_error_callbacks.py | 259 +++++++++++++++ .../plugins/test_lifecycle_error_callbacks.py | 305 ++++++++++++++++++ .../runners/test_runner_error_callbacks.py | 279 ++++++++++++++++ 7 files changed, 1080 insertions(+), 69 deletions(-) create mode 100644 tests/unittests/agents/test_agent_error_callbacks.py create mode 100644 tests/unittests/plugins/test_lifecycle_error_callbacks.py create mode 100644 tests/unittests/runners/test_runner_error_callbacks.py diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index dec85690b3..94b2ed9a33 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -293,9 +293,20 @@ async def run_async( if ctx.end_invocation: return - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - yield event + try: + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + except Exception as agent_error: + # Notify plugins that this agent run failed before re-raising. + # after_agent_callback is intentionally skipped so plugins can + # distinguish a clean completion from a fatal failure. + await ctx.plugin_manager.run_on_agent_error_callback( + agent=self, + callback_context=CallbackContext(ctx), + error=agent_error, + ) + raise if ctx.end_invocation: return @@ -326,9 +337,18 @@ async def run_live( if ctx.end_invocation: return - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event + try: + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + except Exception as agent_error: + # Notify plugins that this live agent run failed before re-raising. + await ctx.plugin_manager.run_on_agent_error_callback( + agent=self, + callback_context=CallbackContext(ctx), + error=agent_error, + ) + raise if event := await self._handle_after_agent_callback(ctx): yield event diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 3639f61aa2..e39104b351 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,63 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_agent_error_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + error: Exception, + ) -> None: + """Callback executed when an unhandled exception escapes an agent's run. + + This callback fires when an exception propagates out of the agent's + ``_run_async_impl`` or ``_run_live_impl`` before ``after_agent_callback`` + has had a chance to execute. It is intended purely for observability + (logging, metrics, tracing) — the original exception is always re-raised + after all registered plugins have been notified. + + Unlike ``on_tool_error_callback`` and ``on_model_error_callback``, this + callback cannot swallow or replace the error; it always returns ``None``. + + Args: + agent: The agent instance whose execution raised the exception. + callback_context: The callback context for the failed agent invocation. + error: The exception that was raised. + + Returns: + None. The return value is ignored; the exception is re-raised by the + framework regardless. + """ + pass + + async def on_run_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + """Callback executed when an unhandled exception escapes a runner invocation. + + This callback fires when an exception propagates out of the runner's main + execution loop before ``after_run_callback`` has had a chance to execute. + It is intended purely for observability (logging, metrics, tracing) — the + original exception is always re-raised after all registered plugins have + been notified. + + This fills the gap where a fatal error (e.g. an unrecoverable model crash + or tool exception) would otherwise cause the invocation to disappear from + observability sinks without ever emitting a terminal event. + + Unlike ``on_tool_error_callback`` and ``on_model_error_callback``, this + callback cannot swallow or replace the error; it always returns ``None``. + + Args: + invocation_context: The context for the entire invocation. + error: The exception that escaped the runner's execution loop. + + Returns: + None. The return value is ignored; the exception is re-raised by the + framework regardless. + """ + pass diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index c781e8fa4e..496e9df192 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,8 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_agent_error_callback", + "on_run_error_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -257,6 +259,46 @@ async def run_on_tool_error_callback( error=error, ) + async def run_on_agent_error_callback( + self, + *, + agent: BaseAgent, + callback_context: CallbackContext, + error: Exception, + ) -> None: + """Runs the ``on_agent_error_callback`` for all plugins. + + All registered plugins are notified even if an earlier plugin raises — + failures in individual plugins are logged but do not prevent subsequent + plugins from being called. The original agent error is never suppressed + by this method. + """ + await self._run_error_callbacks( + "on_agent_error_callback", + agent=agent, + callback_context=callback_context, + error=error, + ) + + async def run_on_run_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> None: + """Runs the ``on_run_error_callback`` for all plugins. + + All registered plugins are notified even if an earlier plugin raises — + failures in individual plugins are logged but do not prevent subsequent + plugins from being called. The original runner error is never suppressed + by this method. + """ + await self._run_error_callbacks( + "on_run_error_callback", + invocation_context=invocation_context, + error=error, + ) + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: @@ -306,6 +348,41 @@ async def _run_callbacks( return None + async def _run_error_callbacks( + self, callback_name: PluginCallbackName, **kwargs: Any + ) -> None: + """Executes an error-notification callback for **all** registered plugins. + + Unlike ``_run_callbacks``, this method does **not** stop on the first + non-``None`` return value. Error callbacks are pure observers — every + plugin deserves a chance to record the failure even if an earlier plugin + in the chain itself encounters an error. + + Individual plugin failures are logged but do not prevent subsequent + plugins from being called, and they do not propagate to the caller. The + underlying framework error that triggered this notification is always + re-raised by the caller independently. + + Args: + callback_name: The name of the error callback method to execute. + **kwargs: Keyword arguments forwarded to each plugin's callback. + """ + for plugin in self.plugins: + callback_method = getattr(plugin, callback_name) + try: + await callback_method(**kwargs) + except Exception as e: + # Log but continue — a broken observability plugin must not hide the + # original error from the framework or prevent other plugins from + # receiving the notification. + logger.error( + "Error in plugin '%s' during '%s' callback: %s", + plugin.name, + callback_name, + e, + exc_info=True, + ) + async def close(self) -> None: """Calls the close method on all registered plugins concurrently. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 8e352794a4..f60ecbad6b 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -849,73 +849,84 @@ async def _exec_with_plugin( buffered_events: list[Event] = [] is_transcribing: bool = False - async with Aclosing(execute_fn(invocation_context)) as agen: - async for event in agen: - _apply_run_config_custom_metadata( - event, invocation_context.run_config - ) - if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-partial event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text(event.input_transcription) - or _has_non_empty_transcription_text( - event.output_transcription + try: + async with Aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + _apply_run_config_custom_metadata( + event, invocation_context.run_config + ) + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + # only buffer function call and function response event which is + # non-partial + buffered_events.append(event) + continue + # Note for live/bidi: for audio response, it's considered as + # non-partial event(event.partial=None) + # event.partial=False and event.partial=None are considered as + # non-partial event; event.partial=True is considered as partial + # event. + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text(event.input_transcription) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + # transcription end signal, append buffered events + is_transcribing = False + logger.debug( + 'Appending transcription finished event: %s', event ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - 'Appending transcription finished event: %s', event + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=session, event=event + ) + + for buffered_event in buffered_events: + logger.debug('Appending buffered event: %s', buffered_event) + await self.session_service.append_event( + session=session, event=buffered_event + ) + yield buffered_event # yield buffered events to caller + buffered_events = [] + else: + # non-transcription event or empty transcription event, for + # example, event that stores blob reference, should be appended. + if self._should_append_event(event, is_live_call): + logger.debug('Appending non-buffered event: %s', event) + await self.session_service.append_event( + session=session, event=event + ) + else: + if event.partial is not True: + await self.session_service.append_event( + session=session, event=event ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=session, event=event - ) - - for buffered_event in buffered_events: - logger.debug('Appending buffered event: %s', buffered_event) - await self.session_service.append_event( - session=session, event=buffered_event - ) - yield buffered_event # yield buffered events to caller - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug('Appending non-buffered event: %s', event) - await self.session_service.append_event( - session=session, event=event - ) - else: - if event.partial is not True: - await self.session_service.append_event( - session=session, event=event - ) - # Step 3: Run the on_event callbacks to optionally modify the event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - if modified_event: - _apply_run_config_custom_metadata( - modified_event, invocation_context.run_config + # Step 3: Run the on_event callbacks to optionally modify the event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event ) - yield modified_event - else: - yield event + if modified_event: + _apply_run_config_custom_metadata( + modified_event, invocation_context.run_config + ) + yield modified_event + else: + yield event + except Exception as run_error: + # Step 3b: Notify all plugins that this invocation failed. + # The callback is fire-and-forget — it cannot suppress the error. + # after_run_callback is intentionally skipped on the error path so + # that plugins can distinguish clean completions from fatal failures. + await plugin_manager.run_on_run_error_callback( + invocation_context=invocation_context, + error=run_error, + ) + raise # Step 4: Run the after_run callbacks to perform global cleanup tasks or # finalizing logs and metrics data. diff --git a/tests/unittests/agents/test_agent_error_callbacks.py b/tests/unittests/agents/test_agent_error_callbacks.py new file mode 100644 index 0000000000..01f708c694 --- /dev/null +++ b/tests/unittests/agents/test_agent_error_callbacks.py @@ -0,0 +1,259 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for on_agent_error_callback in BaseAgent lifecycle. + +These tests verify that: + - on_agent_error_callback is called when _run_async_impl raises. + - on_agent_error_callback is called when _run_live_impl raises. + - after_agent_callback is NOT called on the error path. + - The original exception is re-raised unchanged after notification. + - on_agent_error_callback receives the agent, callback_context, and error. + - on_agent_error_callback is NOT called on a successful run. +""" + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import ClassVar +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.plugins.plugin_manager import PluginManager +from google.genai import types +import pytest +from typing_extensions import override + +from .. import testing_utils + + +# --------------------------------------------------------------------------- +# Concrete agent implementations +# --------------------------------------------------------------------------- + +class _SuccessAgent(BaseAgent): + def __init__(self): + super().__init__(name="success_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content(parts=[types.Part.from_text(text="done")]), + ) + + +class _FailingAgent(BaseAgent): + BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded") + + def __init__(self): + super().__init__(name="failing_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingAgent.BOOM + yield # pragma: no cover + + +class _FailingLiveAgent(BaseAgent): + BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded") + + def __init__(self): + super().__init__(name="failing_live_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield # pragma: no cover + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingLiveAgent.BOOM + yield # pragma: no cover + + +# --------------------------------------------------------------------------- +# Tracking plugin +# --------------------------------------------------------------------------- + +class TrackingPlugin(BasePlugin): + __test__ = False + + def __init__(self, name: str = "tracker"): + super().__init__(name) + self.after_agent_called = False + self.agent_error_calls: list[dict] = [] + + async def after_agent_callback(self, *, agent, callback_context, **kwargs): + self.after_agent_called = True + + async def on_agent_error_callback( + self, *, agent, callback_context, error, **kwargs + ) -> None: + self.agent_error_calls.append( + {"agent": agent, "callback_context": callback_context, "error": error} + ) + + +# --------------------------------------------------------------------------- +# Helper to drive run_async +# --------------------------------------------------------------------------- + +async def _collect_events(agent: BaseAgent, plugins: list[BasePlugin]) -> list[Event]: + inv_ctx = await testing_utils.create_invocation_context( + agent=agent, plugins=plugins + ) + events = [] + async for event in agent.run_async(inv_ctx): + events.append(event) + return events + + +async def _collect_live_events( + agent: BaseAgent, plugins: list[BasePlugin] +) -> list[Event]: + inv_ctx = await testing_utils.create_invocation_context( + agent=agent, plugins=plugins + ) + events = [] + async for event in agent.run_live(inv_ctx): + events.append(event) + return events + + +# --------------------------------------------------------------------------- +# Tests — run_async path +# --------------------------------------------------------------------------- + +class TestAgentOnAgentErrorCallbackAsync: + + @pytest.mark.asyncio + async def test_on_agent_error_callback_called_when_impl_raises(self): + tracker = TrackingPlugin() + with pytest.raises(RuntimeError, match="agent impl exploded"): + await _collect_events(_FailingAgent(), [tracker]) + + assert len(tracker.agent_error_calls) == 1 + assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_on_agent_error_callback_receives_correct_agent(self): + tracker = TrackingPlugin() + agent = _FailingAgent() + + with pytest.raises(RuntimeError): + await _collect_events(agent, [tracker]) + + assert tracker.agent_error_calls[0]["agent"] is agent + + @pytest.mark.asyncio + async def test_on_agent_error_callback_receives_callback_context(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker]) + + cb_ctx = tracker.agent_error_calls[0]["callback_context"] + assert isinstance(cb_ctx, CallbackContext) + + @pytest.mark.asyncio + async def test_original_exception_reraised_after_notification(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError) as exc_info: + await _collect_events(_FailingAgent(), [tracker]) + + assert exc_info.value is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_after_agent_callback_not_called_on_error(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker]) + + assert not tracker.after_agent_called + + @pytest.mark.asyncio + async def test_on_agent_error_callback_not_called_on_success(self): + tracker = TrackingPlugin() + events = await _collect_events(_SuccessAgent(), [tracker]) + + assert len(events) >= 1 + assert len(tracker.agent_error_calls) == 0 + + @pytest.mark.asyncio + async def test_after_agent_callback_still_called_on_success(self): + tracker = TrackingPlugin() + await _collect_events(_SuccessAgent(), [tracker]) + + assert tracker.after_agent_called + + @pytest.mark.asyncio + async def test_multiple_plugins_all_notified_on_agent_error(self): + tracker_a = TrackingPlugin("a") + tracker_b = TrackingPlugin("b") + + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker_a, tracker_b]) + + assert len(tracker_a.agent_error_calls) == 1 + assert len(tracker_b.agent_error_calls) == 1 + + +# --------------------------------------------------------------------------- +# Tests — run_live path +# --------------------------------------------------------------------------- + +class TestAgentOnAgentErrorCallbackLive: + + @pytest.mark.asyncio + async def test_on_agent_error_callback_called_when_live_impl_raises(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError, match="live agent impl exploded"): + await _collect_live_events(_FailingLiveAgent(), [tracker]) + + assert len(tracker.agent_error_calls) == 1 + assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM + + @pytest.mark.asyncio + async def test_after_agent_callback_not_called_on_live_error(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError): + await _collect_live_events(_FailingLiveAgent(), [tracker]) + + assert not tracker.after_agent_called + + @pytest.mark.asyncio + async def test_original_live_exception_reraised_unchanged(self): + tracker = TrackingPlugin() + + with pytest.raises(RuntimeError) as exc_info: + await _collect_live_events(_FailingLiveAgent(), [tracker]) + + assert exc_info.value is _FailingLiveAgent.BOOM diff --git a/tests/unittests/plugins/test_lifecycle_error_callbacks.py b/tests/unittests/plugins/test_lifecycle_error_callbacks.py new file mode 100644 index 0000000000..8f0cd757b0 --- /dev/null +++ b/tests/unittests/plugins/test_lifecycle_error_callbacks.py @@ -0,0 +1,305 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the new on_run_error_callback and on_agent_error_callback +lifecycle hooks added to BasePlugin and PluginManager (issue #4774). + +These tests focus on: + 1. PluginManager._run_error_callbacks — fan-out semantics (all plugins + notified even when one raises). + 2. PluginManager.run_on_run_error_callback and + run_on_agent_error_callback — correct argument forwarding. + 3. BasePlugin default no-op implementations return None. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.plugins.plugin_manager import PluginManager +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class ObservabilityPlugin(BasePlugin): + """A test plugin that records every on_*_error_callback invocation.""" + + __test__ = False + + def __init__(self, name: str, *, raise_on_error: bool = False): + super().__init__(name) + self.agent_error_calls: list[dict] = [] + self.run_error_calls: list[dict] = [] + self._raise_on_error = raise_on_error + + async def on_agent_error_callback( + self, *, agent, callback_context, error + ) -> None: + self.agent_error_calls.append( + {"agent": agent, "callback_context": callback_context, "error": error} + ) + if self._raise_on_error: + raise RuntimeError(f"{self.name} intentionally raised in on_agent_error_callback") + + async def on_run_error_callback( + self, *, invocation_context, error + ) -> None: + self.run_error_calls.append( + {"invocation_context": invocation_context, "error": error} + ) + if self._raise_on_error: + raise RuntimeError(f"{self.name} intentionally raised in on_run_error_callback") + + +# --------------------------------------------------------------------------- +# BasePlugin default no-op tests +# --------------------------------------------------------------------------- + +class TestBasePluginDefaults: + """Verifies that the default BasePlugin implementations are safe no-ops.""" + + @pytest.mark.asyncio + async def test_on_agent_error_callback_returns_none_by_default(self): + """Default on_agent_error_callback must return None without raising.""" + class MinimalPlugin(BasePlugin): + pass + + plugin = MinimalPlugin(name="minimal") + result = await plugin.on_agent_error_callback( + agent=Mock(), + callback_context=Mock(), + error=ValueError("boom"), + ) + assert result is None + + @pytest.mark.asyncio + async def test_on_run_error_callback_returns_none_by_default(self): + """Default on_run_error_callback must return None without raising.""" + class MinimalPlugin(BasePlugin): + pass + + plugin = MinimalPlugin(name="minimal") + result = await plugin.on_run_error_callback( + invocation_context=Mock(), + error=RuntimeError("boom"), + ) + assert result is None + + +# --------------------------------------------------------------------------- +# PluginManager._run_error_callbacks fan-out semantics +# --------------------------------------------------------------------------- + +class TestRunErrorCallbacksFanOut: + """Verifies that _run_error_callbacks notifies ALL plugins regardless of + earlier plugin failures (unlike the early-exit _run_callbacks).""" + + @pytest.mark.asyncio + async def test_all_plugins_notified_when_no_failures(self): + """All plugins must be called when none raise.""" + plugin_a = ObservabilityPlugin("a") + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + inv_ctx = Mock(spec=InvocationContext) + error = ValueError("run exploded") + + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + assert len(plugin_a.run_error_calls) == 1 + assert len(plugin_b.run_error_calls) == 1 + assert plugin_a.run_error_calls[0]["error"] is error + assert plugin_b.run_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_subsequent_plugins_called_even_when_first_raises(self): + """If plugin_a raises during on_run_error_callback, plugin_b must still + be called — a broken observability plugin must not silence others.""" + plugin_a = ObservabilityPlugin("a", raise_on_error=True) + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + inv_ctx = Mock(spec=InvocationContext) + error = ValueError("original error") + + # Should NOT raise even though plugin_a raises internally + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + assert len(plugin_b.run_error_calls) == 1 + assert plugin_b.run_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_agent_error_all_plugins_notified_when_no_failures(self): + """All plugins must be called for on_agent_error_callback.""" + plugin_a = ObservabilityPlugin("a") + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = RuntimeError("agent died") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + assert len(plugin_a.agent_error_calls) == 1 + assert len(plugin_b.agent_error_calls) == 1 + assert plugin_a.agent_error_calls[0]["agent"] is agent + assert plugin_b.agent_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_agent_error_subsequent_plugins_called_even_when_first_raises( + self, + ): + """Broken plugin in on_agent_error_callback must not block others.""" + plugin_a = ObservabilityPlugin("a", raise_on_error=True) + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = RuntimeError("agent died") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + assert len(plugin_b.agent_error_calls) == 1 + + @pytest.mark.asyncio + async def test_no_plugins_registered_does_not_raise(self): + """run_on_run_error_callback with zero plugins must be a safe no-op.""" + manager = PluginManager() + await manager.run_on_run_error_callback( + invocation_context=Mock(), error=ValueError("x") + ) + + @pytest.mark.asyncio + async def test_no_plugins_registered_agent_does_not_raise(self): + """run_on_agent_error_callback with zero plugins must be a safe no-op.""" + manager = PluginManager() + await manager.run_on_agent_error_callback( + agent=Mock(), callback_context=Mock(), error=ValueError("x") + ) + + +# --------------------------------------------------------------------------- +# PluginManager — argument forwarding +# --------------------------------------------------------------------------- + +class TestArgumentForwarding: + """Verifies exact argument passing to each plugin callback.""" + + @pytest.mark.asyncio + async def test_run_on_run_error_callback_passes_correct_kwargs(self): + plugin = ObservabilityPlugin("p") + manager = PluginManager(plugins=[plugin]) + + inv_ctx = Mock(spec=InvocationContext) + error = KeyError("missing_key") + + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + call = plugin.run_error_calls[0] + assert call["invocation_context"] is inv_ctx + assert call["error"] is error + + @pytest.mark.asyncio + async def test_run_on_agent_error_callback_passes_correct_kwargs(self): + plugin = ObservabilityPlugin("p") + manager = PluginManager(plugins=[plugin]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = TypeError("bad_type") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + call = plugin.agent_error_calls[0] + assert call["agent"] is agent + assert call["callback_context"] is cb_ctx + assert call["error"] is error + + +# --------------------------------------------------------------------------- +# Contrast with _run_callbacks early-exit semantics +# --------------------------------------------------------------------------- + +class TestErrorCallbacksDoNotEarlyExit: + """Confirms error callbacks ignore non-None returns (no early exit).""" + + @pytest.mark.asyncio + async def test_on_run_error_return_value_is_ignored(self): + """on_run_error_callback returning a value must not stop other plugins.""" + + class ReturningPlugin(BasePlugin): + def __init__(self, name): + super().__init__(name) + self.called = False + + async def on_run_error_callback(self, **kwargs): + self.called = True + return "non-none-value-that-should-be-ignored" + + p1 = ReturningPlugin("p1") + p2 = ReturningPlugin("p2") + manager = PluginManager(plugins=[p1, p2]) + + await manager.run_on_run_error_callback( + invocation_context=Mock(), error=ValueError("x") + ) + + assert p1.called + assert p2.called # must NOT be short-circuited + + @pytest.mark.asyncio + async def test_on_agent_error_return_value_is_ignored(self): + """on_agent_error_callback returning a value must not stop other plugins.""" + + class ReturningPlugin(BasePlugin): + def __init__(self, name): + super().__init__(name) + self.called = False + + async def on_agent_error_callback(self, **kwargs): + self.called = True + return "non-none-value" + + p1 = ReturningPlugin("p1") + p2 = ReturningPlugin("p2") + manager = PluginManager(plugins=[p1, p2]) + + await manager.run_on_agent_error_callback( + agent=Mock(), callback_context=Mock(), error=ValueError("x") + ) + + assert p1.called + assert p2.called diff --git a/tests/unittests/runners/test_runner_error_callbacks.py b/tests/unittests/runners/test_runner_error_callbacks.py new file mode 100644 index 0000000000..ca2e723d84 --- /dev/null +++ b/tests/unittests/runners/test_runner_error_callbacks.py @@ -0,0 +1,279 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for on_run_error_callback in the Runner lifecycle. + +These tests verify that: + - on_run_error_callback is called when the execution generator raises. + - after_run_callback is NOT called on the error path. + - The original exception is re-raised unchanged. + - on_run_error_callback receives the correct invocation_context and error. + - on_run_error_callback is NOT called on successful runs. +""" + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import ClassVar +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.runners import InMemoryRunner +from google.genai import types +import pytest +from typing_extensions import override + +from .. import testing_utils + + +# --------------------------------------------------------------------------- +# Concrete agent implementations for testing +# --------------------------------------------------------------------------- + +class _SuccessAgent(BaseAgent): + """Agent that yields one event then completes successfully.""" + + def __init__(self): + super().__init__(name="success_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content( + parts=[types.Part.from_text(text="ok")] + ), + ) + + +class _FailingAgent(BaseAgent): + """Agent that raises a RuntimeError during execution.""" + + BOOM: ClassVar[RuntimeError] = RuntimeError("agent exploded mid-run") + + def __init__(self): + super().__init__(name="failing_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingAgent.BOOM + yield # make this an async generator + + +# --------------------------------------------------------------------------- +# Tracking plugin +# --------------------------------------------------------------------------- + +class TrackingPlugin(BasePlugin): + """Records lifecycle callback invocations for assertions.""" + + __test__ = False + + def __init__(self, name: str = "tracker"): + super().__init__(name) + self.after_run_called = False + self.run_error_calls: list[dict] = [] + + async def after_run_callback(self, *, invocation_context, **kwargs) -> None: + self.after_run_called = True + + async def on_run_error_callback(self, *, invocation_context, error, **kwargs) -> None: + self.run_error_calls.append( + {"invocation_context": invocation_context, "error": error} + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestRunnerOnRunErrorCallback: + + def _make_runner(self, agent: BaseAgent, plugins: list[BasePlugin]): + return InMemoryRunner(agent=agent, plugins=plugins) + + @pytest.mark.asyncio + async def test_on_run_error_callback_called_when_agent_raises(self): + """on_run_error_callback must fire when the agent execution raises.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u1" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError, match="agent exploded mid-run"): + async for _ in runner.run_async( + user_id="u1", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker.run_error_calls) == 1 + assert tracker.run_error_calls[0]["error"] is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_after_run_callback_not_called_on_error(self): + """after_run_callback must NOT be called when execution raises.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u2" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u2", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert not tracker.after_run_called + + @pytest.mark.asyncio + async def test_original_exception_is_reraised_unchanged(self): + """The exact original exception must propagate to the caller.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u3" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError) as exc_info: + async for _ in runner.run_async( + user_id="u3", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert exc_info.value is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_on_run_error_callback_receives_correct_invocation_context(self): + """The invocation_context passed to on_run_error_callback must be the + same one used for the run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u4" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u4", + session_id=session.id, + new_message=user_msg, + ): + pass + + inv_ctx = tracker.run_error_calls[0]["invocation_context"] + assert isinstance(inv_ctx, InvocationContext) + assert inv_ctx.session.id == session.id + + @pytest.mark.asyncio + async def test_on_run_error_callback_not_called_on_success(self): + """on_run_error_callback must NOT be called for a successful run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_SuccessAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u5" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + async for _ in runner.run_async( + user_id="u5", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker.run_error_calls) == 0 + + @pytest.mark.asyncio + async def test_after_run_callback_called_on_success(self): + """after_run_callback must still be called for a successful run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_SuccessAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u6" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + async for _ in runner.run_async( + user_id="u6", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert tracker.after_run_called + + @pytest.mark.asyncio + async def test_multiple_plugins_all_notified_on_error(self): + """All registered plugins must receive on_run_error_callback on failure.""" + tracker_a = TrackingPlugin("tracker_a") + tracker_b = TrackingPlugin("tracker_b") + runner = self._make_runner(_FailingAgent(), [tracker_a, tracker_b]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u7" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u7", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker_a.run_error_calls) == 1 + assert len(tracker_b.run_error_calls) == 1 From cd4ac4efcfc9746b7762ade7c34bc59e0799583c Mon Sep 17 00:00:00 2001 From: STHITAPRAJNAS Date: Wed, 25 Mar 2026 01:19:41 +0000 Subject: [PATCH 2/2] style: apply pyink + isort formatting to new test files --- .../agents/test_agent_error_callbacks.py | 275 +++++------ .../plugins/test_lifecycle_error_callbacks.py | 442 +++++++++--------- .../runners/test_runner_error_callbacks.py | 412 ++++++++-------- 3 files changed, 575 insertions(+), 554 deletions(-) diff --git a/tests/unittests/agents/test_agent_error_callbacks.py b/tests/unittests/agents/test_agent_error_callbacks.py index 01f708c694..395a9e652d 100644 --- a/tests/unittests/agents/test_agent_error_callbacks.py +++ b/tests/unittests/agents/test_agent_error_callbacks.py @@ -41,219 +41,226 @@ from .. import testing_utils - # --------------------------------------------------------------------------- # Concrete agent implementations # --------------------------------------------------------------------------- + class _SuccessAgent(BaseAgent): - def __init__(self): - super().__init__(name="success_agent") - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - yield Event( - invocation_id=ctx.invocation_id, - author=self.name, - content=types.Content(parts=[types.Part.from_text(text="done")]), - ) + def __init__(self): + super().__init__(name="success_agent") + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content(parts=[types.Part.from_text(text="done")]), + ) class _FailingAgent(BaseAgent): - BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded") + BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded") - def __init__(self): - super().__init__(name="failing_agent") + def __init__(self): + super().__init__(name="failing_agent") - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - raise _FailingAgent.BOOM - yield # pragma: no cover + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingAgent.BOOM + yield # pragma: no cover class _FailingLiveAgent(BaseAgent): - BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded") + BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded") - def __init__(self): - super().__init__(name="failing_live_agent") + def __init__(self): + super().__init__(name="failing_live_agent") - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - yield # pragma: no cover + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield # pragma: no cover - @override - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - raise _FailingLiveAgent.BOOM - yield # pragma: no cover + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingLiveAgent.BOOM + yield # pragma: no cover # --------------------------------------------------------------------------- # Tracking plugin # --------------------------------------------------------------------------- + class TrackingPlugin(BasePlugin): - __test__ = False + __test__ = False - def __init__(self, name: str = "tracker"): - super().__init__(name) - self.after_agent_called = False - self.agent_error_calls: list[dict] = [] + def __init__(self, name: str = "tracker"): + super().__init__(name) + self.after_agent_called = False + self.agent_error_calls: list[dict] = [] - async def after_agent_callback(self, *, agent, callback_context, **kwargs): - self.after_agent_called = True + async def after_agent_callback(self, *, agent, callback_context, **kwargs): + self.after_agent_called = True - async def on_agent_error_callback( - self, *, agent, callback_context, error, **kwargs - ) -> None: - self.agent_error_calls.append( - {"agent": agent, "callback_context": callback_context, "error": error} - ) + async def on_agent_error_callback( + self, *, agent, callback_context, error, **kwargs + ) -> None: + self.agent_error_calls.append( + {"agent": agent, "callback_context": callback_context, "error": error} + ) # --------------------------------------------------------------------------- # Helper to drive run_async # --------------------------------------------------------------------------- -async def _collect_events(agent: BaseAgent, plugins: list[BasePlugin]) -> list[Event]: - inv_ctx = await testing_utils.create_invocation_context( - agent=agent, plugins=plugins - ) - events = [] - async for event in agent.run_async(inv_ctx): - events.append(event) - return events + +async def _collect_events( + agent: BaseAgent, plugins: list[BasePlugin] +) -> list[Event]: + inv_ctx = await testing_utils.create_invocation_context( + agent=agent, plugins=plugins + ) + events = [] + async for event in agent.run_async(inv_ctx): + events.append(event) + return events async def _collect_live_events( agent: BaseAgent, plugins: list[BasePlugin] ) -> list[Event]: - inv_ctx = await testing_utils.create_invocation_context( - agent=agent, plugins=plugins - ) - events = [] - async for event in agent.run_live(inv_ctx): - events.append(event) - return events + inv_ctx = await testing_utils.create_invocation_context( + agent=agent, plugins=plugins + ) + events = [] + async for event in agent.run_live(inv_ctx): + events.append(event) + return events # --------------------------------------------------------------------------- # Tests — run_async path # --------------------------------------------------------------------------- + class TestAgentOnAgentErrorCallbackAsync: - @pytest.mark.asyncio - async def test_on_agent_error_callback_called_when_impl_raises(self): - tracker = TrackingPlugin() - with pytest.raises(RuntimeError, match="agent impl exploded"): - await _collect_events(_FailingAgent(), [tracker]) + @pytest.mark.asyncio + async def test_on_agent_error_callback_called_when_impl_raises(self): + tracker = TrackingPlugin() + with pytest.raises(RuntimeError, match="agent impl exploded"): + await _collect_events(_FailingAgent(), [tracker]) - assert len(tracker.agent_error_calls) == 1 - assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM + assert len(tracker.agent_error_calls) == 1 + assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM - @pytest.mark.asyncio - async def test_on_agent_error_callback_receives_correct_agent(self): - tracker = TrackingPlugin() - agent = _FailingAgent() + @pytest.mark.asyncio + async def test_on_agent_error_callback_receives_correct_agent(self): + tracker = TrackingPlugin() + agent = _FailingAgent() - with pytest.raises(RuntimeError): - await _collect_events(agent, [tracker]) + with pytest.raises(RuntimeError): + await _collect_events(agent, [tracker]) - assert tracker.agent_error_calls[0]["agent"] is agent + assert tracker.agent_error_calls[0]["agent"] is agent - @pytest.mark.asyncio - async def test_on_agent_error_callback_receives_callback_context(self): - tracker = TrackingPlugin() + @pytest.mark.asyncio + async def test_on_agent_error_callback_receives_callback_context(self): + tracker = TrackingPlugin() - with pytest.raises(RuntimeError): - await _collect_events(_FailingAgent(), [tracker]) + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker]) - cb_ctx = tracker.agent_error_calls[0]["callback_context"] - assert isinstance(cb_ctx, CallbackContext) + cb_ctx = tracker.agent_error_calls[0]["callback_context"] + assert isinstance(cb_ctx, CallbackContext) - @pytest.mark.asyncio - async def test_original_exception_reraised_after_notification(self): - tracker = TrackingPlugin() + @pytest.mark.asyncio + async def test_original_exception_reraised_after_notification(self): + tracker = TrackingPlugin() - with pytest.raises(RuntimeError) as exc_info: - await _collect_events(_FailingAgent(), [tracker]) + with pytest.raises(RuntimeError) as exc_info: + await _collect_events(_FailingAgent(), [tracker]) - assert exc_info.value is _FailingAgent.BOOM + assert exc_info.value is _FailingAgent.BOOM - @pytest.mark.asyncio - async def test_after_agent_callback_not_called_on_error(self): - tracker = TrackingPlugin() + @pytest.mark.asyncio + async def test_after_agent_callback_not_called_on_error(self): + tracker = TrackingPlugin() - with pytest.raises(RuntimeError): - await _collect_events(_FailingAgent(), [tracker]) + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker]) - assert not tracker.after_agent_called + assert not tracker.after_agent_called - @pytest.mark.asyncio - async def test_on_agent_error_callback_not_called_on_success(self): - tracker = TrackingPlugin() - events = await _collect_events(_SuccessAgent(), [tracker]) + @pytest.mark.asyncio + async def test_on_agent_error_callback_not_called_on_success(self): + tracker = TrackingPlugin() + events = await _collect_events(_SuccessAgent(), [tracker]) - assert len(events) >= 1 - assert len(tracker.agent_error_calls) == 0 + assert len(events) >= 1 + assert len(tracker.agent_error_calls) == 0 - @pytest.mark.asyncio - async def test_after_agent_callback_still_called_on_success(self): - tracker = TrackingPlugin() - await _collect_events(_SuccessAgent(), [tracker]) + @pytest.mark.asyncio + async def test_after_agent_callback_still_called_on_success(self): + tracker = TrackingPlugin() + await _collect_events(_SuccessAgent(), [tracker]) - assert tracker.after_agent_called + assert tracker.after_agent_called - @pytest.mark.asyncio - async def test_multiple_plugins_all_notified_on_agent_error(self): - tracker_a = TrackingPlugin("a") - tracker_b = TrackingPlugin("b") + @pytest.mark.asyncio + async def test_multiple_plugins_all_notified_on_agent_error(self): + tracker_a = TrackingPlugin("a") + tracker_b = TrackingPlugin("b") - with pytest.raises(RuntimeError): - await _collect_events(_FailingAgent(), [tracker_a, tracker_b]) + with pytest.raises(RuntimeError): + await _collect_events(_FailingAgent(), [tracker_a, tracker_b]) - assert len(tracker_a.agent_error_calls) == 1 - assert len(tracker_b.agent_error_calls) == 1 + assert len(tracker_a.agent_error_calls) == 1 + assert len(tracker_b.agent_error_calls) == 1 # --------------------------------------------------------------------------- # Tests — run_live path # --------------------------------------------------------------------------- + class TestAgentOnAgentErrorCallbackLive: - @pytest.mark.asyncio - async def test_on_agent_error_callback_called_when_live_impl_raises(self): - tracker = TrackingPlugin() + @pytest.mark.asyncio + async def test_on_agent_error_callback_called_when_live_impl_raises(self): + tracker = TrackingPlugin() - with pytest.raises(RuntimeError, match="live agent impl exploded"): - await _collect_live_events(_FailingLiveAgent(), [tracker]) + with pytest.raises(RuntimeError, match="live agent impl exploded"): + await _collect_live_events(_FailingLiveAgent(), [tracker]) - assert len(tracker.agent_error_calls) == 1 - assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM + assert len(tracker.agent_error_calls) == 1 + assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM - @pytest.mark.asyncio - async def test_after_agent_callback_not_called_on_live_error(self): - tracker = TrackingPlugin() + @pytest.mark.asyncio + async def test_after_agent_callback_not_called_on_live_error(self): + tracker = TrackingPlugin() - with pytest.raises(RuntimeError): - await _collect_live_events(_FailingLiveAgent(), [tracker]) + with pytest.raises(RuntimeError): + await _collect_live_events(_FailingLiveAgent(), [tracker]) - assert not tracker.after_agent_called + assert not tracker.after_agent_called - @pytest.mark.asyncio - async def test_original_live_exception_reraised_unchanged(self): - tracker = TrackingPlugin() + @pytest.mark.asyncio + async def test_original_live_exception_reraised_unchanged(self): + tracker = TrackingPlugin() - with pytest.raises(RuntimeError) as exc_info: - await _collect_live_events(_FailingLiveAgent(), [tracker]) + with pytest.raises(RuntimeError) as exc_info: + await _collect_live_events(_FailingLiveAgent(), [tracker]) - assert exc_info.value is _FailingLiveAgent.BOOM + assert exc_info.value is _FailingLiveAgent.BOOM diff --git a/tests/unittests/plugins/test_lifecycle_error_callbacks.py b/tests/unittests/plugins/test_lifecycle_error_callbacks.py index 8f0cd757b0..c95d55c772 100644 --- a/tests/unittests/plugins/test_lifecycle_error_callbacks.py +++ b/tests/unittests/plugins/test_lifecycle_error_callbacks.py @@ -28,278 +28,288 @@ from unittest.mock import AsyncMock from unittest.mock import Mock -from google.adk.agents.callback_context import CallbackContext from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.plugins.base_plugin import BasePlugin from google.adk.plugins.plugin_manager import PluginManager import pytest - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -class ObservabilityPlugin(BasePlugin): - """A test plugin that records every on_*_error_callback invocation.""" - - __test__ = False - def __init__(self, name: str, *, raise_on_error: bool = False): - super().__init__(name) - self.agent_error_calls: list[dict] = [] - self.run_error_calls: list[dict] = [] - self._raise_on_error = raise_on_error - - async def on_agent_error_callback( - self, *, agent, callback_context, error - ) -> None: - self.agent_error_calls.append( - {"agent": agent, "callback_context": callback_context, "error": error} - ) - if self._raise_on_error: - raise RuntimeError(f"{self.name} intentionally raised in on_agent_error_callback") - - async def on_run_error_callback( - self, *, invocation_context, error - ) -> None: - self.run_error_calls.append( - {"invocation_context": invocation_context, "error": error} - ) - if self._raise_on_error: - raise RuntimeError(f"{self.name} intentionally raised in on_run_error_callback") +class ObservabilityPlugin(BasePlugin): + """A test plugin that records every on_*_error_callback invocation.""" + + __test__ = False + + def __init__(self, name: str, *, raise_on_error: bool = False): + super().__init__(name) + self.agent_error_calls: list[dict] = [] + self.run_error_calls: list[dict] = [] + self._raise_on_error = raise_on_error + + async def on_agent_error_callback( + self, *, agent, callback_context, error + ) -> None: + self.agent_error_calls.append( + {"agent": agent, "callback_context": callback_context, "error": error} + ) + if self._raise_on_error: + raise RuntimeError( + f"{self.name} intentionally raised in on_agent_error_callback" + ) + + async def on_run_error_callback(self, *, invocation_context, error) -> None: + self.run_error_calls.append( + {"invocation_context": invocation_context, "error": error} + ) + if self._raise_on_error: + raise RuntimeError( + f"{self.name} intentionally raised in on_run_error_callback" + ) # --------------------------------------------------------------------------- # BasePlugin default no-op tests # --------------------------------------------------------------------------- + class TestBasePluginDefaults: - """Verifies that the default BasePlugin implementations are safe no-ops.""" - - @pytest.mark.asyncio - async def test_on_agent_error_callback_returns_none_by_default(self): - """Default on_agent_error_callback must return None without raising.""" - class MinimalPlugin(BasePlugin): - pass - - plugin = MinimalPlugin(name="minimal") - result = await plugin.on_agent_error_callback( - agent=Mock(), - callback_context=Mock(), - error=ValueError("boom"), - ) - assert result is None - - @pytest.mark.asyncio - async def test_on_run_error_callback_returns_none_by_default(self): - """Default on_run_error_callback must return None without raising.""" - class MinimalPlugin(BasePlugin): - pass - - plugin = MinimalPlugin(name="minimal") - result = await plugin.on_run_error_callback( - invocation_context=Mock(), - error=RuntimeError("boom"), - ) - assert result is None + """Verifies that the default BasePlugin implementations are safe no-ops.""" + + @pytest.mark.asyncio + async def test_on_agent_error_callback_returns_none_by_default(self): + """Default on_agent_error_callback must return None without raising.""" + + class MinimalPlugin(BasePlugin): + pass + + plugin = MinimalPlugin(name="minimal") + result = await plugin.on_agent_error_callback( + agent=Mock(), + callback_context=Mock(), + error=ValueError("boom"), + ) + assert result is None + + @pytest.mark.asyncio + async def test_on_run_error_callback_returns_none_by_default(self): + """Default on_run_error_callback must return None without raising.""" + + class MinimalPlugin(BasePlugin): + pass + + plugin = MinimalPlugin(name="minimal") + result = await plugin.on_run_error_callback( + invocation_context=Mock(), + error=RuntimeError("boom"), + ) + assert result is None # --------------------------------------------------------------------------- # PluginManager._run_error_callbacks fan-out semantics # --------------------------------------------------------------------------- + class TestRunErrorCallbacksFanOut: - """Verifies that _run_error_callbacks notifies ALL plugins regardless of - earlier plugin failures (unlike the early-exit _run_callbacks).""" - - @pytest.mark.asyncio - async def test_all_plugins_notified_when_no_failures(self): - """All plugins must be called when none raise.""" - plugin_a = ObservabilityPlugin("a") - plugin_b = ObservabilityPlugin("b") - manager = PluginManager(plugins=[plugin_a, plugin_b]) - - inv_ctx = Mock(spec=InvocationContext) - error = ValueError("run exploded") - - await manager.run_on_run_error_callback( - invocation_context=inv_ctx, error=error - ) - - assert len(plugin_a.run_error_calls) == 1 - assert len(plugin_b.run_error_calls) == 1 - assert plugin_a.run_error_calls[0]["error"] is error - assert plugin_b.run_error_calls[0]["error"] is error - - @pytest.mark.asyncio - async def test_subsequent_plugins_called_even_when_first_raises(self): - """If plugin_a raises during on_run_error_callback, plugin_b must still - be called — a broken observability plugin must not silence others.""" - plugin_a = ObservabilityPlugin("a", raise_on_error=True) - plugin_b = ObservabilityPlugin("b") - manager = PluginManager(plugins=[plugin_a, plugin_b]) - - inv_ctx = Mock(spec=InvocationContext) - error = ValueError("original error") - - # Should NOT raise even though plugin_a raises internally - await manager.run_on_run_error_callback( - invocation_context=inv_ctx, error=error - ) - - assert len(plugin_b.run_error_calls) == 1 - assert plugin_b.run_error_calls[0]["error"] is error - - @pytest.mark.asyncio - async def test_agent_error_all_plugins_notified_when_no_failures(self): - """All plugins must be called for on_agent_error_callback.""" - plugin_a = ObservabilityPlugin("a") - plugin_b = ObservabilityPlugin("b") - manager = PluginManager(plugins=[plugin_a, plugin_b]) - - agent = Mock(spec=BaseAgent) - cb_ctx = Mock(spec=CallbackContext) - error = RuntimeError("agent died") - - await manager.run_on_agent_error_callback( - agent=agent, callback_context=cb_ctx, error=error - ) - - assert len(plugin_a.agent_error_calls) == 1 - assert len(plugin_b.agent_error_calls) == 1 - assert plugin_a.agent_error_calls[0]["agent"] is agent - assert plugin_b.agent_error_calls[0]["error"] is error - - @pytest.mark.asyncio - async def test_agent_error_subsequent_plugins_called_even_when_first_raises( - self, - ): - """Broken plugin in on_agent_error_callback must not block others.""" - plugin_a = ObservabilityPlugin("a", raise_on_error=True) - plugin_b = ObservabilityPlugin("b") - manager = PluginManager(plugins=[plugin_a, plugin_b]) - - agent = Mock(spec=BaseAgent) - cb_ctx = Mock(spec=CallbackContext) - error = RuntimeError("agent died") - - await manager.run_on_agent_error_callback( - agent=agent, callback_context=cb_ctx, error=error - ) - - assert len(plugin_b.agent_error_calls) == 1 - - @pytest.mark.asyncio - async def test_no_plugins_registered_does_not_raise(self): - """run_on_run_error_callback with zero plugins must be a safe no-op.""" - manager = PluginManager() - await manager.run_on_run_error_callback( - invocation_context=Mock(), error=ValueError("x") - ) - - @pytest.mark.asyncio - async def test_no_plugins_registered_agent_does_not_raise(self): - """run_on_agent_error_callback with zero plugins must be a safe no-op.""" - manager = PluginManager() - await manager.run_on_agent_error_callback( - agent=Mock(), callback_context=Mock(), error=ValueError("x") - ) + """Verifies that _run_error_callbacks notifies ALL plugins regardless of + earlier plugin failures (unlike the early-exit _run_callbacks).""" + + @pytest.mark.asyncio + async def test_all_plugins_notified_when_no_failures(self): + """All plugins must be called when none raise.""" + plugin_a = ObservabilityPlugin("a") + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + inv_ctx = Mock(spec=InvocationContext) + error = ValueError("run exploded") + + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + assert len(plugin_a.run_error_calls) == 1 + assert len(plugin_b.run_error_calls) == 1 + assert plugin_a.run_error_calls[0]["error"] is error + assert plugin_b.run_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_subsequent_plugins_called_even_when_first_raises(self): + """If plugin_a raises during on_run_error_callback, plugin_b must still + be called — a broken observability plugin must not silence others.""" + plugin_a = ObservabilityPlugin("a", raise_on_error=True) + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + inv_ctx = Mock(spec=InvocationContext) + error = ValueError("original error") + + # Should NOT raise even though plugin_a raises internally + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) + + assert len(plugin_b.run_error_calls) == 1 + assert plugin_b.run_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_agent_error_all_plugins_notified_when_no_failures(self): + """All plugins must be called for on_agent_error_callback.""" + plugin_a = ObservabilityPlugin("a") + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = RuntimeError("agent died") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + assert len(plugin_a.agent_error_calls) == 1 + assert len(plugin_b.agent_error_calls) == 1 + assert plugin_a.agent_error_calls[0]["agent"] is agent + assert plugin_b.agent_error_calls[0]["error"] is error + + @pytest.mark.asyncio + async def test_agent_error_subsequent_plugins_called_even_when_first_raises( + self, + ): + """Broken plugin in on_agent_error_callback must not block others.""" + plugin_a = ObservabilityPlugin("a", raise_on_error=True) + plugin_b = ObservabilityPlugin("b") + manager = PluginManager(plugins=[plugin_a, plugin_b]) + + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = RuntimeError("agent died") + + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) + + assert len(plugin_b.agent_error_calls) == 1 + + @pytest.mark.asyncio + async def test_no_plugins_registered_does_not_raise(self): + """run_on_run_error_callback with zero plugins must be a safe no-op.""" + manager = PluginManager() + await manager.run_on_run_error_callback( + invocation_context=Mock(), error=ValueError("x") + ) + + @pytest.mark.asyncio + async def test_no_plugins_registered_agent_does_not_raise(self): + """run_on_agent_error_callback with zero plugins must be a safe no-op.""" + manager = PluginManager() + await manager.run_on_agent_error_callback( + agent=Mock(), callback_context=Mock(), error=ValueError("x") + ) # --------------------------------------------------------------------------- # PluginManager — argument forwarding # --------------------------------------------------------------------------- + class TestArgumentForwarding: - """Verifies exact argument passing to each plugin callback.""" + """Verifies exact argument passing to each plugin callback.""" - @pytest.mark.asyncio - async def test_run_on_run_error_callback_passes_correct_kwargs(self): - plugin = ObservabilityPlugin("p") - manager = PluginManager(plugins=[plugin]) + @pytest.mark.asyncio + async def test_run_on_run_error_callback_passes_correct_kwargs(self): + plugin = ObservabilityPlugin("p") + manager = PluginManager(plugins=[plugin]) - inv_ctx = Mock(spec=InvocationContext) - error = KeyError("missing_key") + inv_ctx = Mock(spec=InvocationContext) + error = KeyError("missing_key") - await manager.run_on_run_error_callback( - invocation_context=inv_ctx, error=error - ) + await manager.run_on_run_error_callback( + invocation_context=inv_ctx, error=error + ) - call = plugin.run_error_calls[0] - assert call["invocation_context"] is inv_ctx - assert call["error"] is error + call = plugin.run_error_calls[0] + assert call["invocation_context"] is inv_ctx + assert call["error"] is error - @pytest.mark.asyncio - async def test_run_on_agent_error_callback_passes_correct_kwargs(self): - plugin = ObservabilityPlugin("p") - manager = PluginManager(plugins=[plugin]) + @pytest.mark.asyncio + async def test_run_on_agent_error_callback_passes_correct_kwargs(self): + plugin = ObservabilityPlugin("p") + manager = PluginManager(plugins=[plugin]) - agent = Mock(spec=BaseAgent) - cb_ctx = Mock(spec=CallbackContext) - error = TypeError("bad_type") + agent = Mock(spec=BaseAgent) + cb_ctx = Mock(spec=CallbackContext) + error = TypeError("bad_type") - await manager.run_on_agent_error_callback( - agent=agent, callback_context=cb_ctx, error=error - ) + await manager.run_on_agent_error_callback( + agent=agent, callback_context=cb_ctx, error=error + ) - call = plugin.agent_error_calls[0] - assert call["agent"] is agent - assert call["callback_context"] is cb_ctx - assert call["error"] is error + call = plugin.agent_error_calls[0] + assert call["agent"] is agent + assert call["callback_context"] is cb_ctx + assert call["error"] is error # --------------------------------------------------------------------------- # Contrast with _run_callbacks early-exit semantics # --------------------------------------------------------------------------- + class TestErrorCallbacksDoNotEarlyExit: - """Confirms error callbacks ignore non-None returns (no early exit).""" + """Confirms error callbacks ignore non-None returns (no early exit).""" + + @pytest.mark.asyncio + async def test_on_run_error_return_value_is_ignored(self): + """on_run_error_callback returning a value must not stop other plugins.""" - @pytest.mark.asyncio - async def test_on_run_error_return_value_is_ignored(self): - """on_run_error_callback returning a value must not stop other plugins.""" + class ReturningPlugin(BasePlugin): - class ReturningPlugin(BasePlugin): - def __init__(self, name): - super().__init__(name) - self.called = False + def __init__(self, name): + super().__init__(name) + self.called = False - async def on_run_error_callback(self, **kwargs): - self.called = True - return "non-none-value-that-should-be-ignored" + async def on_run_error_callback(self, **kwargs): + self.called = True + return "non-none-value-that-should-be-ignored" - p1 = ReturningPlugin("p1") - p2 = ReturningPlugin("p2") - manager = PluginManager(plugins=[p1, p2]) + p1 = ReturningPlugin("p1") + p2 = ReturningPlugin("p2") + manager = PluginManager(plugins=[p1, p2]) - await manager.run_on_run_error_callback( - invocation_context=Mock(), error=ValueError("x") - ) + await manager.run_on_run_error_callback( + invocation_context=Mock(), error=ValueError("x") + ) - assert p1.called - assert p2.called # must NOT be short-circuited + assert p1.called + assert p2.called # must NOT be short-circuited - @pytest.mark.asyncio - async def test_on_agent_error_return_value_is_ignored(self): - """on_agent_error_callback returning a value must not stop other plugins.""" + @pytest.mark.asyncio + async def test_on_agent_error_return_value_is_ignored(self): + """on_agent_error_callback returning a value must not stop other plugins.""" - class ReturningPlugin(BasePlugin): - def __init__(self, name): - super().__init__(name) - self.called = False + class ReturningPlugin(BasePlugin): + + def __init__(self, name): + super().__init__(name) + self.called = False - async def on_agent_error_callback(self, **kwargs): - self.called = True - return "non-none-value" + async def on_agent_error_callback(self, **kwargs): + self.called = True + return "non-none-value" - p1 = ReturningPlugin("p1") - p2 = ReturningPlugin("p2") - manager = PluginManager(plugins=[p1, p2]) + p1 = ReturningPlugin("p1") + p2 = ReturningPlugin("p2") + manager = PluginManager(plugins=[p1, p2]) - await manager.run_on_agent_error_callback( - agent=Mock(), callback_context=Mock(), error=ValueError("x") - ) + await manager.run_on_agent_error_callback( + agent=Mock(), callback_context=Mock(), error=ValueError("x") + ) - assert p1.called - assert p2.called + assert p1.called + assert p2.called diff --git a/tests/unittests/runners/test_runner_error_callbacks.py b/tests/unittests/runners/test_runner_error_callbacks.py index ca2e723d84..226567dee3 100644 --- a/tests/unittests/runners/test_runner_error_callbacks.py +++ b/tests/unittests/runners/test_runner_error_callbacks.py @@ -41,239 +41,243 @@ from .. import testing_utils - # --------------------------------------------------------------------------- # Concrete agent implementations for testing # --------------------------------------------------------------------------- + class _SuccessAgent(BaseAgent): - """Agent that yields one event then completes successfully.""" + """Agent that yields one event then completes successfully.""" - def __init__(self): - super().__init__(name="success_agent") + def __init__(self): + super().__init__(name="success_agent") - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - yield Event( - invocation_id=ctx.invocation_id, - author=self.name, - content=types.Content( - parts=[types.Part.from_text(text="ok")] - ), - ) + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content(parts=[types.Part.from_text(text="ok")]), + ) class _FailingAgent(BaseAgent): - """Agent that raises a RuntimeError during execution.""" + """Agent that raises a RuntimeError during execution.""" - BOOM: ClassVar[RuntimeError] = RuntimeError("agent exploded mid-run") + BOOM: ClassVar[RuntimeError] = RuntimeError("agent exploded mid-run") - def __init__(self): - super().__init__(name="failing_agent") + def __init__(self): + super().__init__(name="failing_agent") - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - raise _FailingAgent.BOOM - yield # make this an async generator + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + raise _FailingAgent.BOOM + yield # make this an async generator # --------------------------------------------------------------------------- # Tracking plugin # --------------------------------------------------------------------------- + class TrackingPlugin(BasePlugin): - """Records lifecycle callback invocations for assertions.""" + """Records lifecycle callback invocations for assertions.""" - __test__ = False + __test__ = False - def __init__(self, name: str = "tracker"): - super().__init__(name) - self.after_run_called = False - self.run_error_calls: list[dict] = [] + def __init__(self, name: str = "tracker"): + super().__init__(name) + self.after_run_called = False + self.run_error_calls: list[dict] = [] - async def after_run_callback(self, *, invocation_context, **kwargs) -> None: - self.after_run_called = True + async def after_run_callback(self, *, invocation_context, **kwargs) -> None: + self.after_run_called = True - async def on_run_error_callback(self, *, invocation_context, error, **kwargs) -> None: - self.run_error_calls.append( - {"invocation_context": invocation_context, "error": error} - ) + async def on_run_error_callback( + self, *, invocation_context, error, **kwargs + ) -> None: + self.run_error_calls.append( + {"invocation_context": invocation_context, "error": error} + ) # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- + class TestRunnerOnRunErrorCallback: - def _make_runner(self, agent: BaseAgent, plugins: list[BasePlugin]): - return InMemoryRunner(agent=agent, plugins=plugins) - - @pytest.mark.asyncio - async def test_on_run_error_callback_called_when_agent_raises(self): - """on_run_error_callback must fire when the agent execution raises.""" - tracker = TrackingPlugin() - runner = self._make_runner(_FailingAgent(), [tracker]) - - session = await runner.session_service.create_session( - app_name=runner.app_name, user_id="u1" - ) - user_msg = types.Content( - parts=[types.Part.from_text(text="hello")], role="user" - ) - - with pytest.raises(RuntimeError, match="agent exploded mid-run"): - async for _ in runner.run_async( - user_id="u1", - session_id=session.id, - new_message=user_msg, - ): - pass - - assert len(tracker.run_error_calls) == 1 - assert tracker.run_error_calls[0]["error"] is _FailingAgent.BOOM - - @pytest.mark.asyncio - async def test_after_run_callback_not_called_on_error(self): - """after_run_callback must NOT be called when execution raises.""" - tracker = TrackingPlugin() - runner = self._make_runner(_FailingAgent(), [tracker]) - - session = await runner.session_service.create_session( - app_name=runner.app_name, user_id="u2" - ) - user_msg = types.Content( - parts=[types.Part.from_text(text="hello")], role="user" - ) - - with pytest.raises(RuntimeError): - async for _ in runner.run_async( - user_id="u2", - session_id=session.id, - new_message=user_msg, - ): - pass - - assert not tracker.after_run_called - - @pytest.mark.asyncio - async def test_original_exception_is_reraised_unchanged(self): - """The exact original exception must propagate to the caller.""" - tracker = TrackingPlugin() - runner = self._make_runner(_FailingAgent(), [tracker]) - - session = await runner.session_service.create_session( - app_name=runner.app_name, user_id="u3" - ) - user_msg = types.Content( - parts=[types.Part.from_text(text="hello")], role="user" - ) - - with pytest.raises(RuntimeError) as exc_info: - async for _ in runner.run_async( - user_id="u3", - session_id=session.id, - new_message=user_msg, - ): - pass - - assert exc_info.value is _FailingAgent.BOOM - - @pytest.mark.asyncio - async def test_on_run_error_callback_receives_correct_invocation_context(self): - """The invocation_context passed to on_run_error_callback must be the - same one used for the run.""" - tracker = TrackingPlugin() - runner = self._make_runner(_FailingAgent(), [tracker]) - - session = await runner.session_service.create_session( - app_name=runner.app_name, user_id="u4" - ) - user_msg = types.Content( - parts=[types.Part.from_text(text="hello")], role="user" - ) - - with pytest.raises(RuntimeError): - async for _ in runner.run_async( - user_id="u4", - session_id=session.id, - new_message=user_msg, - ): - pass - - inv_ctx = tracker.run_error_calls[0]["invocation_context"] - assert isinstance(inv_ctx, InvocationContext) - assert inv_ctx.session.id == session.id - - @pytest.mark.asyncio - async def test_on_run_error_callback_not_called_on_success(self): - """on_run_error_callback must NOT be called for a successful run.""" - tracker = TrackingPlugin() - runner = self._make_runner(_SuccessAgent(), [tracker]) - - session = await runner.session_service.create_session( - app_name=runner.app_name, user_id="u5" - ) - user_msg = types.Content( - parts=[types.Part.from_text(text="hello")], role="user" - ) - - async for _ in runner.run_async( - user_id="u5", - session_id=session.id, - new_message=user_msg, - ): - pass - - assert len(tracker.run_error_calls) == 0 - - @pytest.mark.asyncio - async def test_after_run_callback_called_on_success(self): - """after_run_callback must still be called for a successful run.""" - tracker = TrackingPlugin() - runner = self._make_runner(_SuccessAgent(), [tracker]) - - session = await runner.session_service.create_session( - app_name=runner.app_name, user_id="u6" - ) - user_msg = types.Content( - parts=[types.Part.from_text(text="hello")], role="user" - ) - - async for _ in runner.run_async( - user_id="u6", - session_id=session.id, - new_message=user_msg, - ): - pass - - assert tracker.after_run_called - - @pytest.mark.asyncio - async def test_multiple_plugins_all_notified_on_error(self): - """All registered plugins must receive on_run_error_callback on failure.""" - tracker_a = TrackingPlugin("tracker_a") - tracker_b = TrackingPlugin("tracker_b") - runner = self._make_runner(_FailingAgent(), [tracker_a, tracker_b]) - - session = await runner.session_service.create_session( - app_name=runner.app_name, user_id="u7" - ) - user_msg = types.Content( - parts=[types.Part.from_text(text="hello")], role="user" - ) - - with pytest.raises(RuntimeError): - async for _ in runner.run_async( - user_id="u7", - session_id=session.id, - new_message=user_msg, - ): - pass - - assert len(tracker_a.run_error_calls) == 1 - assert len(tracker_b.run_error_calls) == 1 + def _make_runner(self, agent: BaseAgent, plugins: list[BasePlugin]): + return InMemoryRunner(agent=agent, plugins=plugins) + + @pytest.mark.asyncio + async def test_on_run_error_callback_called_when_agent_raises(self): + """on_run_error_callback must fire when the agent execution raises.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u1" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError, match="agent exploded mid-run"): + async for _ in runner.run_async( + user_id="u1", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker.run_error_calls) == 1 + assert tracker.run_error_calls[0]["error"] is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_after_run_callback_not_called_on_error(self): + """after_run_callback must NOT be called when execution raises.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u2" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u2", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert not tracker.after_run_called + + @pytest.mark.asyncio + async def test_original_exception_is_reraised_unchanged(self): + """The exact original exception must propagate to the caller.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u3" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError) as exc_info: + async for _ in runner.run_async( + user_id="u3", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert exc_info.value is _FailingAgent.BOOM + + @pytest.mark.asyncio + async def test_on_run_error_callback_receives_correct_invocation_context( + self, + ): + """The invocation_context passed to on_run_error_callback must be the + same one used for the run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_FailingAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u4" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u4", + session_id=session.id, + new_message=user_msg, + ): + pass + + inv_ctx = tracker.run_error_calls[0]["invocation_context"] + assert isinstance(inv_ctx, InvocationContext) + assert inv_ctx.session.id == session.id + + @pytest.mark.asyncio + async def test_on_run_error_callback_not_called_on_success(self): + """on_run_error_callback must NOT be called for a successful run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_SuccessAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u5" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + async for _ in runner.run_async( + user_id="u5", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker.run_error_calls) == 0 + + @pytest.mark.asyncio + async def test_after_run_callback_called_on_success(self): + """after_run_callback must still be called for a successful run.""" + tracker = TrackingPlugin() + runner = self._make_runner(_SuccessAgent(), [tracker]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u6" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + async for _ in runner.run_async( + user_id="u6", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert tracker.after_run_called + + @pytest.mark.asyncio + async def test_multiple_plugins_all_notified_on_error(self): + """All registered plugins must receive on_run_error_callback on failure.""" + tracker_a = TrackingPlugin("tracker_a") + tracker_b = TrackingPlugin("tracker_b") + runner = self._make_runner(_FailingAgent(), [tracker_a, tracker_b]) + + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u7" + ) + user_msg = types.Content( + parts=[types.Part.from_text(text="hello")], role="user" + ) + + with pytest.raises(RuntimeError): + async for _ in runner.run_async( + user_id="u7", + session_id=session.id, + new_message=user_msg, + ): + pass + + assert len(tracker_a.run_error_calls) == 1 + assert len(tracker_b.run_error_calls) == 1