From 96d1dca612efee4fe3cba52311b92382a2b25772 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 13:14:26 -0600 Subject: [PATCH 01/12] Oneshot attempt at adding firestore support for memory and sessions --- pyproject.toml | 1 + src/google/adk/firestore_database_runner.py | 64 +++ .../adk/memory/firestore_memory_service.py | 172 ++++++++ .../adk/sessions/firestore_session_service.py | 404 ++++++++++++++++++ .../memory/test_firestore_memory_service.py | 96 +++++ .../test_firestore_session_service.py | 159 +++++++ 6 files changed, 896 insertions(+) create mode 100644 src/google/adk/firestore_database_runner.py create mode 100644 src/google/adk/memory/firestore_memory_service.py create mode 100644 src/google/adk/sessions/firestore_session_service.py create mode 100644 tests/unittests/memory/test_firestore_memory_service.py create mode 100644 tests/unittests/sessions/test_firestore_session_service.py diff --git a/pyproject.toml b/pyproject.toml index 2789bcf82a..426b6d1bbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ extensions = [ "beautifulsoup4>=3.2.2", # For load_web_page tool. "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ "docker>=7.0.0", # For ContainerCodeExecutor + "google-cloud-firestore>=2.11.0", # For Firestore services "kubernetes>=29.0.0", # For GkeCodeExecutor "k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/firestore_database_runner.py new file mode 100644 index 0000000000..0ea7aa4f16 --- /dev/null +++ b/src/google/adk/firestore_database_runner.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Optional + +from .artifacts.gcs_artifact_service import GcsArtifactService +from .memory.firestore_memory_service import FirestoreMemoryService +from .runners import Runner +from .sessions.firestore_session_service import FirestoreSessionService + +if TYPE_CHECKING: + from .agents.base_agent import BaseAgent + + +def create_firestore_runner( + agent: BaseAgent, + gcs_bucket_name: Optional[str] = None, + firestore_root_collection: Optional[str] = None, +) -> Runner: + """Creates a Runner configured with Firestore and GCS services. + + Args: + agent: The root agent to run. + gcs_bucket_name: The GCS bucket name for artifacts. + firestore_root_collection: The root collection name for Firestore. + + Returns: + A Runner instance configured with Firestore services. + """ + # GcsArtifactService might require bucket name in constructor or read from env. + # Let's assume it reads from env or takes it. + # If we pass it, we might need to check its signature. + # Let's assume it takes bucket_name if provided, or reads from env. + artifact_service = GcsArtifactService() + if gcs_bucket_name: + # If GcsArtifactService supports setting it, we set it. + # Or we can assume it reads from ADK_GCS_BUCKET_NAME env var. + pass + + session_service = FirestoreSessionService( + root_collection=firestore_root_collection + ) + memory_service = FirestoreMemoryService() + + return Runner( + agent=agent, + session_service=session_service, + artifact_service=artifact_service, + memory_service=memory_service, + ) diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/memory/firestore_memory_service.py new file mode 100644 index 0000000000..57c0de645f --- /dev/null +++ b/src/google/adk/memory/firestore_memory_service.py @@ -0,0 +1,172 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import os +import re +from typing import Any +from typing import Optional + +from google.cloud import firestore +from typing_extensions import override + +from ..events.event import Event +from . import _utils +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if False: # TYPE_CHECKING + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_EVENTS_COLLECTION = "events" + +# Standard English stop words +DEFAULT_STOP_WORDS = { + "a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of", + "in", "on", "for", "with", "is", "are", "was", "were", "be", "been", + "being", "have", "has", "had", "do", "does", "did", "can", "could", + "will", "would", "should", "shall", "may", "might", "must", "up", "down", + "out", "in", "over", "under", "again", "further", "then", "once", "here", + "there", "when", "where", "why", "how", "all", "any", "both", "each", + "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", + "own", "same", "so", "than", "too", "very", "i", "me", "my", "myself", + "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", + "yourselves", "he", "him", "his", "himself", "she", "her", "hers", + "herself", "it", "its", "itself", "they", "them", "their", "theirs", + "themselves", "what", "which", "who", "whom", "this", "that", "these", + "those", "am", "is", "are", "was", "were", "be", "been", "being", + "have", "has", "had", "having", "do", "does", "did", "doing", + "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", + "while", "of", "at", "by", "for", "with", "about", "against", "between", + "into", "through", "during", "before", "after", "above", "below", "to", + "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", + "further", "then", "once", "here", "there", "when", "where", "why", "how", + "all", "any", "both", "each", "few", "more", "most", "other", "some", + "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", + "very", "s", "t", "can", "will", "just", "don", "should", "now" +} + + +class FirestoreMemoryService(BaseMemoryService): + """Memory service that uses Google Cloud Firestore as the backend.""" + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + events_collection: Optional[str] = None, + stop_words: Optional[set[str]] = None, + ): + """Initializes the Firestore memory service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + events_collection: The name of the events collection or collection group. + Defaults to 'events'. + stop_words: A set of words to ignore when extracting keywords. Defaults to + a standard English stop words list. + """ + self.client = client or firestore.AsyncClient() + self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION + self.stop_words = stop_words if stop_words is not None else DEFAULT_STOP_WORDS + + @override + async def add_session_to_memory(self, session: Session) -> None: + """No-op. Assumes events are written to Firestore by FirestoreSessionService.""" + pass + + def _extract_keywords(self, text: str) -> set[str]: + """Extracts keywords from text, ignoring stop words.""" + words = re.findall(r"[A-Za-z]+", text.lower()) + return {word for word in words if word not in self.stop_words} + + async def _search_by_keyword( + self, app_name: str, user_id: str, keyword: str + ) -> list[MemoryEntry]: + """Searches for events matching a single keyword.""" + # This requires a collection group index in Firestore for 'events' with + # appName == X, userId == Y, and keywords array-contains Z. + query = ( + self.client.collection_group(self.events_collection) + .where("appName", "==", app_name) + .where("userId", "==", user_id) + .where("keywords", "array_contains", keyword) + ) + + docs = await query.get() + entries = [] + for doc in docs: + data = doc.to_dict() + if data and "event_data" in data: + try: + event = Event.model_validate(data["event_data"]) + if event.content: + entries.append( + MemoryEntry( + content=event.content, + author=event.author, + timestamp=_utils.format_timestamp(event.timestamp), + ) + ) + except Exception as e: + logger.warning("Failed to parse event from Firestore: %s", e) + + return entries + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Searches memory for events matching the query.""" + keywords = self._extract_keywords(query) + if not keywords: + return SearchMemoryResponse() + + # Search for each keyword concurrently + tasks = [ + self._search_by_keyword(app_name, user_id, keyword) + for keyword in keywords + ] + results = await asyncio.gather(*tasks) + + # Merge results and deduplicate by MemoryEntry content/author/timestamp + # (MemoryEntry is not hashable by default if it contains complex objects, + # so we might need to deduplicate by id if available, or by content string). + # Since we convert Event to MemoryEntry, we don't have event.id in MemoryEntry + # unless we add it. The Java code use custom hash/equals for MemoryEntry. + # In Python, MemoryEntry is a Pydantic model. We can deduplicate by model_dump_json() + # or by a custom key. + seen = set() + memories = [] + for result_list in results: + for entry in result_list: + # Deduplicate by a key of (author, content_text) + # Content might be complex, so let's use its json representation or text + content_text = "" + if entry.content and entry.content.parts: + content_text = " ".join( + [part.text for part in entry.content.parts if part.text] + ) + key = (entry.author, content_text, entry.timestamp) + if key not in seen: + seen.add(key) + memories.append(entry) + + return SearchMemoryResponse(memories=memories) diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py new file mode 100644 index 0000000000..b6ad01d4e4 --- /dev/null +++ b/src/google/adk/sessions/firestore_session_service.py @@ -0,0 +1,404 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +from typing import Any +from typing import Optional + +from google.cloud import firestore +from pydantic import BaseModel + +from ..events.event import Event +from .base_session_service import BaseSessionService +from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsResponse +from .session import Session + + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_ROOT_COLLECTION = "adk-session" +DEFAULT_SESSIONS_COLLECTION = "sessions" +DEFAULT_EVENTS_COLLECTION = "events" +DEFAULT_APP_STATE_COLLECTION = "app_states" +DEFAULT_USER_STATE_COLLECTION = "user_states" + + +class FirestoreSessionService(BaseSessionService): + """Session service that uses Google Cloud Firestore as the backend.""" + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + root_collection: Optional[str] = None, + ): + """Initializes the Firestore session service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + root_collection: The root collection name. Defaults to 'adk-session' or + the value of ADK_FIRESTORE_ROOT_COLLECTION env var. + """ + self.client = client or firestore.AsyncClient() + self.root_collection = ( + root_collection + or os.environ.get("ADK_FIRESTORE_ROOT_COLLECTION") + or DEFAULT_ROOT_COLLECTION + ) + self.sessions_collection = DEFAULT_SESSIONS_COLLECTION + self.events_collection = DEFAULT_EVENTS_COLLECTION + self.app_state_collection = DEFAULT_APP_STATE_COLLECTION + self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + + def _get_sessions_ref(self, user_id: str) -> firestore.AsyncCollectionReference: + return ( + self.client.collection(self.root_collection) + .document(user_id) + .collection(self.sessions_collection) + ) + + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + """Creates a new session in Firestore.""" + if not session_id: + from google.adk.platform import uuid as platform_uuid + session_id = platform_uuid.new_uuid() + + initial_state = state or {} + now = firestore.SERVER_TIMESTAMP + + session_ref = self._get_sessions_ref(user_id).document(session_id) + + # Check if session already exists + doc = await session_ref.get() + if doc.exists: + from ..errors.already_exists_error import AlreadyExistsError + raise AlreadyExistsError(f"Session {session_id} already exists.") + + session_data = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": initial_state, + "createTime": now, + "updateTime": now, + } + + await session_ref.set(session_data) + + # We need a timestamp for the Session object. Since SERVER_TIMESTAMP is + # evaluated on the server, we might want to use local time for the object + # or read it back. Reading it back is expensive. We'll use local time for + # the object, but the DB will have SERVER_TIMESTAMP. + from datetime import datetime + from datetime import timezone + local_now = datetime.now(timezone.utc).timestamp() + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=initial_state, + events=[], + last_update_time=local_now, + ) + + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + """Gets a session from Firestore.""" + session_ref = self._get_sessions_ref(user_id).document(session_id) + doc = await session_ref.get() + + if not doc.exists: + return None + + data = doc.to_dict() + if not data: + return None + + # Fetch events + events_ref = session_ref.collection(self.events_collection) + query = events_ref.order_by("timestamp") + + if config: + if config.after_timestamp: + after_dt = datetime.fromtimestamp(config.after_timestamp) + query = query.where("timestamp", ">=", after_dt) + if config.num_recent_events: + query = query.limit_to_last(config.num_recent_events) + + events_docs = await query.get() + events = [] + for event_doc in events_docs: + event_data = event_doc.to_dict() + if event_data and "event_data" in event_data: + # The Java code serializes individual fields, but Python schema/v1 uses + # JSON serialization of the whole event. We'll stick to Pythonic JSON + # serialization (event.model_dump) for consistency with Python ADK. + ed = event_data["event_data"] + # Restore timestamp if needed, or assume it's in event_data + events.append(Event.model_validate(ed)) + + # Fetch states (app and user) if we want to merge them, similar to + # DatabaseSessionService. The Java code seems to merge them in listSessions + # but let's see if getSession does it. + # In Java, getSession fetches app/user state if needed? The Java code I read: + # It didn't seem to fetch app/user state in getSession, only in appendEvent + # where it updates them, and listSessions where it mergers. + # Wait, let's re-read Java getSession. + # It doesn't seem to fetch app/user state in getSession either? + # Actually, in Java `FirestoreSessionService.java` `getSession`: + # It reads the session doc, then reads events. It doesn't seem to read + # app/user state docs. + # But `DatabaseSessionService` in Python DOES read them in `get_session`. + # Let's align with Python `DatabaseSessionService` if possible, as it's the + # standard in Python ADK. + # Python `DatabaseSessionService` reads `StorageAppState` and `StorageUserState` + # and merges them. + # If I want to be consistent with Python ADK, I should probably do it. + # But if I want to be consistent with Java ADK port, I should follow Java. + # The user asked to "Port this firestore support over to ADK Python". + # I should follow the Java logic but make it Pythonic. + # The Java logic doesn't seem to merge app/user state in `getSession`, it + # just returns session state. + # Wait, let's check Java `listSessions`. It read `StorageAppState`? No, it + # just read sessions. + # Let's stick to the Java logic if it works, or adapt to Python if it's better. + # Since `DatabaseSessionService` in Python merges them, maybe it's a newer + # feature in Python ADK that Java doesn't have or does differently. + # Let's check `FirestoreSessionService.java` again. + # In Java `listSessions`, it doesn't seem to fetch app/user state. + # In Java `appendEvent`, it updates app/user state if `state_delta` has + # `_app_` or `_user_` prefixes. + # Let's stick to the Java behavior unless it conflicts with Python interfaces. + # The Python `BaseSessionService` doesn't enforce merging, it just defines + # the interface. `DatabaseSessionService` implements merging. + # I'll stick to the Java behavior (no merging in get/list, only update in append) + # for now, as it's a port of Java. Or I can implement merging if it's easy. + # Let's look at Java `appendEvent`: + # It checks `_app_` and `_user_` prefixes in `state_delta` and updates + # separate collections! + # ```java + # firestore.collection(APP_STATE_COLLECTION).document(appName).set(...) + # ``` + # So it DOES use separate collections for app/user state. + # If it uses them, it should probably read them somewhere. In Java, it seems + # it might not read them in `getSession`? Wait, let's check `FirestoreSessionService.java` + # again. I see `listSessions` doesn't read them. `getSession` doesn't read them. + # That might be a bug or partial implementation in Java? Or maybe they are + # read elsewhere? + # In Python `DatabaseSessionService` reads them in `get_session` and `list_sessions`. + # Let's implement reading them in Python `FirestoreSessionService` to be + # consistent with Python ADK standards if possible, or at least support it. + # I'll implement it without merging first to match Java, then see if I should + # add it. The Java code didn't do it. + + # Let's continue getting session. + session_state = data.get("state", {}) + + # Convert timestamp + update_time = data.get("updateTime") + last_update_time = 0.0 + if update_time: + # If it's a datetime object (Firestore might return it) + if isinstance(update_time, datetime): + last_update_time = update_time.timestamp() + else: + # Assuming it's a string or float + try: + last_update_time = float(update_time) + except (ValueError, TypeError): + pass + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=session_state, + events=events, + last_update_time=last_update_time, + ) + + async def list_sessions( + self, *, app_name: str, user_id: Optional[str] = None + ) -> ListSessionsResponse: + """Lists sessions from Firestore.""" + # If user_id is provided, we can list directly. + # If not, we might need a collection group query or list all users first. + # Java listSessions takes appName and userId. It always scopes to user. + # Python list_sessions has user_id optional. + # If user_id is None, we should list all sessions for the app across all users. + # This requires a collection group query on 'sessions'. + if user_id: + query = self._get_sessions_ref(user_id).where("appName", "==", app_name) + docs = await query.get() + else: + # Collection group query + query = self.client.collection_group(self.sessions_collection).where( + "appName", "==", app_name + ) + docs = await query.get() + + sessions = [] + for doc in docs: + data = doc.to_dict() + if data: + # Session state is empty for listing as per in_memory + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state={}, # Empty state for listing + events=[], # Empty events for listing + last_update_time=0.0, # Or parse from updateTime + ) + ) + + return ListSessionsResponse(sessions=sessions) + + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + """Deletes a session and its events from Firestore.""" + session_ref = self._get_sessions_ref(user_id).document(session_id) + + # Delete events subcollection first (Firestore requires manual subcollection deletion) + events_ref = session_ref.collection(self.events_collection) + events_docs = await events_ref.get() + + # Batch delete + batch = self.client.batch() + for event_doc in events_docs: + batch.delete(event_doc.reference) + await batch.commit() + + # Delete session doc + await session_ref.delete() + + async def append_event(self, session: Session, event: Event) -> Event: + """Appends an event to a session in Firestore.""" + if event.partial: + return event + + # Apply temp state to in-memory session (from base class) + self._apply_temp_state(session, event) + event = self._trim_temp_delta_state(event) + + session_ref = self._get_sessions_ref(session.user_id).document(session.id) + + # Handle state deltas (app and user state) + if event.actions and event.actions.state_delta: + state_delta = event.actions.state_delta + app_updates = {} + user_updates = {} + session_updates = {} + + for key, value in state_delta.items(): + if key.startswith("_app_"): + app_updates[key[len("_app_"):]] = value + elif key.startswith("_user_"): + user_updates[key[len("_user_"):]] = value + else: + session_updates[key] = value + + + # Update session doc with new state and updateTime + # We'll do it outside the batch or inside if we can. + # Let's use batch for everything to be atomic. + # Wait, I didn't add session_ref to batch yet. + # Let's create a batch. + batch = self.client.batch() + + if app_updates: + app_ref = self.client.collection(self.app_state_collection).document( + session.app_name + ) + batch.set(app_ref, app_updates, merge=True) + + if user_updates: + user_ref = ( + self.client.collection(self.user_state_collection) + .document(session.app_name) + .collection("users") + .document(session.user_id) + ) + batch.set(user_ref, user_updates, merge=True) + + # Update session state in-memory first + for k, v in session_updates.items(): + session.state[k] = v + + # Update session doc + batch.update( + session_ref, + { + "state": session.state, + "updateTime": firestore.SERVER_TIMESTAMP, + }, + ) + + # Add event + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document(event_id) + # Store event data as JSON serialized string or dict + event_data = event.model_dump(exclude_none=True, mode="json") + batch.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) + + await batch.commit() + else: + # No state delta, just add event and update session timestamp + batch = self.client.batch() + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document(event_id) + event_data = event.model_dump(exclude_none=True, mode="json") + batch.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) + batch.update(session_ref, {"updateTime": firestore.SERVER_TIMESTAMP}) + await batch.commit() + + # Also update the in-memory session (adds event to list) + await super().append_event(session, event) + return event diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py new file mode 100644 index 0000000000..b41e14fb94 --- /dev/null +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -0,0 +1,96 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest +from google.adk.events.event import Event +from google.adk.memory.firestore_memory_service import FirestoreMemoryService +from google.genai import types + + +@pytest.fixture +def mock_firestore_client(): + client = mock.AsyncMock() + collection_ref = mock.AsyncMock() + client.collection_group.return_value = collection_ref + collection_ref.where.return_value = collection_ref + + # Mock get() for documents + doc_snapshot = mock.AsyncMock() + doc_snapshot.to_dict.return_value = {} + collection_ref.get.return_value = [doc_snapshot] + + return client + + +def test_extract_keywords(): + service = FirestoreMemoryService() + text = "The quick brown fox jumps over the lazy dog." + keywords = service._extract_keywords(text) + + # Check that stopwords like "the", "over" are removed + assert "the" not in keywords + assert "over" not in keywords + assert "quick" in keywords + assert "brown" in keywords + assert "fox" in keywords + assert "jumps" in keywords + assert "lazy" in keywords + assert "dog" in keywords + + +@pytest.mark.asyncio +async def test_search_memory_empty_query(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="" + ) + assert not response.memories + mock_firestore_client.collection_group.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_memory_with_results(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick fox" + + # Mock document snapshot to return event data + doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[0] + event = Event( + invocation_id="test_inv", + author="user", + content=types.Content(parts=[types.Part(text="quick fox jumps")]), + ) + doc_snapshot.to_dict.return_value = { + "event_data": event.model_dump(exclude_none=True, mode="json") + } + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert response.memories + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + + # Verify Firestore calls + mock_firestore_client.collection_group.assert_called_with("events") + collection_ref = mock_firestore_client.collection_group.return_value + # Verify where calls (order might vary, so we just check it was called or check the chain) + collection_ref.where.assert_called() diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py new file mode 100644 index 0000000000..ec2af4ad04 --- /dev/null +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -0,0 +1,159 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest +from google.adk.events.event import Event +from google.adk.sessions.firestore_session_service import FirestoreSessionService + + +@pytest.fixture +def mock_firestore_client(): + client = mock.AsyncMock() + # Mock collection and document references + collection_ref = mock.AsyncMock() + doc_ref = mock.AsyncMock() + subcollection_ref = mock.AsyncMock() + subdoc_ref = mock.AsyncMock() + + client.collection.return_value = collection_ref + collection_ref.document.return_value = doc_ref + doc_ref.collection.return_value = subcollection_ref + subcollection_ref.document.return_value = subdoc_ref + + # Mock get() for documents + doc_snapshot = mock.AsyncMock() + doc_snapshot.exists = False + doc_snapshot.to_dict.return_value = {} + doc_ref.get.return_value = doc_snapshot + subdoc_ref.get.return_value = doc_snapshot + + # Mock collection group + client.collection_group.return_value = collection_ref + + return client + + +@pytest.mark.asyncio +async def test_create_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session = await service.create_session(app_name=app_name, user_id=user_id) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id + + # Verify Firestore calls + mock_firestore_client.collection.assert_called_once_with("adk-session") + collection_ref = mock_firestore_client.collection.return_value + collection_ref.document.assert_called_once_with(user_id) + doc_ref = collection_ref.document.return_value + doc_ref.collection.assert_called_once_with("sessions") + sessions_ref = doc_ref.collection.return_value + sessions_ref.document.assert_called_once_with(session.id) + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.set.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_session_not_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is None + + +@pytest.mark.asyncio +async def test_get_session_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Mock document snapshot to return data + doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": {"key": "value"}, + "updateTime": 1234567890.0, + } + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.id == session_id + assert session.state == {"key": "value"} + + +@pytest.mark.asyncio +async def test_delete_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Mock events subcollection + events_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + event_doc = mock.AsyncMock() + events_ref.get.return_value = [event_doc] + + await service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + # Verify events deletion + events_ref.get.assert_called_once() + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + batch.delete.assert_called_once_with(event_doc.reference) + batch.commit.assert_called_once() + + # Verify session deletion + session_doc_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + session_doc_ref.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_append_event(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + session = Session(id="test_session", app_name=app_name, user_id=user_id) + event = Event(invocation_id="test_inv", author="user") + + await service.append_event(session, event) + + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + batch.set.assert_called_once() # For event + batch.update.assert_called_once() # For session updateTime + batch.commit.assert_called_once() From d2d223168e45384488f983b2bec35b4319dc90b4 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 13:30:59 -0600 Subject: [PATCH 02/12] Formatting and fixing the bucket name handling --- contributing/samples/gepa/experiment.py | 1 - contributing/samples/gepa/run_experiment.py | 1 - src/google/adk/firestore_database_runner.py | 19 +- .../adk/memory/firestore_memory_service.py | 228 ++++++++++++++++-- .../adk/sessions/firestore_session_service.py | 21 +- .../memory/test_firestore_memory_service.py | 6 +- .../test_firestore_session_service.py | 19 +- 7 files changed, 244 insertions(+), 51 deletions(-) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/firestore_database_runner.py index 0ea7aa4f16..b3abbd45b0 100644 --- a/src/google/adk/firestore_database_runner.py +++ b/src/google/adk/firestore_database_runner.py @@ -14,8 +14,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import os from typing import Optional +from typing import TYPE_CHECKING from .artifacts.gcs_artifact_service import GcsArtifactService from .memory.firestore_memory_service import FirestoreMemoryService @@ -41,15 +42,13 @@ def create_firestore_runner( Returns: A Runner instance configured with Firestore services. """ - # GcsArtifactService might require bucket name in constructor or read from env. - # Let's assume it reads from env or takes it. - # If we pass it, we might need to check its signature. - # Let's assume it takes bucket_name if provided, or reads from env. - artifact_service = GcsArtifactService() - if gcs_bucket_name: - # If GcsArtifactService supports setting it, we set it. - # Or we can assume it reads from ADK_GCS_BUCKET_NAME env var. - pass + bucket_name = gcs_bucket_name or os.environ.get("ADK_GCS_BUCKET_NAME") + if not bucket_name: + raise ValueError( + "Required property 'ADK_GCS_BUCKET_NAME' is not set. This" + " is needed for the GcsArtifactService." + ) + artifact_service = GcsArtifactService(bucket_name=bucket_name) session_service = FirestoreSessionService( root_collection=firestore_root_collection diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/memory/firestore_memory_service.py index 57c0de645f..97ade10b89 100644 --- a/src/google/adk/memory/firestore_memory_service.py +++ b/src/google/adk/memory/firestore_memory_service.py @@ -24,8 +24,8 @@ from google.cloud import firestore from typing_extensions import override -from ..events.event import Event from . import _utils +from ..events.event import Event from .base_memory_service import BaseMemoryService from .base_memory_service import SearchMemoryResponse from .memory_entry import MemoryEntry @@ -39,28 +39,206 @@ # Standard English stop words DEFAULT_STOP_WORDS = { - "a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of", - "in", "on", "for", "with", "is", "are", "was", "were", "be", "been", - "being", "have", "has", "had", "do", "does", "did", "can", "could", - "will", "would", "should", "shall", "may", "might", "must", "up", "down", - "out", "in", "over", "under", "again", "further", "then", "once", "here", - "there", "when", "where", "why", "how", "all", "any", "both", "each", - "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", - "own", "same", "so", "than", "too", "very", "i", "me", "my", "myself", - "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", - "yourselves", "he", "him", "his", "himself", "she", "her", "hers", - "herself", "it", "its", "itself", "they", "them", "their", "theirs", - "themselves", "what", "which", "who", "whom", "this", "that", "these", - "those", "am", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "having", "do", "does", "did", "doing", - "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", - "while", "of", "at", "by", "for", "with", "about", "against", "between", - "into", "through", "during", "before", "after", "above", "below", "to", - "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", - "further", "then", "once", "here", "there", "when", "where", "why", "how", - "all", "any", "both", "each", "few", "more", "most", "other", "some", - "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", - "very", "s", "t", "can", "will", "just", "don", "should", "now" + "a", + "an", + "the", + "and", + "or", + "but", + "if", + "then", + "else", + "to", + "of", + "in", + "on", + "for", + "with", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "can", + "could", + "will", + "would", + "should", + "shall", + "may", + "might", + "must", + "up", + "down", + "out", + "in", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "i", + "me", + "my", + "myself", + "we", + "our", + "ours", + "ourselves", + "you", + "your", + "yours", + "yourself", + "yourselves", + "he", + "him", + "his", + "himself", + "she", + "her", + "hers", + "herself", + "it", + "its", + "itself", + "they", + "them", + "their", + "theirs", + "themselves", + "what", + "which", + "who", + "whom", + "this", + "that", + "these", + "those", + "am", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "having", + "do", + "does", + "did", + "doing", + "a", + "an", + "the", + "and", + "but", + "if", + "or", + "because", + "as", + "until", + "while", + "of", + "at", + "by", + "for", + "with", + "about", + "against", + "between", + "into", + "through", + "during", + "before", + "after", + "above", + "below", + "to", + "from", + "up", + "down", + "in", + "out", + "on", + "off", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "s", + "t", + "can", + "will", + "just", + "don", + "should", + "now", } @@ -85,7 +263,9 @@ def __init__( """ self.client = client or firestore.AsyncClient() self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION - self.stop_words = stop_words if stop_words is not None else DEFAULT_STOP_WORDS + self.stop_words = ( + stop_words if stop_words is not None else DEFAULT_STOP_WORDS + ) @override async def add_session_to_memory(self, session: Session) -> None: diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py index b6ad01d4e4..1f1000ad11 100644 --- a/src/google/adk/sessions/firestore_session_service.py +++ b/src/google/adk/sessions/firestore_session_service.py @@ -28,7 +28,6 @@ from .base_session_service import ListSessionsResponse from .session import Session - logger = logging.getLogger("google_adk." + __name__) DEFAULT_ROOT_COLLECTION = "adk-session" @@ -65,7 +64,9 @@ def __init__( self.app_state_collection = DEFAULT_APP_STATE_COLLECTION self.user_state_collection = DEFAULT_USER_STATE_COLLECTION - def _get_sessions_ref(self, user_id: str) -> firestore.AsyncCollectionReference: + def _get_sessions_ref( + self, user_id: str + ) -> firestore.AsyncCollectionReference: return ( self.client.collection(self.root_collection) .document(user_id) @@ -83,6 +84,7 @@ async def create_session( """Creates a new session in Firestore.""" if not session_id: from google.adk.platform import uuid as platform_uuid + session_id = platform_uuid.new_uuid() initial_state = state or {} @@ -94,6 +96,7 @@ async def create_session( doc = await session_ref.get() if doc.exists: from ..errors.already_exists_error import AlreadyExistsError + raise AlreadyExistsError(f"Session {session_id} already exists.") session_data = { @@ -113,6 +116,7 @@ async def create_session( # the object, but the DB will have SERVER_TIMESTAMP. from datetime import datetime from datetime import timezone + local_now = datetime.now(timezone.utc).timestamp() return Session( @@ -323,13 +327,12 @@ async def append_event(self, session: Session, event: Event) -> Event: for key, value in state_delta.items(): if key.startswith("_app_"): - app_updates[key[len("_app_"):]] = value + app_updates[key[len("_app_") :]] = value elif key.startswith("_user_"): - user_updates[key[len("_user_"):]] = value + user_updates[key[len("_user_") :]] = value else: session_updates[key] = value - # Update session doc with new state and updateTime # We'll do it outside the batch or inside if we can. # Let's use batch for everything to be atomic. @@ -367,7 +370,9 @@ async def append_event(self, session: Session, event: Event) -> Event: # Add event event_id = event.id - event_ref = session_ref.collection(self.events_collection).document(event_id) + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) # Store event data as JSON serialized string or dict event_data = event.model_dump(exclude_none=True, mode="json") batch.set( @@ -385,7 +390,9 @@ async def append_event(self, session: Session, event: Event) -> Event: # No state delta, just add event and update session timestamp batch = self.client.batch() event_id = event.id - event_ref = session_ref.collection(self.events_collection).document(event_id) + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) event_data = event.model_dump(exclude_none=True, mode="json") batch.set( event_ref, diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py index b41e14fb94..6cd878f336 100644 --- a/tests/unittests/memory/test_firestore_memory_service.py +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -16,10 +16,10 @@ from unittest import mock -import pytest from google.adk.events.event import Event from google.adk.memory.firestore_memory_service import FirestoreMemoryService from google.genai import types +import pytest @pytest.fixture @@ -71,7 +71,9 @@ async def test_search_memory_with_results(mock_firestore_client): query = "quick fox" # Mock document snapshot to return event data - doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[0] + doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + 0 + ] event = Event( invocation_id="test_inv", author="user", diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index ec2af4ad04..485ef668d1 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -16,9 +16,9 @@ from unittest import mock -import pytest from google.adk.events.event import Event from google.adk.sessions.firestore_session_service import FirestoreSessionService +import pytest @pytest.fixture @@ -94,7 +94,9 @@ async def test_get_session_found(mock_firestore_client): session_id = "test_session" # Mock document snapshot to return data - doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) doc_snapshot.exists = True doc_snapshot.to_dict.return_value = { "id": session_id, @@ -121,7 +123,9 @@ async def test_delete_session(mock_firestore_client): session_id = "test_session" # Mock events subcollection - events_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + events_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) event_doc = mock.AsyncMock() events_ref.get.return_value = [event_doc] @@ -137,7 +141,9 @@ async def test_delete_session(mock_firestore_client): batch.commit.assert_called_once() # Verify session deletion - session_doc_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + session_doc_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + ) session_doc_ref.delete.assert_called_once() @@ -147,6 +153,7 @@ async def test_append_event(mock_firestore_client): app_name = "test_app" user_id = "test_user" from google.adk.sessions.session import Session + session = Session(id="test_session", app_name=app_name, user_id=user_id) event = Event(invocation_id="test_inv", author="user") @@ -154,6 +161,6 @@ async def test_append_event(mock_firestore_client): mock_firestore_client.batch.assert_called_once() batch = mock_firestore_client.batch.return_value - batch.set.assert_called_once() # For event - batch.update.assert_called_once() # For session updateTime + batch.set.assert_called_once() # For event + batch.update.assert_called_once() # For session updateTime batch.commit.assert_called_once() From 8312dc82c47c8931289932c514549f48096ba45d Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 13:50:25 -0600 Subject: [PATCH 03/12] Correct imports for firestore --- src/google/adk/memory/firestore_memory_service.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/memory/firestore_memory_service.py index 97ade10b89..ad9fb57992 100644 --- a/src/google/adk/memory/firestore_memory_service.py +++ b/src/google/adk/memory/firestore_memory_service.py @@ -20,8 +20,8 @@ import re from typing import Any from typing import Optional +from typing import TYPE_CHECKING -from google.cloud import firestore from typing_extensions import override from . import _utils @@ -30,7 +30,9 @@ from .base_memory_service import SearchMemoryResponse from .memory_entry import MemoryEntry -if False: # TYPE_CHECKING +if TYPE_CHECKING: + from google.cloud import firestore + from ..sessions.session import Session logger = logging.getLogger("google_adk." + __name__) @@ -261,7 +263,12 @@ def __init__( stop_words: A set of words to ignore when extracting keywords. Defaults to a standard English stop words list. """ - self.client = client or firestore.AsyncClient() + if client is None: + from google.cloud import firestore + + self.client = firestore.AsyncClient() + else: + self.client = client self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION self.stop_words = ( stop_words if stop_words is not None else DEFAULT_STOP_WORDS From 7760b70e2d340b3882e31443c9337a9977981aac Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:00:10 -0600 Subject: [PATCH 04/12] Add firestore to test dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 426b6d1bbf..931ca994f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ test = [ "a2a-sdk>=0.3.0,<0.4.0", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ + "google-cloud-firestore>=2.11.0", "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent From b29afb73cfe2a5a5f15210e9ecc1ceadf3c874c7 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:12:25 -0600 Subject: [PATCH 05/12] Fix tests --- .../test_firestore_session_service.py | 47 ++++++++++++++ .../test_firestore_database_runner.py | 62 +++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 tests/unittests/test_firestore_database_runner.py diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index 485ef668d1..6048097997 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -164,3 +164,50 @@ async def test_append_event(mock_firestore_client): batch.set.assert_called_once() # For event batch.update.assert_called_once() # For session updateTime batch.commit.assert_called_once() + + +@pytest.mark.asyncio +async def test_append_event_with_state_delta(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + + # Using MagicMock for Event to bypass complex pydantic validation for test + event = mock.MagicMock() + event.partial = False + event.id = "test_event_id" + # Mock actions.state_delta + event.actions.state_delta = { + "_app_my_key": "app_val", + "_user_my_key": "user_val", + "session_key": "session_val", + } + # Mock model_dump to return valid event data + event.model_dump.return_value = {"id": "test_event_id", "author": "user"} + + await service.append_event(session, event) + + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + + # Verify app state set + # In code: batch.set(app_ref, app_updates, merge=True) + # But app_ref is a mock! Which mock? + # It's mock_firestore_client.collection().document() + # In fixture: collection_ref = mock.AsyncMock() + # doc_ref = mock.AsyncMock() + # client.collection.return_value = collection_ref + # collection_ref.document.return_value = doc_ref + # So batch.set is called with app_ref (which is doc_ref) + batch.set.assert_called() + + # Verify session state updated in memory + assert session.state["session_key"] == "session_val" + + # Verify batch update was called for session + batch.update.assert_called_once() + + batch.commit.assert_called_once() diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/test_firestore_database_runner.py new file mode 100644 index 0000000000..a4e8fb1889 --- /dev/null +++ b/tests/unittests/test_firestore_database_runner.py @@ -0,0 +1,62 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest +from google.adk.agents.base_agent import BaseAgent +from google.adk.firestore_database_runner import create_firestore_runner + + +@pytest.fixture +def mock_agent(): + agent = mock.MagicMock(spec=BaseAgent) + agent.name = "test_agent" + return agent + + +def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): + monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) + + # Mock GcsArtifactService to avoid real client init + with mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs: + runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket") + + assert runner is not None + mock_gcs.assert_called_once_with(bucket_name="test_bucket") + + +def test_create_firestore_runner_with_env(mock_agent, monkeypatch): + monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "env_bucket") + + with mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs: + runner = create_firestore_runner(mock_agent) + + assert runner is not None + mock_gcs.assert_called_once_with(bucket_name="env_bucket") + + +def test_create_firestore_runner_missing_bucket(mock_agent, monkeypatch): + monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) + + with pytest.raises( + ValueError, match="Required property 'ADK_GCS_BUCKET_NAME' is not set" + ): + create_firestore_runner(mock_agent) From 31ffb864422333f44a7db2a3918d07f1ad596569 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:37:07 -0600 Subject: [PATCH 06/12] Fix mypy errors --- src/google/adk/errors/already_exists_error.py | 2 +- .../adk/sessions/firestore_session_service.py | 59 +------------------ 2 files changed, 3 insertions(+), 58 deletions(-) diff --git a/src/google/adk/errors/already_exists_error.py b/src/google/adk/errors/already_exists_error.py index 8bd14f9ad6..bf8d357a81 100644 --- a/src/google/adk/errors/already_exists_error.py +++ b/src/google/adk/errors/already_exists_error.py @@ -18,7 +18,7 @@ class AlreadyExistsError(Exception): """Represents an error that occurs when an entity already exists.""" - def __init__(self, message="The resource already exists."): + def __init__(self, message: str = "The resource already exists."): """Initializes the AlreadyExistsError exception. Args: diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py index 1f1000ad11..5ace2d94a9 100644 --- a/src/google/adk/sessions/firestore_session_service.py +++ b/src/google/adk/sessions/firestore_session_service.py @@ -16,6 +16,8 @@ import logging import os +from datetime import datetime +from datetime import timezone from typing import Any from typing import Optional @@ -114,9 +116,6 @@ async def create_session( # evaluated on the server, we might want to use local time for the object # or read it back. Reading it back is expensive. We'll use local time for # the object, but the DB will have SERVER_TIMESTAMP. - from datetime import datetime - from datetime import timezone - local_now = datetime.now(timezone.utc).timestamp() return Session( @@ -170,60 +169,6 @@ async def get_session( # Restore timestamp if needed, or assume it's in event_data events.append(Event.model_validate(ed)) - # Fetch states (app and user) if we want to merge them, similar to - # DatabaseSessionService. The Java code seems to merge them in listSessions - # but let's see if getSession does it. - # In Java, getSession fetches app/user state if needed? The Java code I read: - # It didn't seem to fetch app/user state in getSession, only in appendEvent - # where it updates them, and listSessions where it mergers. - # Wait, let's re-read Java getSession. - # It doesn't seem to fetch app/user state in getSession either? - # Actually, in Java `FirestoreSessionService.java` `getSession`: - # It reads the session doc, then reads events. It doesn't seem to read - # app/user state docs. - # But `DatabaseSessionService` in Python DOES read them in `get_session`. - # Let's align with Python `DatabaseSessionService` if possible, as it's the - # standard in Python ADK. - # Python `DatabaseSessionService` reads `StorageAppState` and `StorageUserState` - # and merges them. - # If I want to be consistent with Python ADK, I should probably do it. - # But if I want to be consistent with Java ADK port, I should follow Java. - # The user asked to "Port this firestore support over to ADK Python". - # I should follow the Java logic but make it Pythonic. - # The Java logic doesn't seem to merge app/user state in `getSession`, it - # just returns session state. - # Wait, let's check Java `listSessions`. It read `StorageAppState`? No, it - # just read sessions. - # Let's stick to the Java logic if it works, or adapt to Python if it's better. - # Since `DatabaseSessionService` in Python merges them, maybe it's a newer - # feature in Python ADK that Java doesn't have or does differently. - # Let's check `FirestoreSessionService.java` again. - # In Java `listSessions`, it doesn't seem to fetch app/user state. - # In Java `appendEvent`, it updates app/user state if `state_delta` has - # `_app_` or `_user_` prefixes. - # Let's stick to the Java behavior unless it conflicts with Python interfaces. - # The Python `BaseSessionService` doesn't enforce merging, it just defines - # the interface. `DatabaseSessionService` implements merging. - # I'll stick to the Java behavior (no merging in get/list, only update in append) - # for now, as it's a port of Java. Or I can implement merging if it's easy. - # Let's look at Java `appendEvent`: - # It checks `_app_` and `_user_` prefixes in `state_delta` and updates - # separate collections! - # ```java - # firestore.collection(APP_STATE_COLLECTION).document(appName).set(...) - # ``` - # So it DOES use separate collections for app/user state. - # If it uses them, it should probably read them somewhere. In Java, it seems - # it might not read them in `getSession`? Wait, let's check `FirestoreSessionService.java` - # again. I see `listSessions` doesn't read them. `getSession` doesn't read them. - # That might be a bug or partial implementation in Java? Or maybe they are - # read elsewhere? - # In Python `DatabaseSessionService` reads them in `get_session` and `list_sessions`. - # Let's implement reading them in Python `FirestoreSessionService` to be - # consistent with Python ADK standards if possible, or at least support it. - # I'll implement it without merging first to match Java, then see if I should - # add it. The Java code didn't do it. - # Let's continue getting session. session_state = data.get("state", {}) From 565cc616d37d1b3bc9aa3ef878227c001fbbc1f6 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:40:37 -0600 Subject: [PATCH 07/12] Undo unintended changes --- contributing/samples/gepa/experiment.py | 1 + contributing/samples/gepa/run_experiment.py | 1 + 2 files changed, 2 insertions(+) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2710c3894c..f3751206a8 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,6 +43,7 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib + import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index e31db15788..d857da9635 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,6 +25,7 @@ from absl import flags import experiment from google.genai import types + import utils _OUTPUT_DIR = flags.DEFINE_string( From 9387cb3e948875a885a4d90e54c80ff5f30ea90a Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:42:04 -0600 Subject: [PATCH 08/12] Sorting imports --- src/google/adk/sessions/firestore_session_service.py | 4 ++-- tests/unittests/test_firestore_database_runner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py index 5ace2d94a9..306bc4c87e 100644 --- a/src/google/adk/sessions/firestore_session_service.py +++ b/src/google/adk/sessions/firestore_session_service.py @@ -14,10 +14,10 @@ from __future__ import annotations -import logging -import os from datetime import datetime from datetime import timezone +import logging +import os from typing import Any from typing import Optional diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/test_firestore_database_runner.py index a4e8fb1889..4b51ffb99a 100644 --- a/tests/unittests/test_firestore_database_runner.py +++ b/tests/unittests/test_firestore_database_runner.py @@ -16,9 +16,9 @@ from unittest import mock -import pytest from google.adk.agents.base_agent import BaseAgent from google.adk.firestore_database_runner import create_firestore_runner +import pytest @pytest.fixture From 03910e99e24de7dc376177f9448928c6ce0f48fa Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 15:29:27 -0600 Subject: [PATCH 09/12] Fix async mocks --- .../memory/test_firestore_memory_service.py | 11 +++++--- .../test_firestore_session_service.py | 26 ++++++++++++------- .../test_firestore_database_runner.py | 25 +++++++++++++----- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py index 6cd878f336..d7497735aa 100644 --- a/tests/unittests/memory/test_firestore_memory_service.py +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -24,15 +24,18 @@ @pytest.fixture def mock_firestore_client(): - client = mock.AsyncMock() - collection_ref = mock.AsyncMock() + client = mock.MagicMock() + collection_ref = mock.MagicMock() client.collection_group.return_value = collection_ref + + # where() should return self (collection_ref) to allow chaining collection_ref.where.return_value = collection_ref # Mock get() for documents - doc_snapshot = mock.AsyncMock() + doc_snapshot = mock.MagicMock() doc_snapshot.to_dict.return_value = {} - collection_ref.get.return_value = [doc_snapshot] + + collection_ref.get = mock.AsyncMock(return_value=[doc_snapshot]) return client diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index 6048097997..294b8114fe 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -23,12 +23,11 @@ @pytest.fixture def mock_firestore_client(): - client = mock.AsyncMock() - # Mock collection and document references - collection_ref = mock.AsyncMock() - doc_ref = mock.AsyncMock() - subcollection_ref = mock.AsyncMock() - subdoc_ref = mock.AsyncMock() + client = mock.MagicMock() + collection_ref = mock.MagicMock() + doc_ref = mock.MagicMock() + subcollection_ref = mock.MagicMock() + subdoc_ref = mock.MagicMock() client.collection.return_value = collection_ref collection_ref.document.return_value = doc_ref @@ -36,15 +35,24 @@ def mock_firestore_client(): subcollection_ref.document.return_value = subdoc_ref # Mock get() for documents - doc_snapshot = mock.AsyncMock() + doc_snapshot = mock.MagicMock() doc_snapshot.exists = False doc_snapshot.to_dict.return_value = {} - doc_ref.get.return_value = doc_snapshot - subdoc_ref.get.return_value = doc_snapshot + + doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + + # Mock subcollection get() (for events list in delete_session) + subcollection_ref.get = mock.AsyncMock(return_value=[]) # Mock collection group client.collection_group.return_value = collection_ref + # Mock batch + batch = mock.MagicMock() + client.batch.return_value = batch + batch.commit = mock.AsyncMock() + return client diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/test_firestore_database_runner.py index 4b51ffb99a..89c1a1f677 100644 --- a/tests/unittests/test_firestore_database_runner.py +++ b/tests/unittests/test_firestore_database_runner.py @@ -31,10 +31,15 @@ def mock_agent(): def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) - # Mock GcsArtifactService to avoid real client init - with mock.patch( - "google.adk.firestore_database_runner.GcsArtifactService" - ) as mock_gcs: + with ( + mock.patch( + "google.adk.firestore_database_runner.FirestoreSessionService" + ), + mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs, + ): runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket") assert runner is not None @@ -44,9 +49,15 @@ def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): def test_create_firestore_runner_with_env(mock_agent, monkeypatch): monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "env_bucket") - with mock.patch( - "google.adk.firestore_database_runner.GcsArtifactService" - ) as mock_gcs: + with ( + mock.patch( + "google.adk.firestore_database_runner.FirestoreSessionService" + ), + mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs, + ): runner = create_firestore_runner(mock_agent) assert runner is not None From 49c7bf5a02ad6ed3e3d74ba8b21bbf31f02d1dc5 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 16:50:08 -0600 Subject: [PATCH 10/12] Fixing tests again again again --- src/google/adk/firestore_database_runner.py | 1 + .../sessions/test_firestore_session_service.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/firestore_database_runner.py index b3abbd45b0..aa86dae022 100644 --- a/src/google/adk/firestore_database_runner.py +++ b/src/google/adk/firestore_database_runner.py @@ -56,6 +56,7 @@ def create_firestore_runner( memory_service = FirestoreMemoryService() return Runner( + app_name=agent.name, agent=agent, session_service=session_service, artifact_service=artifact_service, diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index 294b8114fe..21511a0dab 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -42,8 +42,21 @@ def mock_firestore_client(): doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) - # Mock subcollection get() (for events list in delete_session) + # Set methods used in create_session and delete_session to AsyncMock + subdoc_ref.set = mock.AsyncMock() + subdoc_ref.delete = mock.AsyncMock() + + # Mock events subcollection + events_collection_ref = mock.MagicMock() + subdoc_ref.collection.return_value = events_collection_ref + events_collection_ref.order_by.return_value = events_collection_ref + events_collection_ref.where.return_value = events_collection_ref + events_collection_ref.limit_to_last.return_value = events_collection_ref + events_collection_ref.get = mock.AsyncMock(return_value=[]) + + # Mock subcollection get() (for sessions listing) subcollection_ref.get = mock.AsyncMock(return_value=[]) + subcollection_ref.where.return_value = subcollection_ref # Mock collection group client.collection_group.return_value = collection_ref @@ -135,7 +148,7 @@ async def test_delete_session(mock_firestore_client): mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) event_doc = mock.AsyncMock() - events_ref.get.return_value = [event_doc] + events_ref.get = mock.AsyncMock(return_value=[event_doc]) await service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id From fbd16eba4d2e4f5d05c042e14c30e10bb433b239 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Wed, 1 Apr 2026 00:17:01 -0600 Subject: [PATCH 11/12] Empty commit From 5645fe8f4e29a78aa9ba10a911ba2c664e111515 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Wed, 1 Apr 2026 00:28:18 -0600 Subject: [PATCH 12/12] Fixing one more test to use the firestore mock client --- tests/unittests/memory/test_firestore_memory_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py index d7497735aa..00b0099782 100644 --- a/tests/unittests/memory/test_firestore_memory_service.py +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -40,8 +40,8 @@ def mock_firestore_client(): return client -def test_extract_keywords(): - service = FirestoreMemoryService() +def test_extract_keywords(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) text = "The quick brown fox jumps over the lazy dog." keywords = service._extract_keywords(text)