Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 98 additions & 3 deletions src/google/adk/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Session model with event filtering capabilities.

This module contains the Session class which represents a series of interactions
between a user and agents, including methods for retrieving and filtering events.
"""

from __future__ import annotations

from typing import Any
Expand All @@ -20,7 +26,6 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import PrivateAttr

from ..events.event import Event

Expand Down Expand Up @@ -50,5 +55,95 @@ class Session(BaseModel):
last_update_time: float = 0.0
"""The last update time of the session."""

_storage_update_marker: str | None = PrivateAttr(default=None)
"""Internal storage revision marker used for stale-session detection."""
def get_events(self) -> list[Event]:
"""Returns all events in the session.

This method provides a consistent API for accessing events alongside
the filter_events() method.

Returns:
A list containing all events in the session.

Example:
>>> for event in session.get_events():
... print(event.author, event.content)
"""
return self.events

def filter_events(self, *, exclude_rewound: bool = True) -> list[Event]:
"""Returns filtered events from the session.

This method provides convenient filtering of session events, with the
primary use case being exclusion of events that have been invalidated
by rewind operations.

Args:
exclude_rewound: If True (default), excludes events that have been
invalidated by a rewind operation. When a session is rewound,
all events from the rewind target invocation onwards are
considered "rewound" and will be excluded.

Returns:
A filtered list of events based on the specified criteria.

Example:
>>> # Get only active events (excluding rewound ones)
>>> for event in session.filter_events():
... process_event(event)

>>> # Get all events including rewound ones
>>> for event in session.filter_events(exclude_rewound=False):
... process_all_events(event)
"""
if not exclude_rewound:
return self.events
return self._filter_rewound_events()

def _filter_rewound_events(self) -> list[Event]:
"""Filter out events that have been invalidated by a rewind operation.

This method implements the rewind filtering logic: it iterates backward
through the events, and when a rewind event is found (identified by
having a non-None `actions.rewind_before_invocation_id`), it skips all
events from the rewind target invocation up to and including the rewind
event itself.

The algorithm works as follows:
1. Iterate through events from the end to the beginning
2. When a rewind event is encountered, find the first event with the
target invocation_id and skip all events from that point to the
rewind event
3. Events not affected by any rewind are included in the result
4. The final list is reversed to maintain chronological order

Returns:
A list of events with rewound events filtered out.
"""
if not self.events:
return []

filtered: list[Event] = []
i = len(self.events) - 1

while i >= 0:
event = self.events[i]

# Check if this is a rewind event
if event.actions and event.actions.rewind_before_invocation_id:
rewind_invocation_id = event.actions.rewind_before_invocation_id

# Find the first event with the target invocation_id and skip to it
for j in range(0, i):
if self.events[j].invocation_id == rewind_invocation_id:
# Skip all events from j to i (inclusive of the rewind event)
i = j
break
else:
# Not a rewind event, include it
filtered.append(event)

i -= 1

# Reverse to restore chronological order
filtered.reverse()
return filtered
Loading