diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index dec85690b3..a73192f7b2 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -181,8 +181,12 @@ def _load_agent_state( """ if ctx.agent_states is None or self.name not in ctx.agent_states: return None - else: - return state_type.model_validate(ctx.agent_states.get(self.name)) + + raw_state = ctx.agent_states.get(self.name) + if raw_state == {} and state_type is not BaseAgentState: + return None + + return state_type.model_validate(raw_state) def _create_agent_state_event( self, diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index cd9e88f718..a879de385c 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -1020,6 +1020,26 @@ async def test_load_agent_state_with_resume(): assert state == persisted_state +@pytest.mark.asyncio +async def test_load_agent_state_ignores_base_placeholder_for_custom_state(): + agent = BaseAgent(name='test_agent') + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + ctx = InvocationContext( + invocation_id='test_invocation', + agent=agent, + session=session, + session_service=session_service, + resumability_config=ResumabilityConfig(is_resumable=True), + ) + ctx.agent_states[agent.name] = BaseAgentState().model_dump(mode='json') + + state = agent._load_agent_state(ctx, _TestAgentState) + assert state is None + + @pytest.mark.asyncio async def test_create_agent_state_event(): agent = BaseAgent(name='test_agent') diff --git a/tests/unittests/agents/test_sequential_agent.py b/tests/unittests/agents/test_sequential_agent.py index 85523d2dca..b40e2d139b 100644 --- a/tests/unittests/agents/test_sequential_agent.py +++ b/tests/unittests/agents/test_sequential_agent.py @@ -17,6 +17,7 @@ from typing import AsyncGenerator from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.base_agent import BaseAgentState from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.sequential_agent import SequentialAgent from google.adk.agents.sequential_agent import SequentialAgentState @@ -180,6 +181,43 @@ async def test_resume_async(request: pytest.FixtureRequest): assert events[1].actions.end_of_agent +@pytest.mark.asyncio +async def test_resume_async_ignores_base_placeholder_state( + request: pytest.FixtureRequest, +): + agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') + agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') + sequential_agent = SequentialAgent( + name=f'{request.function.__name__}_test_agent', + sub_agents=[ + agent_1, + agent_2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent, resumable=True + ) + parent_ctx.agent_states[sequential_agent.name] = BaseAgentState().model_dump( + mode='json' + ) + + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + assert len(events) == 5 + assert events[0].author == sequential_agent.name + assert events[0].actions.agent_state == SequentialAgentState( + current_sub_agent=agent_1.name + ).model_dump(mode='json') + assert events[1].author == agent_1.name + assert events[2].author == sequential_agent.name + assert events[2].actions.agent_state == SequentialAgentState( + current_sub_agent=agent_2.name + ).model_dump(mode='json') + assert events[3].author == agent_2.name + assert events[4].author == sequential_agent.name + assert events[4].actions.end_of_agent + + @pytest.mark.asyncio async def test_run_live(request: pytest.FixtureRequest): agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')