|
1 | | -# Copyright 2026 Google LLC |
| 1 | +# Copyright 2025 Google LLC |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +"""Session model with event filtering capabilities. |
| 16 | +
|
| 17 | +This module contains the Session class which represents a series of interactions |
| 18 | +between a user and agents, including methods for retrieving and filtering events. |
| 19 | +""" |
| 20 | + |
15 | 21 | from __future__ import annotations |
16 | 22 |
|
17 | 23 | from typing import Any |
|
20 | 26 | from pydantic import BaseModel |
21 | 27 | from pydantic import ConfigDict |
22 | 28 | from pydantic import Field |
23 | | -from pydantic import PrivateAttr |
24 | 29 |
|
25 | 30 | from ..events.event import Event |
26 | 31 |
|
@@ -50,5 +55,95 @@ class Session(BaseModel): |
50 | 55 | last_update_time: float = 0.0 |
51 | 56 | """The last update time of the session.""" |
52 | 57 |
|
53 | | - _storage_update_marker: str | None = PrivateAttr(default=None) |
54 | | - """Internal storage revision marker used for stale-session detection.""" |
| 58 | + def get_events(self) -> list[Event]: |
| 59 | + """Returns all events in the session. |
| 60 | +
|
| 61 | + This method provides a consistent API for accessing events alongside |
| 62 | + the filter_events() method. |
| 63 | +
|
| 64 | + Returns: |
| 65 | + A list containing all events in the session. |
| 66 | +
|
| 67 | + Example: |
| 68 | + >>> for event in session.get_events(): |
| 69 | + ... print(event.author, event.content) |
| 70 | + """ |
| 71 | + return self.events |
| 72 | + |
| 73 | + def filter_events(self, *, exclude_rewound: bool = True) -> list[Event]: |
| 74 | + """Returns filtered events from the session. |
| 75 | +
|
| 76 | + This method provides convenient filtering of session events, with the |
| 77 | + primary use case being exclusion of events that have been invalidated |
| 78 | + by rewind operations. |
| 79 | +
|
| 80 | + Args: |
| 81 | + exclude_rewound: If True (default), excludes events that have been |
| 82 | + invalidated by a rewind operation. When a session is rewound, |
| 83 | + all events from the rewind target invocation onwards are |
| 84 | + considered "rewound" and will be excluded. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + A filtered list of events based on the specified criteria. |
| 88 | +
|
| 89 | + Example: |
| 90 | + >>> # Get only active events (excluding rewound ones) |
| 91 | + >>> for event in session.filter_events(): |
| 92 | + ... process_event(event) |
| 93 | +
|
| 94 | + >>> # Get all events including rewound ones |
| 95 | + >>> for event in session.filter_events(exclude_rewound=False): |
| 96 | + ... process_all_events(event) |
| 97 | + """ |
| 98 | + if not exclude_rewound: |
| 99 | + return self.events |
| 100 | + return self._filter_rewound_events() |
| 101 | + |
| 102 | + def _filter_rewound_events(self) -> list[Event]: |
| 103 | + """Filter out events that have been invalidated by a rewind operation. |
| 104 | +
|
| 105 | + This method implements the rewind filtering logic: it iterates backward |
| 106 | + through the events, and when a rewind event is found (identified by |
| 107 | + having a non-None `actions.rewind_before_invocation_id`), it skips all |
| 108 | + events from the rewind target invocation up to and including the rewind |
| 109 | + event itself. |
| 110 | +
|
| 111 | + The algorithm works as follows: |
| 112 | + 1. Iterate through events from the end to the beginning |
| 113 | + 2. When a rewind event is encountered, find the first event with the |
| 114 | + target invocation_id and skip all events from that point to the |
| 115 | + rewind event |
| 116 | + 3. Events not affected by any rewind are included in the result |
| 117 | + 4. The final list is reversed to maintain chronological order |
| 118 | +
|
| 119 | + Returns: |
| 120 | + A list of events with rewound events filtered out. |
| 121 | + """ |
| 122 | + if not self.events: |
| 123 | + return [] |
| 124 | + |
| 125 | + filtered: list[Event] = [] |
| 126 | + i = len(self.events) - 1 |
| 127 | + |
| 128 | + while i >= 0: |
| 129 | + event = self.events[i] |
| 130 | + |
| 131 | + # Check if this is a rewind event |
| 132 | + if event.actions and event.actions.rewind_before_invocation_id: |
| 133 | + rewind_invocation_id = event.actions.rewind_before_invocation_id |
| 134 | + |
| 135 | + # Find the first event with the target invocation_id and skip to it |
| 136 | + for j in range(0, i): |
| 137 | + if self.events[j].invocation_id == rewind_invocation_id: |
| 138 | + # Skip all events from j to i (inclusive of the rewind event) |
| 139 | + i = j |
| 140 | + break |
| 141 | + else: |
| 142 | + # Not a rewind event, include it |
| 143 | + filtered.append(event) |
| 144 | + |
| 145 | + i -= 1 |
| 146 | + |
| 147 | + # Reverse to restore chronological order |
| 148 | + filtered.reverse() |
| 149 | + return filtered |
0 commit comments