diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 9e5c9bb2ec..a3efeee60b 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -43,6 +43,7 @@ logger = logging.getLogger('google_adk.' + __name__) _COMPACTION_CUSTOM_METADATA_KEY = '_compaction' +_REWIND_CUSTOM_METADATA_KEY = '_rewind_before_invocation_id' _USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata' @@ -284,7 +285,8 @@ async def append_event(self, session: Session, event: Event) -> Event: }, # TODO: add requested_tool_confirmations, agent_state once # they are available in the API. - # Note: compaction is stored via event_metadata.custom_metadata. + # Note: compaction and rewind_before_invocation_id are stored via + # event_metadata.custom_metadata. } if event.error_code: config['error_code'] = event.error_code @@ -320,6 +322,16 @@ async def append_event(self, session: Session, event: Event) -> Event: key=_COMPACTION_CUSTOM_METADATA_KEY, value=compaction_dict, ) + # Store rewind_before_invocation_id in custom_metadata since the Vertex AI + # service does not yet support the field in EventActions. + # TODO: Stop writing to custom_metadata once the Vertex AI service + # supports rewind_before_invocation_id natively in EventActions. + if event.actions and event.actions.rewind_before_invocation_id: + _set_internal_custom_metadata( + metadata_dict, + key=_REWIND_CUSTOM_METADATA_KEY, + value=event.actions.rewind_before_invocation_id, + ) # Store usage_metadata in custom_metadata since the Vertex AI service # does not persist it in EventMetadata. if event.usage_metadata: @@ -405,15 +417,20 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: # written before native compaction support store compaction data # in custom_metadata under the compaction metadata key. compaction_data = None + rewind_data = None usage_metadata_data = None if custom_metadata and ( _COMPACTION_CUSTOM_METADATA_KEY in custom_metadata + or _REWIND_CUSTOM_METADATA_KEY in custom_metadata or _USAGE_METADATA_CUSTOM_METADATA_KEY in custom_metadata ): custom_metadata = dict(custom_metadata) # avoid mutating the API response compaction_data = custom_metadata.pop( _COMPACTION_CUSTOM_METADATA_KEY, None ) + rewind_data = custom_metadata.pop( + _REWIND_CUSTOM_METADATA_KEY, None + ) usage_metadata_data = custom_metadata.pop( _USAGE_METADATA_CUSTOM_METADATA_KEY, None ) @@ -431,6 +448,7 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: branch = None custom_metadata = None compaction_data = None + rewind_data = None usage_metadata_data = None grounding_metadata = None @@ -442,11 +460,18 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: } if compaction_data: renamed_actions_dict['compaction'] = compaction_data + if rewind_data: + renamed_actions_dict['rewind_before_invocation_id'] = rewind_data event_actions = EventActions.model_validate(renamed_actions_dict) else: - if compaction_data: + if compaction_data or rewind_data: event_actions = EventActions( - compaction=EventCompaction.model_validate(compaction_data) + compaction=( + EventCompaction.model_validate(compaction_data) + if compaction_data + else None + ), + rewind_before_invocation_id=rewind_data, ) else: event_actions = EventActions() diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 20fdbe3c6d..d65ea3f0ab 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -829,6 +829,35 @@ async def test_append_event(): assert retrieved_session.events[1] == event_to_append +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_with_rewind(): + """rewind_before_invocation_id round-trips through append_event and get_session.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + event_to_append = Event( + invocation_id='rewind_invocation', + author='model', + timestamp=1734005533.0, + actions=EventActions( + rewind_before_invocation_id='target_invocation', + ), + ) + + await session_service.append_event(session, event_to_append) + + retrieved_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + + appended_event = retrieved_session.events[-1] + assert ( + appended_event.actions.rewind_before_invocation_id == 'target_invocation' + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_append_event_with_compaction():