diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 7ded1fc205..53451c8d29 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio +import copy import datetime import json import logging @@ -25,6 +26,7 @@ from google.genai import types from google.genai.errors import ClientError +import pydantic from typing_extensions import override if TYPE_CHECKING: @@ -339,17 +341,41 @@ async def append_event(self, session: Session, event: Event) -> Event: value=usage_dict, ) config['event_metadata'] = metadata_dict + config['raw_event'] = event.model_dump( + exclude_none=True, + mode='json', + by_alias=True, + ) + # Retry without raw_event if client side validation fails for older SDK + # versions. async with self._get_api_client() as api_client: - await api_client.agent_engines.sessions.events.append( - name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}', - author=event.author, - invocation_id=event.invocation_id, - timestamp=datetime.datetime.fromtimestamp( - event.timestamp, tz=datetime.timezone.utc - ), - config=config, - ) + try: + await api_client.agent_engines.sessions.events.append( + name=( + f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}' + ), + author=event.author, + invocation_id=event.invocation_id, + timestamp=datetime.datetime.fromtimestamp( + event.timestamp, tz=datetime.timezone.utc + ), + config=config, + ) + except pydantic.ValidationError: + if 'raw_event' in config: + del config['raw_event'] + await api_client.agent_engines.sessions.events.append( + name=( + f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}' + ), + author=event.author, + invocation_id=event.invocation_id, + timestamp=datetime.datetime.fromtimestamp( + event.timestamp, tz=datetime.timezone.utc + ), + config=config, + ) return event def _get_reasoning_engine_id(self, app_name: str): @@ -395,8 +421,33 @@ def _get_api_client(self) -> vertexai.AsyncClient: ).aio +def _get_raw_event(api_event_obj: Any) -> dict[str, Any] | None: + """Extracts raw_event dict from SessionEvent object safely.""" + try: + return api_event_obj.raw_event + except AttributeError: + try: + return api_event_obj.rawEvent + except AttributeError: + return None + + def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: """Converts an API event object to an Event object.""" + # Read event data from raw_event first before falling back to top level + # fields. + raw_event_dict = _get_raw_event(api_event_obj) + if raw_event_dict: + event_dict = copy.deepcopy(raw_event_dict) + timestamp_obj = getattr(api_event_obj, 'timestamp', None) + event_dict.update({ + 'id': api_event_obj.name.split('/')[-1], + 'invocation_id': getattr(api_event_obj, 'invocation_id', None), + 'author': getattr(api_event_obj, 'author', None), + 'timestamp': timestamp_obj.timestamp() if timestamp_obj else None, + }) + return Event.model_validate(event_dict) + actions = getattr(api_event_obj, 'actions', None) event_metadata = getattr(api_event_obj, 'event_metadata', None) if event_metadata: diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 8f2b44c68f..0a00627225 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -28,6 +28,7 @@ from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.events.event_actions import EventCompaction +from google.adk.models.cache_metadata import CacheMetadata from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.session import Session from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService @@ -91,6 +92,7 @@ 'branch': '', 'long_running_tool_ids': ['tool1'], }, + 'raw_event': {}, }, ] MOCK_EVENT_JSON_2 = [ @@ -162,6 +164,96 @@ def _generate_mock_events_for_session_5(num_events): MANY_EVENTS_COUNT = 200 MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT) +MOCK_EVENT_WITH_OVERRIDE_JSON = [{ + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/override/events/1' + ), + 'invocationId': 'override_invoke', + 'author': 'user_with_override', + 'timestamp': '2024-12-12T12:12:12.123456Z', + 'content': { + 'parts': [ + {'text': 'top_level_content'}, + ], + }, + 'actions': { + 'transferToAgent': 'top_level_agent', + }, + 'eventMetadata': { + 'partial': True, + 'turnComplete': False, + 'interrupted': False, + 'branch': 'top_level_branch', + }, + 'errorCode': '111', + 'errorMessage': 'top_level_error', + 'rawEvent': { + 'invocationId': 'wrong_invocation_id', + 'author': 'wrong_author', + 'content': { + 'parts': [ + {'text': 'raw_event_content'}, + ], + }, + 'actions': { + 'transferToAgent': 'raw_event_agent', + }, + 'partial': False, + 'turnComplete': True, + 'interrupted': True, + 'branch': 'raw_event_branch', + 'errorCode': '222', + 'errorMessage': 'raw_event_error', + }, +}] + +MOCK_EVENT_WITH_OVERRIDE_JSON_2 = [{ + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/override/events/1' + ), + 'invocationId': 'override_invoke', + 'author': 'user_with_override', + 'content': {}, + 'actions': {}, + 'timestamp': '2024-12-12T12:12:12.123456Z', + 'rawEvent': { + 'invocationId': 'wrong_invocation_id', + 'author': 'wrong_author', + 'content': { + 'parts': [ + {'text': 'raw_event_content'}, + ], + }, + 'actions': { + 'skipSummarization': None, + 'stateDelta': {}, + 'artifactDelta': {}, + 'transferToAgent': 'raw_event_agent', + 'escalate': None, + 'requestedAuthConfigs': {}, + }, + 'errorCode': '222', + 'errorMessage': 'raw_event_error', + 'partial': False, + 'turnComplete': True, + 'interrupted': True, + 'branch': 'raw_event_branch', + 'customMetadata': None, + 'longRunningToolIds': None, + }, +}] + +MOCK_SESSION_WITH_OVERRIDE_JSON = { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/override' + ), + 'update_time': '2024-12-12T12:12:12.123456Z', + 'user_id': 'user_with_override', +} + MOCK_SESSION = Session( app_name='123', user_id='user', @@ -249,6 +341,8 @@ def _convert_to_object(data): 'artifact_delta', 'custom_metadata', 'requested_auth_configs', + 'rawEvent', + 'raw_event', ]: kwargs[key] = value else: @@ -680,6 +774,38 @@ async def test_get_session_keeps_events_newer_than_update_time( ) +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +@pytest.mark.parametrize( + 'mock_event_json', + [MOCK_EVENT_WITH_OVERRIDE_JSON, MOCK_EVENT_WITH_OVERRIDE_JSON_2], +) +async def test_get_session_from_raw_event( + mock_api_client_instance: MockAsyncClient, + mock_event_json, +) -> None: + mock_api_client_instance.session_dict['6'] = MOCK_SESSION_WITH_OVERRIDE_JSON + mock_api_client_instance.event_dict['6'] = ( + copy.deepcopy(mock_event_json), + None, + ) + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user_with_override', session_id='6' + ) + assert session is not None + assert len(session.events) == 1 + event = session.events[0] + assert event.content.parts[0].text == 'raw_event_content' + assert event.actions.transfer_to_agent == 'raw_event_agent' + assert not event.partial + assert event.turn_complete + assert event.interrupted + assert event.branch == 'raw_event_branch' + assert event.error_code == '222' + assert event.error_message == 'raw_event_error' + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_get_session_with_many_events(mock_api_client_instance): @@ -830,6 +956,36 @@ async def test_append_event(): branch='test_branch', custom_metadata={'custom': 'data'}, long_running_tool_ids={'tool2'}, + input_transcription=genai_types.Transcription( + text='test_input_transcription' + ), + output_transcription=genai_types.Transcription( + text='test_output_transcription' + ), + model_version='test_model_version', + avg_logprobs=0.5, + logprobs_result=genai_types.LogprobsResult( + chosen_candidates=[ + genai_types.LogprobsResultCandidate( + log_probability=0.5, + token='test_token', + token_id=0, + ) + ] + ), + cache_metadata=CacheMetadata( + cache_name='test_cache_name', + fingerprint='test_fingerprint', + contents_count=1, + ), + citation_metadata=genai_types.CitationMetadata( + citations=[ + genai_types.Citation( + uri='http://test.com', + title='test_title', + ) + ] + ), ) await session_service.append_event(session_before_append, event_to_append)