From e1913a6b411aec9e8774ca92ea39531b085c43f0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 3 Apr 2026 15:45:41 -0700 Subject: [PATCH] feat: Add custom session id functionality to vertex ai session service PiperOrigin-RevId: 894285971 --- .../adk/sessions/vertex_ai_session_service.py | 9 ++---- .../test_vertex_ai_session_service.py | 32 +++++++++++++------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 7ded1fc205..8025821975 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -115,16 +115,11 @@ async def create_session( Returns: The created session. """ - - if session_id: - raise ValueError( - 'User-provided Session id is not supported for' - ' VertexAISessionService.' - ) - reasoning_engine_id = self._get_reasoning_engine_id(app_name) config = {'session_state': state} if state else {} + if session_id: + config['session_id'] = session_id config.update(kwargs) async with self._get_api_client() as api_client: api_response = await api_client.agent_engines.sessions.create( diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 8f2b44c68f..1156cad916 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -324,7 +324,10 @@ async def _create_session( self, name: str, user_id: str, config: dict[str, Any] ): self.last_create_session_config = config - new_session_id = '4' + if 'session_id' in config: + new_session_id = config['session_id'] + else: + new_session_id = '4' self.session_dict[new_session_id] = { 'name': ( 'projects/test-project/locations/test-location/' @@ -343,7 +346,7 @@ async def _create_session( + '/operations/111' ), 'done': True, - 'response': self.session_dict['4'], + 'response': self.session_dict[new_session_id], }) async def _list_events(self, name: str, **kwargs): @@ -769,15 +772,26 @@ async def test_create_session(): @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') -async def test_create_session_with_custom_session_id(): +@pytest.mark.parametrize('session_id', ['1', 'abc123']) +async def test_create_session_with_custom_session_id( + mock_api_client_instance: MockAsyncClient, session_id: str +): session_service = mock_vertex_ai_session_service() - with pytest.raises(ValueError) as excinfo: - await session_service.create_session( - app_name='123', user_id='user', session_id='1' - ) - assert str(excinfo.value) == ( - 'User-provided Session id is not supported for VertexAISessionService.' + mock_api_client_instance.event_dict[session_id] = ( + [], + None, + ) + + session = await session_service.create_session( + app_name='123', user_id='user', session_id=session_id + ) + assert session.id == session_id + assert session.app_name == '123' + assert session.user_id == 'user' + assert session.last_update_time is not None + assert session == await session_service.get_session( + app_name='123', user_id='user', session_id=session_id )