From 03e285210b6bc524a4eda033765b805aedee11d5 Mon Sep 17 00:00:00 2001 From: ecantn Date: Mon, 23 Mar 2026 12:43:51 +0100 Subject: [PATCH 1/3] feat(sessions): add get_events() and filter_events() methods to Session --- src/google/adk/sessions/session.py | 103 +++++- .../sessions/test_session_filter_events.py | 307 ++++++++++++++++++ 2 files changed, 406 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/sessions/test_session_filter_events.py diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py index 24d200efdb..e1ea7f136d 100644 --- a/src/google/adk/sessions/session.py +++ b/src/google/adk/sessions/session.py @@ -1,4 +1,4 @@ -# Copyright 2026 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Session model with event filtering capabilities. + +This module contains the Session class which represents a series of interactions +between a user and agents, including methods for retrieving and filtering events. +""" + from __future__ import annotations from typing import Any @@ -20,7 +26,6 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from pydantic import PrivateAttr from ..events.event import Event @@ -50,5 +55,95 @@ class Session(BaseModel): last_update_time: float = 0.0 """The last update time of the session.""" - _storage_update_marker: str | None = PrivateAttr(default=None) - """Internal storage revision marker used for stale-session detection.""" + def get_events(self) -> list[Event]: + """Returns all events in the session. + + This method provides a consistent API for accessing events alongside + the filter_events() method. + + Returns: + A list containing all events in the session. + + Example: + >>> for event in session.get_events(): + ... print(event.author, event.content) + """ + return self.events + + def filter_events(self, *, exclude_rewound: bool = True) -> list[Event]: + """Returns filtered events from the session. + + This method provides convenient filtering of session events, with the + primary use case being exclusion of events that have been invalidated + by rewind operations. + + Args: + exclude_rewound: If True (default), excludes events that have been + invalidated by a rewind operation. When a session is rewound, + all events from the rewind target invocation onwards are + considered "rewound" and will be excluded. + + Returns: + A filtered list of events based on the specified criteria. + + Example: + >>> # Get only active events (excluding rewound ones) + >>> for event in session.filter_events(): + ... process_event(event) + + >>> # Get all events including rewound ones + >>> for event in session.filter_events(exclude_rewound=False): + ... process_all_events(event) + """ + if not exclude_rewound: + return self.events + return self._filter_rewound_events() + + def _filter_rewound_events(self) -> list[Event]: + """Filter out events that have been invalidated by a rewind operation. + + This method implements the rewind filtering logic: it iterates backward + through the events, and when a rewind event is found (identified by + having a non-None `actions.rewind_before_invocation_id`), it skips all + events from the rewind target invocation up to and including the rewind + event itself. + + The algorithm works as follows: + 1. Iterate through events from the end to the beginning + 2. When a rewind event is encountered, find the first event with the + target invocation_id and skip all events from that point to the + rewind event + 3. Events not affected by any rewind are included in the result + 4. The final list is reversed to maintain chronological order + + Returns: + A list of events with rewound events filtered out. + """ + if not self.events: + return [] + + filtered: list[Event] = [] + i = len(self.events) - 1 + + while i >= 0: + event = self.events[i] + + # Check if this is a rewind event + if event.actions and event.actions.rewind_before_invocation_id: + rewind_invocation_id = event.actions.rewind_before_invocation_id + + # Find the first event with the target invocation_id and skip to it + for j in range(0, i): + if self.events[j].invocation_id == rewind_invocation_id: + # Skip all events from j to i (inclusive of the rewind event) + i = j + break + else: + # Not a rewind event, include it + filtered.append(event) + + i -= 1 + + # Reverse to restore chronological order + filtered.reverse() + return filtered diff --git a/tests/unittests/sessions/test_session_filter_events.py b/tests/unittests/sessions/test_session_filter_events.py new file mode 100644 index 0000000000..3acb7fedd9 --- /dev/null +++ b/tests/unittests/sessions/test_session_filter_events.py @@ -0,0 +1,307 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Session.get_events() and Session.filter_events() methods.""" + +import pytest +from unittest.mock import MagicMock + +from google.adk.sessions.session import Session +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions + + +def _make_event( + invocation_id: str, + author: str = "user", + rewind_before_invocation_id: str | None = None, +) -> Event: + """Helper to create a mock Event with specified properties.""" + actions = EventActions() + if rewind_before_invocation_id: + actions.rewind_before_invocation_id = rewind_before_invocation_id + + return Event( + invocation_id=invocation_id, + author=author, + actions=actions, + ) + + +def _make_session(events: list[Event]) -> Session: + """Helper to create a Session with specified events.""" + return Session( + id="test-session-id", + app_name="test-app", + user_id="test-user", + events=events, + ) + + +class TestGetEvents: + """Tests for Session.get_events() method.""" + + def test_get_events_returns_all_events(self): + """Test that get_events returns all events in the session.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") + event_c = _make_event("inv-2", "user") + + session = _make_session([event_a, event_b, event_c]) + + result = session.get_events() + + assert len(result) == 3 + assert result[0] is event_a + assert result[1] is event_b + assert result[2] is event_c + + def test_get_events_returns_empty_list_when_no_events(self): + """Test that get_events returns empty list for session with no events.""" + session = _make_session([]) + + result = session.get_events() + + assert result == [] + + def test_get_events_returns_same_reference_as_events_property(self): + """Test that get_events returns the same list as session.events.""" + event_a = _make_event("inv-1") + session = _make_session([event_a]) + + result = session.get_events() + + assert result is session.events + + +class TestFilterEventsNoRewind: + """Tests for Session.filter_events() when no rewind events exist.""" + + def test_filter_events_returns_all_when_no_rewinds(self): + """Test that all events are returned when there are no rewind events.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") + event_c = _make_event("inv-2", "user") + + session = _make_session([event_a, event_b, event_c]) + + result = session.filter_events() + + assert len(result) == 3 + assert result[0].invocation_id == "inv-1" + assert result[0].author == "user" + assert result[1].invocation_id == "inv-1" + assert result[1].author == "agent" + assert result[2].invocation_id == "inv-2" + + def test_filter_events_returns_empty_list_when_no_events(self): + """Test that filter_events returns empty list for empty session.""" + session = _make_session([]) + + result = session.filter_events() + + assert result == [] + + def test_filter_events_exclude_rewound_false_returns_all(self): + """Test that exclude_rewound=False returns all events.""" + event_a = _make_event("inv-1") + event_b = _make_event("inv-2") + + session = _make_session([event_a, event_b]) + + result = session.filter_events(exclude_rewound=False) + + assert len(result) == 2 + assert result is session.events # Should return same reference + + +class TestFilterEventsWithRewind: + """Tests for Session.filter_events() with rewind events.""" + + def test_filter_events_single_rewind(self): + """Test that events from rewound invocation are filtered out.""" + # inv-1: user message + agent response (should remain) + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") + + # inv-2: user message + agent response (should be filtered) + event_c = _make_event("inv-2", "user") + event_d = _make_event("inv-2", "agent") + + # Rewind event targeting inv-2 + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-2" + ) + + session = _make_session([event_a, event_b, event_c, event_d, rewind_event]) + + result = session.filter_events() + + # Should only contain events from inv-1 + assert len(result) == 2 + assert result[0].invocation_id == "inv-1" + assert result[0].author == "user" + assert result[1].invocation_id == "inv-1" + assert result[1].author == "agent" + + def test_filter_events_multiple_sequential_rewinds(self): + """Test multiple sequential rewinds filter correctly.""" + # inv-1: should remain + event_a = _make_event("inv-1", "user") + + # inv-2: should be filtered by first rewind + event_b = _make_event("inv-2", "user") + + # First rewind: removes inv-2 + rewind_1 = _make_event( + "inv-rewind-1", "user", rewind_before_invocation_id="inv-2" + ) + + # inv-3: should remain + event_c = _make_event("inv-3", "user") + + # inv-4: should be filtered by second rewind + event_d = _make_event("inv-4", "user") + + # Second rewind: removes inv-4 + rewind_2 = _make_event( + "inv-rewind-2", "user", rewind_before_invocation_id="inv-4" + ) + + session = _make_session( + [event_a, event_b, rewind_1, event_c, event_d, rewind_2] + ) + + result = session.filter_events() + + # Should only contain inv-1 and inv-3 + assert len(result) == 2 + assert result[0].invocation_id == "inv-1" + assert result[1].invocation_id == "inv-3" + + def test_filter_events_rewind_to_first_invocation(self): + """Test rewind that goes back to the very first invocation.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-2", "user") + event_c = _make_event("inv-3", "user") + + # Rewind all the way back to inv-1 + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-1" + ) + + session = _make_session([event_a, event_b, event_c, rewind_event]) + + result = session.filter_events() + + # All events should be filtered out + assert len(result) == 0 + + def test_filter_events_with_exclude_rewound_false_includes_all(self): + """Test that exclude_rewound=False includes rewound events.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-2", "user") + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-2" + ) + + session = _make_session([event_a, event_b, rewind_event]) + + result = session.filter_events(exclude_rewound=False) + + # Should include all events including rewound ones + assert len(result) == 3 + + +class TestFilterEventsEdgeCases: + """Edge case tests for Session.filter_events().""" + + def test_filter_events_rewind_target_not_found(self): + """Test behavior when rewind target invocation doesn't exist.""" + event_a = _make_event("inv-1", "user") + + # Rewind targeting non-existent invocation + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-nonexistent" + ) + + session = _make_session([event_a, rewind_event]) + + result = session.filter_events() + + # Should still filter out the rewind event itself, keep inv-1 + assert len(result) == 1 + assert result[0].invocation_id == "inv-1" + + def test_filter_events_preserves_chronological_order(self): + """Test that filtered events maintain chronological order.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-2", "user") + event_c = _make_event("inv-3", "user") + event_d = _make_event("inv-4", "user") + + # Rewind inv-3 only + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-3" + ) + + event_e = _make_event("inv-5", "user") + + session = _make_session( + [event_a, event_b, event_c, event_d, rewind_event, event_e] + ) + + result = session.filter_events() + + # Should be inv-1, inv-2, inv-5 in order + assert len(result) == 3 + assert result[0].invocation_id == "inv-1" + assert result[1].invocation_id == "inv-2" + assert result[2].invocation_id == "inv-5" + + def test_filter_events_single_event_no_rewind(self): + """Test with single event and no rewind.""" + event_a = _make_event("inv-1", "user") + session = _make_session([event_a]) + + result = session.filter_events() + + assert len(result) == 1 + assert result[0].invocation_id == "inv-1" + + def test_filter_events_multiple_events_same_invocation(self): + """Test filtering with multiple events in the same invocation.""" + # Multiple events in inv-1 + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") + event_c = _make_event("inv-1", "agent") # Multiple agent responses + + # inv-2 events + event_d = _make_event("inv-2", "user") + event_e = _make_event("inv-2", "agent") + + # Rewind to inv-2 - should remove all inv-2 events + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-2" + ) + + session = _make_session( + [event_a, event_b, event_c, event_d, event_e, rewind_event] + ) + + result = session.filter_events() + + # Should only have the 3 events from inv-1 + assert len(result) == 3 + assert all(e.invocation_id == "inv-1" for e in result) From f024441a3477eff600fd241fdbc07271334b66fb Mon Sep 17 00:00:00 2001 From: ecantn Date: Mon, 23 Mar 2026 12:43:51 +0100 Subject: [PATCH 2/3] chore: update copyright year to 2026 in session filtering files --- src/google/adk/sessions/session.py | 2 +- tests/unittests/sessions/test_session_filter_events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py index e1ea7f136d..69abf39e0d 100644 --- a/src/google/adk/sessions/session.py +++ b/src/google/adk/sessions/session.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unittests/sessions/test_session_filter_events.py b/tests/unittests/sessions/test_session_filter_events.py index 3acb7fedd9..1908e76160 100644 --- a/tests/unittests/sessions/test_session_filter_events.py +++ b/tests/unittests/sessions/test_session_filter_events.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 45a34af2c416770d3e9b2bf7e96cf6008cf66bf4 Mon Sep 17 00:00:00 2001 From: ecantn Date: Tue, 24 Mar 2026 10:19:45 +0100 Subject: [PATCH 3/3] style: apply autoformat to test_session_filter_events.py --- .../sessions/test_session_filter_events.py | 400 +++++++++--------- 1 file changed, 200 insertions(+), 200 deletions(-) diff --git a/tests/unittests/sessions/test_session_filter_events.py b/tests/unittests/sessions/test_session_filter_events.py index 1908e76160..9ad0da91af 100644 --- a/tests/unittests/sessions/test_session_filter_events.py +++ b/tests/unittests/sessions/test_session_filter_events.py @@ -14,12 +14,12 @@ """Unit tests for Session.get_events() and Session.filter_events() methods.""" -import pytest from unittest.mock import MagicMock -from google.adk.sessions.session import Session from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.sessions.session import Session +import pytest def _make_event( @@ -27,281 +27,281 @@ def _make_event( author: str = "user", rewind_before_invocation_id: str | None = None, ) -> Event: - """Helper to create a mock Event with specified properties.""" - actions = EventActions() - if rewind_before_invocation_id: - actions.rewind_before_invocation_id = rewind_before_invocation_id - - return Event( - invocation_id=invocation_id, - author=author, - actions=actions, - ) + """Helper to create a mock Event with specified properties.""" + actions = EventActions() + if rewind_before_invocation_id: + actions.rewind_before_invocation_id = rewind_before_invocation_id + + return Event( + invocation_id=invocation_id, + author=author, + actions=actions, + ) def _make_session(events: list[Event]) -> Session: - """Helper to create a Session with specified events.""" - return Session( - id="test-session-id", - app_name="test-app", - user_id="test-user", - events=events, - ) + """Helper to create a Session with specified events.""" + return Session( + id="test-session-id", + app_name="test-app", + user_id="test-user", + events=events, + ) class TestGetEvents: - """Tests for Session.get_events() method.""" + """Tests for Session.get_events() method.""" - def test_get_events_returns_all_events(self): - """Test that get_events returns all events in the session.""" - event_a = _make_event("inv-1", "user") - event_b = _make_event("inv-1", "agent") - event_c = _make_event("inv-2", "user") + def test_get_events_returns_all_events(self): + """Test that get_events returns all events in the session.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") + event_c = _make_event("inv-2", "user") - session = _make_session([event_a, event_b, event_c]) + session = _make_session([event_a, event_b, event_c]) - result = session.get_events() + result = session.get_events() - assert len(result) == 3 - assert result[0] is event_a - assert result[1] is event_b - assert result[2] is event_c + assert len(result) == 3 + assert result[0] is event_a + assert result[1] is event_b + assert result[2] is event_c - def test_get_events_returns_empty_list_when_no_events(self): - """Test that get_events returns empty list for session with no events.""" - session = _make_session([]) + def test_get_events_returns_empty_list_when_no_events(self): + """Test that get_events returns empty list for session with no events.""" + session = _make_session([]) - result = session.get_events() + result = session.get_events() - assert result == [] + assert result == [] - def test_get_events_returns_same_reference_as_events_property(self): - """Test that get_events returns the same list as session.events.""" - event_a = _make_event("inv-1") - session = _make_session([event_a]) + def test_get_events_returns_same_reference_as_events_property(self): + """Test that get_events returns the same list as session.events.""" + event_a = _make_event("inv-1") + session = _make_session([event_a]) - result = session.get_events() + result = session.get_events() - assert result is session.events + assert result is session.events class TestFilterEventsNoRewind: - """Tests for Session.filter_events() when no rewind events exist.""" + """Tests for Session.filter_events() when no rewind events exist.""" - def test_filter_events_returns_all_when_no_rewinds(self): - """Test that all events are returned when there are no rewind events.""" - event_a = _make_event("inv-1", "user") - event_b = _make_event("inv-1", "agent") - event_c = _make_event("inv-2", "user") + def test_filter_events_returns_all_when_no_rewinds(self): + """Test that all events are returned when there are no rewind events.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") + event_c = _make_event("inv-2", "user") - session = _make_session([event_a, event_b, event_c]) + session = _make_session([event_a, event_b, event_c]) - result = session.filter_events() + result = session.filter_events() - assert len(result) == 3 - assert result[0].invocation_id == "inv-1" - assert result[0].author == "user" - assert result[1].invocation_id == "inv-1" - assert result[1].author == "agent" - assert result[2].invocation_id == "inv-2" + assert len(result) == 3 + assert result[0].invocation_id == "inv-1" + assert result[0].author == "user" + assert result[1].invocation_id == "inv-1" + assert result[1].author == "agent" + assert result[2].invocation_id == "inv-2" - def test_filter_events_returns_empty_list_when_no_events(self): - """Test that filter_events returns empty list for empty session.""" - session = _make_session([]) + def test_filter_events_returns_empty_list_when_no_events(self): + """Test that filter_events returns empty list for empty session.""" + session = _make_session([]) - result = session.filter_events() + result = session.filter_events() - assert result == [] + assert result == [] - def test_filter_events_exclude_rewound_false_returns_all(self): - """Test that exclude_rewound=False returns all events.""" - event_a = _make_event("inv-1") - event_b = _make_event("inv-2") + def test_filter_events_exclude_rewound_false_returns_all(self): + """Test that exclude_rewound=False returns all events.""" + event_a = _make_event("inv-1") + event_b = _make_event("inv-2") - session = _make_session([event_a, event_b]) + session = _make_session([event_a, event_b]) - result = session.filter_events(exclude_rewound=False) + result = session.filter_events(exclude_rewound=False) - assert len(result) == 2 - assert result is session.events # Should return same reference + assert len(result) == 2 + assert result is session.events # Should return same reference class TestFilterEventsWithRewind: - """Tests for Session.filter_events() with rewind events.""" + """Tests for Session.filter_events() with rewind events.""" - def test_filter_events_single_rewind(self): - """Test that events from rewound invocation are filtered out.""" - # inv-1: user message + agent response (should remain) - event_a = _make_event("inv-1", "user") - event_b = _make_event("inv-1", "agent") + def test_filter_events_single_rewind(self): + """Test that events from rewound invocation are filtered out.""" + # inv-1: user message + agent response (should remain) + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") - # inv-2: user message + agent response (should be filtered) - event_c = _make_event("inv-2", "user") - event_d = _make_event("inv-2", "agent") + # inv-2: user message + agent response (should be filtered) + event_c = _make_event("inv-2", "user") + event_d = _make_event("inv-2", "agent") - # Rewind event targeting inv-2 - rewind_event = _make_event( - "inv-rewind", "user", rewind_before_invocation_id="inv-2" - ) + # Rewind event targeting inv-2 + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-2" + ) - session = _make_session([event_a, event_b, event_c, event_d, rewind_event]) + session = _make_session([event_a, event_b, event_c, event_d, rewind_event]) - result = session.filter_events() + result = session.filter_events() - # Should only contain events from inv-1 - assert len(result) == 2 - assert result[0].invocation_id == "inv-1" - assert result[0].author == "user" - assert result[1].invocation_id == "inv-1" - assert result[1].author == "agent" + # Should only contain events from inv-1 + assert len(result) == 2 + assert result[0].invocation_id == "inv-1" + assert result[0].author == "user" + assert result[1].invocation_id == "inv-1" + assert result[1].author == "agent" - def test_filter_events_multiple_sequential_rewinds(self): - """Test multiple sequential rewinds filter correctly.""" - # inv-1: should remain - event_a = _make_event("inv-1", "user") + def test_filter_events_multiple_sequential_rewinds(self): + """Test multiple sequential rewinds filter correctly.""" + # inv-1: should remain + event_a = _make_event("inv-1", "user") - # inv-2: should be filtered by first rewind - event_b = _make_event("inv-2", "user") + # inv-2: should be filtered by first rewind + event_b = _make_event("inv-2", "user") - # First rewind: removes inv-2 - rewind_1 = _make_event( - "inv-rewind-1", "user", rewind_before_invocation_id="inv-2" - ) + # First rewind: removes inv-2 + rewind_1 = _make_event( + "inv-rewind-1", "user", rewind_before_invocation_id="inv-2" + ) - # inv-3: should remain - event_c = _make_event("inv-3", "user") + # inv-3: should remain + event_c = _make_event("inv-3", "user") - # inv-4: should be filtered by second rewind - event_d = _make_event("inv-4", "user") + # inv-4: should be filtered by second rewind + event_d = _make_event("inv-4", "user") - # Second rewind: removes inv-4 - rewind_2 = _make_event( - "inv-rewind-2", "user", rewind_before_invocation_id="inv-4" - ) + # Second rewind: removes inv-4 + rewind_2 = _make_event( + "inv-rewind-2", "user", rewind_before_invocation_id="inv-4" + ) - session = _make_session( - [event_a, event_b, rewind_1, event_c, event_d, rewind_2] - ) + session = _make_session( + [event_a, event_b, rewind_1, event_c, event_d, rewind_2] + ) - result = session.filter_events() + result = session.filter_events() - # Should only contain inv-1 and inv-3 - assert len(result) == 2 - assert result[0].invocation_id == "inv-1" - assert result[1].invocation_id == "inv-3" + # Should only contain inv-1 and inv-3 + assert len(result) == 2 + assert result[0].invocation_id == "inv-1" + assert result[1].invocation_id == "inv-3" - def test_filter_events_rewind_to_first_invocation(self): - """Test rewind that goes back to the very first invocation.""" - event_a = _make_event("inv-1", "user") - event_b = _make_event("inv-2", "user") - event_c = _make_event("inv-3", "user") + def test_filter_events_rewind_to_first_invocation(self): + """Test rewind that goes back to the very first invocation.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-2", "user") + event_c = _make_event("inv-3", "user") - # Rewind all the way back to inv-1 - rewind_event = _make_event( - "inv-rewind", "user", rewind_before_invocation_id="inv-1" - ) + # Rewind all the way back to inv-1 + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-1" + ) - session = _make_session([event_a, event_b, event_c, rewind_event]) + session = _make_session([event_a, event_b, event_c, rewind_event]) - result = session.filter_events() + result = session.filter_events() - # All events should be filtered out - assert len(result) == 0 + # All events should be filtered out + assert len(result) == 0 - def test_filter_events_with_exclude_rewound_false_includes_all(self): - """Test that exclude_rewound=False includes rewound events.""" - event_a = _make_event("inv-1", "user") - event_b = _make_event("inv-2", "user") - rewind_event = _make_event( - "inv-rewind", "user", rewind_before_invocation_id="inv-2" - ) + def test_filter_events_with_exclude_rewound_false_includes_all(self): + """Test that exclude_rewound=False includes rewound events.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-2", "user") + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-2" + ) - session = _make_session([event_a, event_b, rewind_event]) + session = _make_session([event_a, event_b, rewind_event]) - result = session.filter_events(exclude_rewound=False) + result = session.filter_events(exclude_rewound=False) - # Should include all events including rewound ones - assert len(result) == 3 + # Should include all events including rewound ones + assert len(result) == 3 class TestFilterEventsEdgeCases: - """Edge case tests for Session.filter_events().""" + """Edge case tests for Session.filter_events().""" - def test_filter_events_rewind_target_not_found(self): - """Test behavior when rewind target invocation doesn't exist.""" - event_a = _make_event("inv-1", "user") + def test_filter_events_rewind_target_not_found(self): + """Test behavior when rewind target invocation doesn't exist.""" + event_a = _make_event("inv-1", "user") - # Rewind targeting non-existent invocation - rewind_event = _make_event( - "inv-rewind", "user", rewind_before_invocation_id="inv-nonexistent" - ) + # Rewind targeting non-existent invocation + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-nonexistent" + ) - session = _make_session([event_a, rewind_event]) + session = _make_session([event_a, rewind_event]) - result = session.filter_events() + result = session.filter_events() - # Should still filter out the rewind event itself, keep inv-1 - assert len(result) == 1 - assert result[0].invocation_id == "inv-1" + # Should still filter out the rewind event itself, keep inv-1 + assert len(result) == 1 + assert result[0].invocation_id == "inv-1" - def test_filter_events_preserves_chronological_order(self): - """Test that filtered events maintain chronological order.""" - event_a = _make_event("inv-1", "user") - event_b = _make_event("inv-2", "user") - event_c = _make_event("inv-3", "user") - event_d = _make_event("inv-4", "user") + def test_filter_events_preserves_chronological_order(self): + """Test that filtered events maintain chronological order.""" + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-2", "user") + event_c = _make_event("inv-3", "user") + event_d = _make_event("inv-4", "user") - # Rewind inv-3 only - rewind_event = _make_event( - "inv-rewind", "user", rewind_before_invocation_id="inv-3" - ) + # Rewind inv-3 only + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-3" + ) - event_e = _make_event("inv-5", "user") + event_e = _make_event("inv-5", "user") - session = _make_session( - [event_a, event_b, event_c, event_d, rewind_event, event_e] - ) + session = _make_session( + [event_a, event_b, event_c, event_d, rewind_event, event_e] + ) - result = session.filter_events() + result = session.filter_events() - # Should be inv-1, inv-2, inv-5 in order - assert len(result) == 3 - assert result[0].invocation_id == "inv-1" - assert result[1].invocation_id == "inv-2" - assert result[2].invocation_id == "inv-5" + # Should be inv-1, inv-2, inv-5 in order + assert len(result) == 3 + assert result[0].invocation_id == "inv-1" + assert result[1].invocation_id == "inv-2" + assert result[2].invocation_id == "inv-5" - def test_filter_events_single_event_no_rewind(self): - """Test with single event and no rewind.""" - event_a = _make_event("inv-1", "user") - session = _make_session([event_a]) + def test_filter_events_single_event_no_rewind(self): + """Test with single event and no rewind.""" + event_a = _make_event("inv-1", "user") + session = _make_session([event_a]) - result = session.filter_events() + result = session.filter_events() - assert len(result) == 1 - assert result[0].invocation_id == "inv-1" + assert len(result) == 1 + assert result[0].invocation_id == "inv-1" - def test_filter_events_multiple_events_same_invocation(self): - """Test filtering with multiple events in the same invocation.""" - # Multiple events in inv-1 - event_a = _make_event("inv-1", "user") - event_b = _make_event("inv-1", "agent") - event_c = _make_event("inv-1", "agent") # Multiple agent responses + def test_filter_events_multiple_events_same_invocation(self): + """Test filtering with multiple events in the same invocation.""" + # Multiple events in inv-1 + event_a = _make_event("inv-1", "user") + event_b = _make_event("inv-1", "agent") + event_c = _make_event("inv-1", "agent") # Multiple agent responses - # inv-2 events - event_d = _make_event("inv-2", "user") - event_e = _make_event("inv-2", "agent") + # inv-2 events + event_d = _make_event("inv-2", "user") + event_e = _make_event("inv-2", "agent") - # Rewind to inv-2 - should remove all inv-2 events - rewind_event = _make_event( - "inv-rewind", "user", rewind_before_invocation_id="inv-2" - ) + # Rewind to inv-2 - should remove all inv-2 events + rewind_event = _make_event( + "inv-rewind", "user", rewind_before_invocation_id="inv-2" + ) - session = _make_session( - [event_a, event_b, event_c, event_d, event_e, rewind_event] - ) + session = _make_session( + [event_a, event_b, event_c, event_d, event_e, rewind_event] + ) - result = session.filter_events() + result = session.filter_events() - # Should only have the 3 events from inv-1 - assert len(result) == 3 - assert all(e.invocation_id == "inv-1" for e in result) + # Should only have the 3 events from inv-1 + assert len(result) == 3 + assert all(e.invocation_id == "inv-1" for e in result)