diff --git a/pyproject.toml b/pyproject.toml index 2789bcf82a..931ca994f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ test = [ "a2a-sdk>=0.3.0,<0.4.0", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ + "google-cloud-firestore>=2.11.0", "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent @@ -157,6 +158,7 @@ extensions = [ "beautifulsoup4>=3.2.2", # For load_web_page tool. "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ "docker>=7.0.0", # For ContainerCodeExecutor + "google-cloud-firestore>=2.11.0", # For Firestore services "kubernetes>=29.0.0", # For GkeCodeExecutor "k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent diff --git a/src/google/adk/errors/already_exists_error.py b/src/google/adk/errors/already_exists_error.py index 8bd14f9ad6..bf8d357a81 100644 --- a/src/google/adk/errors/already_exists_error.py +++ b/src/google/adk/errors/already_exists_error.py @@ -18,7 +18,7 @@ class AlreadyExistsError(Exception): """Represents an error that occurs when an entity already exists.""" - def __init__(self, message="The resource already exists."): + def __init__(self, message: str = "The resource already exists."): """Initializes the AlreadyExistsError exception. Args: diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/firestore_database_runner.py new file mode 100644 index 0000000000..aa86dae022 --- /dev/null +++ b/src/google/adk/firestore_database_runner.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from typing import Optional +from typing import TYPE_CHECKING + +from .artifacts.gcs_artifact_service import GcsArtifactService +from .memory.firestore_memory_service import FirestoreMemoryService +from .runners import Runner +from .sessions.firestore_session_service import FirestoreSessionService + +if TYPE_CHECKING: + from .agents.base_agent import BaseAgent + + +def create_firestore_runner( + agent: BaseAgent, + gcs_bucket_name: Optional[str] = None, + firestore_root_collection: Optional[str] = None, +) -> Runner: + """Creates a Runner configured with Firestore and GCS services. + + Args: + agent: The root agent to run. + gcs_bucket_name: The GCS bucket name for artifacts. + firestore_root_collection: The root collection name for Firestore. + + Returns: + A Runner instance configured with Firestore services. + """ + bucket_name = gcs_bucket_name or os.environ.get("ADK_GCS_BUCKET_NAME") + if not bucket_name: + raise ValueError( + "Required property 'ADK_GCS_BUCKET_NAME' is not set. This" + " is needed for the GcsArtifactService." + ) + artifact_service = GcsArtifactService(bucket_name=bucket_name) + + session_service = FirestoreSessionService( + root_collection=firestore_root_collection + ) + memory_service = FirestoreMemoryService() + + return Runner( + app_name=agent.name, + agent=agent, + session_service=session_service, + artifact_service=artifact_service, + memory_service=memory_service, + ) diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/memory/firestore_memory_service.py new file mode 100644 index 0000000000..ad9fb57992 --- /dev/null +++ b/src/google/adk/memory/firestore_memory_service.py @@ -0,0 +1,359 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import os +import re +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING + +from typing_extensions import override + +from . import _utils +from ..events.event import Event +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from google.cloud import firestore + + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_EVENTS_COLLECTION = "events" + +# Standard English stop words +DEFAULT_STOP_WORDS = { + "a", + "an", + "the", + "and", + "or", + "but", + "if", + "then", + "else", + "to", + "of", + "in", + "on", + "for", + "with", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "can", + "could", + "will", + "would", + "should", + "shall", + "may", + "might", + "must", + "up", + "down", + "out", + "in", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "i", + "me", + "my", + "myself", + "we", + "our", + "ours", + "ourselves", + "you", + "your", + "yours", + "yourself", + "yourselves", + "he", + "him", + "his", + "himself", + "she", + "her", + "hers", + "herself", + "it", + "its", + "itself", + "they", + "them", + "their", + "theirs", + "themselves", + "what", + "which", + "who", + "whom", + "this", + "that", + "these", + "those", + "am", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "having", + "do", + "does", + "did", + "doing", + "a", + "an", + "the", + "and", + "but", + "if", + "or", + "because", + "as", + "until", + "while", + "of", + "at", + "by", + "for", + "with", + "about", + "against", + "between", + "into", + "through", + "during", + "before", + "after", + "above", + "below", + "to", + "from", + "up", + "down", + "in", + "out", + "on", + "off", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "s", + "t", + "can", + "will", + "just", + "don", + "should", + "now", +} + + +class FirestoreMemoryService(BaseMemoryService): + """Memory service that uses Google Cloud Firestore as the backend.""" + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + events_collection: Optional[str] = None, + stop_words: Optional[set[str]] = None, + ): + """Initializes the Firestore memory service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + events_collection: The name of the events collection or collection group. + Defaults to 'events'. + stop_words: A set of words to ignore when extracting keywords. Defaults to + a standard English stop words list. + """ + if client is None: + from google.cloud import firestore + + self.client = firestore.AsyncClient() + else: + self.client = client + self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION + self.stop_words = ( + stop_words if stop_words is not None else DEFAULT_STOP_WORDS + ) + + @override + async def add_session_to_memory(self, session: Session) -> None: + """No-op. Assumes events are written to Firestore by FirestoreSessionService.""" + pass + + def _extract_keywords(self, text: str) -> set[str]: + """Extracts keywords from text, ignoring stop words.""" + words = re.findall(r"[A-Za-z]+", text.lower()) + return {word for word in words if word not in self.stop_words} + + async def _search_by_keyword( + self, app_name: str, user_id: str, keyword: str + ) -> list[MemoryEntry]: + """Searches for events matching a single keyword.""" + # This requires a collection group index in Firestore for 'events' with + # appName == X, userId == Y, and keywords array-contains Z. + query = ( + self.client.collection_group(self.events_collection) + .where("appName", "==", app_name) + .where("userId", "==", user_id) + .where("keywords", "array_contains", keyword) + ) + + docs = await query.get() + entries = [] + for doc in docs: + data = doc.to_dict() + if data and "event_data" in data: + try: + event = Event.model_validate(data["event_data"]) + if event.content: + entries.append( + MemoryEntry( + content=event.content, + author=event.author, + timestamp=_utils.format_timestamp(event.timestamp), + ) + ) + except Exception as e: + logger.warning("Failed to parse event from Firestore: %s", e) + + return entries + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Searches memory for events matching the query.""" + keywords = self._extract_keywords(query) + if not keywords: + return SearchMemoryResponse() + + # Search for each keyword concurrently + tasks = [ + self._search_by_keyword(app_name, user_id, keyword) + for keyword in keywords + ] + results = await asyncio.gather(*tasks) + + # Merge results and deduplicate by MemoryEntry content/author/timestamp + # (MemoryEntry is not hashable by default if it contains complex objects, + # so we might need to deduplicate by id if available, or by content string). + # Since we convert Event to MemoryEntry, we don't have event.id in MemoryEntry + # unless we add it. The Java code use custom hash/equals for MemoryEntry. + # In Python, MemoryEntry is a Pydantic model. We can deduplicate by model_dump_json() + # or by a custom key. + seen = set() + memories = [] + for result_list in results: + for entry in result_list: + # Deduplicate by a key of (author, content_text) + # Content might be complex, so let's use its json representation or text + content_text = "" + if entry.content and entry.content.parts: + content_text = " ".join( + [part.text for part in entry.content.parts if part.text] + ) + key = (entry.author, content_text, entry.timestamp) + if key not in seen: + seen.add(key) + memories.append(entry) + + return SearchMemoryResponse(memories=memories) diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py new file mode 100644 index 0000000000..306bc4c87e --- /dev/null +++ b/src/google/adk/sessions/firestore_session_service.py @@ -0,0 +1,356 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import logging +import os +from typing import Any +from typing import Optional + +from google.cloud import firestore +from pydantic import BaseModel + +from ..events.event import Event +from .base_session_service import BaseSessionService +from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsResponse +from .session import Session + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_ROOT_COLLECTION = "adk-session" +DEFAULT_SESSIONS_COLLECTION = "sessions" +DEFAULT_EVENTS_COLLECTION = "events" +DEFAULT_APP_STATE_COLLECTION = "app_states" +DEFAULT_USER_STATE_COLLECTION = "user_states" + + +class FirestoreSessionService(BaseSessionService): + """Session service that uses Google Cloud Firestore as the backend.""" + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + root_collection: Optional[str] = None, + ): + """Initializes the Firestore session service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + root_collection: The root collection name. Defaults to 'adk-session' or + the value of ADK_FIRESTORE_ROOT_COLLECTION env var. + """ + self.client = client or firestore.AsyncClient() + self.root_collection = ( + root_collection + or os.environ.get("ADK_FIRESTORE_ROOT_COLLECTION") + or DEFAULT_ROOT_COLLECTION + ) + self.sessions_collection = DEFAULT_SESSIONS_COLLECTION + self.events_collection = DEFAULT_EVENTS_COLLECTION + self.app_state_collection = DEFAULT_APP_STATE_COLLECTION + self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + + def _get_sessions_ref( + self, user_id: str + ) -> firestore.AsyncCollectionReference: + return ( + self.client.collection(self.root_collection) + .document(user_id) + .collection(self.sessions_collection) + ) + + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + """Creates a new session in Firestore.""" + if not session_id: + from google.adk.platform import uuid as platform_uuid + + session_id = platform_uuid.new_uuid() + + initial_state = state or {} + now = firestore.SERVER_TIMESTAMP + + session_ref = self._get_sessions_ref(user_id).document(session_id) + + # Check if session already exists + doc = await session_ref.get() + if doc.exists: + from ..errors.already_exists_error import AlreadyExistsError + + raise AlreadyExistsError(f"Session {session_id} already exists.") + + session_data = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": initial_state, + "createTime": now, + "updateTime": now, + } + + await session_ref.set(session_data) + + # We need a timestamp for the Session object. Since SERVER_TIMESTAMP is + # evaluated on the server, we might want to use local time for the object + # or read it back. Reading it back is expensive. We'll use local time for + # the object, but the DB will have SERVER_TIMESTAMP. + local_now = datetime.now(timezone.utc).timestamp() + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=initial_state, + events=[], + last_update_time=local_now, + ) + + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + """Gets a session from Firestore.""" + session_ref = self._get_sessions_ref(user_id).document(session_id) + doc = await session_ref.get() + + if not doc.exists: + return None + + data = doc.to_dict() + if not data: + return None + + # Fetch events + events_ref = session_ref.collection(self.events_collection) + query = events_ref.order_by("timestamp") + + if config: + if config.after_timestamp: + after_dt = datetime.fromtimestamp(config.after_timestamp) + query = query.where("timestamp", ">=", after_dt) + if config.num_recent_events: + query = query.limit_to_last(config.num_recent_events) + + events_docs = await query.get() + events = [] + for event_doc in events_docs: + event_data = event_doc.to_dict() + if event_data and "event_data" in event_data: + # The Java code serializes individual fields, but Python schema/v1 uses + # JSON serialization of the whole event. We'll stick to Pythonic JSON + # serialization (event.model_dump) for consistency with Python ADK. + ed = event_data["event_data"] + # Restore timestamp if needed, or assume it's in event_data + events.append(Event.model_validate(ed)) + + # Let's continue getting session. + session_state = data.get("state", {}) + + # Convert timestamp + update_time = data.get("updateTime") + last_update_time = 0.0 + if update_time: + # If it's a datetime object (Firestore might return it) + if isinstance(update_time, datetime): + last_update_time = update_time.timestamp() + else: + # Assuming it's a string or float + try: + last_update_time = float(update_time) + except (ValueError, TypeError): + pass + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=session_state, + events=events, + last_update_time=last_update_time, + ) + + async def list_sessions( + self, *, app_name: str, user_id: Optional[str] = None + ) -> ListSessionsResponse: + """Lists sessions from Firestore.""" + # If user_id is provided, we can list directly. + # If not, we might need a collection group query or list all users first. + # Java listSessions takes appName and userId. It always scopes to user. + # Python list_sessions has user_id optional. + # If user_id is None, we should list all sessions for the app across all users. + # This requires a collection group query on 'sessions'. + if user_id: + query = self._get_sessions_ref(user_id).where("appName", "==", app_name) + docs = await query.get() + else: + # Collection group query + query = self.client.collection_group(self.sessions_collection).where( + "appName", "==", app_name + ) + docs = await query.get() + + sessions = [] + for doc in docs: + data = doc.to_dict() + if data: + # Session state is empty for listing as per in_memory + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state={}, # Empty state for listing + events=[], # Empty events for listing + last_update_time=0.0, # Or parse from updateTime + ) + ) + + return ListSessionsResponse(sessions=sessions) + + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + """Deletes a session and its events from Firestore.""" + session_ref = self._get_sessions_ref(user_id).document(session_id) + + # Delete events subcollection first (Firestore requires manual subcollection deletion) + events_ref = session_ref.collection(self.events_collection) + events_docs = await events_ref.get() + + # Batch delete + batch = self.client.batch() + for event_doc in events_docs: + batch.delete(event_doc.reference) + await batch.commit() + + # Delete session doc + await session_ref.delete() + + async def append_event(self, session: Session, event: Event) -> Event: + """Appends an event to a session in Firestore.""" + if event.partial: + return event + + # Apply temp state to in-memory session (from base class) + self._apply_temp_state(session, event) + event = self._trim_temp_delta_state(event) + + session_ref = self._get_sessions_ref(session.user_id).document(session.id) + + # Handle state deltas (app and user state) + if event.actions and event.actions.state_delta: + state_delta = event.actions.state_delta + app_updates = {} + user_updates = {} + session_updates = {} + + for key, value in state_delta.items(): + if key.startswith("_app_"): + app_updates[key[len("_app_") :]] = value + elif key.startswith("_user_"): + user_updates[key[len("_user_") :]] = value + else: + session_updates[key] = value + + # Update session doc with new state and updateTime + # We'll do it outside the batch or inside if we can. + # Let's use batch for everything to be atomic. + # Wait, I didn't add session_ref to batch yet. + # Let's create a batch. + batch = self.client.batch() + + if app_updates: + app_ref = self.client.collection(self.app_state_collection).document( + session.app_name + ) + batch.set(app_ref, app_updates, merge=True) + + if user_updates: + user_ref = ( + self.client.collection(self.user_state_collection) + .document(session.app_name) + .collection("users") + .document(session.user_id) + ) + batch.set(user_ref, user_updates, merge=True) + + # Update session state in-memory first + for k, v in session_updates.items(): + session.state[k] = v + + # Update session doc + batch.update( + session_ref, + { + "state": session.state, + "updateTime": firestore.SERVER_TIMESTAMP, + }, + ) + + # Add event + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) + # Store event data as JSON serialized string or dict + event_data = event.model_dump(exclude_none=True, mode="json") + batch.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) + + await batch.commit() + else: + # No state delta, just add event and update session timestamp + batch = self.client.batch() + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) + event_data = event.model_dump(exclude_none=True, mode="json") + batch.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) + batch.update(session_ref, {"updateTime": firestore.SERVER_TIMESTAMP}) + await batch.commit() + + # Also update the in-memory session (adds event to list) + await super().append_event(session, event) + return event diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py new file mode 100644 index 0000000000..00b0099782 --- /dev/null +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -0,0 +1,101 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.events.event import Event +from google.adk.memory.firestore_memory_service import FirestoreMemoryService +from google.genai import types +import pytest + + +@pytest.fixture +def mock_firestore_client(): + client = mock.MagicMock() + collection_ref = mock.MagicMock() + client.collection_group.return_value = collection_ref + + # where() should return self (collection_ref) to allow chaining + collection_ref.where.return_value = collection_ref + + # Mock get() for documents + doc_snapshot = mock.MagicMock() + doc_snapshot.to_dict.return_value = {} + + collection_ref.get = mock.AsyncMock(return_value=[doc_snapshot]) + + return client + + +def test_extract_keywords(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + text = "The quick brown fox jumps over the lazy dog." + keywords = service._extract_keywords(text) + + # Check that stopwords like "the", "over" are removed + assert "the" not in keywords + assert "over" not in keywords + assert "quick" in keywords + assert "brown" in keywords + assert "fox" in keywords + assert "jumps" in keywords + assert "lazy" in keywords + assert "dog" in keywords + + +@pytest.mark.asyncio +async def test_search_memory_empty_query(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="" + ) + assert not response.memories + mock_firestore_client.collection_group.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_memory_with_results(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick fox" + + # Mock document snapshot to return event data + doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + 0 + ] + event = Event( + invocation_id="test_inv", + author="user", + content=types.Content(parts=[types.Part(text="quick fox jumps")]), + ) + doc_snapshot.to_dict.return_value = { + "event_data": event.model_dump(exclude_none=True, mode="json") + } + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert response.memories + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + + # Verify Firestore calls + mock_firestore_client.collection_group.assert_called_with("events") + collection_ref = mock_firestore_client.collection_group.return_value + # Verify where calls (order might vary, so we just check it was called or check the chain) + collection_ref.where.assert_called() diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py new file mode 100644 index 0000000000..21511a0dab --- /dev/null +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -0,0 +1,234 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.events.event import Event +from google.adk.sessions.firestore_session_service import FirestoreSessionService +import pytest + + +@pytest.fixture +def mock_firestore_client(): + client = mock.MagicMock() + collection_ref = mock.MagicMock() + doc_ref = mock.MagicMock() + subcollection_ref = mock.MagicMock() + subdoc_ref = mock.MagicMock() + + client.collection.return_value = collection_ref + collection_ref.document.return_value = doc_ref + doc_ref.collection.return_value = subcollection_ref + subcollection_ref.document.return_value = subdoc_ref + + # Mock get() for documents + doc_snapshot = mock.MagicMock() + doc_snapshot.exists = False + doc_snapshot.to_dict.return_value = {} + + doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + + # Set methods used in create_session and delete_session to AsyncMock + subdoc_ref.set = mock.AsyncMock() + subdoc_ref.delete = mock.AsyncMock() + + # Mock events subcollection + events_collection_ref = mock.MagicMock() + subdoc_ref.collection.return_value = events_collection_ref + events_collection_ref.order_by.return_value = events_collection_ref + events_collection_ref.where.return_value = events_collection_ref + events_collection_ref.limit_to_last.return_value = events_collection_ref + events_collection_ref.get = mock.AsyncMock(return_value=[]) + + # Mock subcollection get() (for sessions listing) + subcollection_ref.get = mock.AsyncMock(return_value=[]) + subcollection_ref.where.return_value = subcollection_ref + + # Mock collection group + client.collection_group.return_value = collection_ref + + # Mock batch + batch = mock.MagicMock() + client.batch.return_value = batch + batch.commit = mock.AsyncMock() + + return client + + +@pytest.mark.asyncio +async def test_create_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session = await service.create_session(app_name=app_name, user_id=user_id) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id + + # Verify Firestore calls + mock_firestore_client.collection.assert_called_once_with("adk-session") + collection_ref = mock_firestore_client.collection.return_value + collection_ref.document.assert_called_once_with(user_id) + doc_ref = collection_ref.document.return_value + doc_ref.collection.assert_called_once_with("sessions") + sessions_ref = doc_ref.collection.return_value + sessions_ref.document.assert_called_once_with(session.id) + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.set.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_session_not_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is None + + +@pytest.mark.asyncio +async def test_get_session_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Mock document snapshot to return data + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": {"key": "value"}, + "updateTime": 1234567890.0, + } + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.id == session_id + assert session.state == {"key": "value"} + + +@pytest.mark.asyncio +async def test_delete_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Mock events subcollection + events_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) + event_doc = mock.AsyncMock() + events_ref.get = mock.AsyncMock(return_value=[event_doc]) + + await service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + # Verify events deletion + events_ref.get.assert_called_once() + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + batch.delete.assert_called_once_with(event_doc.reference) + batch.commit.assert_called_once() + + # Verify session deletion + session_doc_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + ) + session_doc_ref.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_append_event(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + event = Event(invocation_id="test_inv", author="user") + + await service.append_event(session, event) + + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + batch.set.assert_called_once() # For event + batch.update.assert_called_once() # For session updateTime + batch.commit.assert_called_once() + + +@pytest.mark.asyncio +async def test_append_event_with_state_delta(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + + # Using MagicMock for Event to bypass complex pydantic validation for test + event = mock.MagicMock() + event.partial = False + event.id = "test_event_id" + # Mock actions.state_delta + event.actions.state_delta = { + "_app_my_key": "app_val", + "_user_my_key": "user_val", + "session_key": "session_val", + } + # Mock model_dump to return valid event data + event.model_dump.return_value = {"id": "test_event_id", "author": "user"} + + await service.append_event(session, event) + + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + + # Verify app state set + # In code: batch.set(app_ref, app_updates, merge=True) + # But app_ref is a mock! Which mock? + # It's mock_firestore_client.collection().document() + # In fixture: collection_ref = mock.AsyncMock() + # doc_ref = mock.AsyncMock() + # client.collection.return_value = collection_ref + # collection_ref.document.return_value = doc_ref + # So batch.set is called with app_ref (which is doc_ref) + batch.set.assert_called() + + # Verify session state updated in memory + assert session.state["session_key"] == "session_val" + + # Verify batch update was called for session + batch.update.assert_called_once() + + batch.commit.assert_called_once() diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/test_firestore_database_runner.py new file mode 100644 index 0000000000..89c1a1f677 --- /dev/null +++ b/tests/unittests/test_firestore_database_runner.py @@ -0,0 +1,73 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.firestore_database_runner import create_firestore_runner +import pytest + + +@pytest.fixture +def mock_agent(): + agent = mock.MagicMock(spec=BaseAgent) + agent.name = "test_agent" + return agent + + +def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): + monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) + + with ( + mock.patch( + "google.adk.firestore_database_runner.FirestoreSessionService" + ), + mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs, + ): + runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket") + + assert runner is not None + mock_gcs.assert_called_once_with(bucket_name="test_bucket") + + +def test_create_firestore_runner_with_env(mock_agent, monkeypatch): + monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "env_bucket") + + with ( + mock.patch( + "google.adk.firestore_database_runner.FirestoreSessionService" + ), + mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs, + ): + runner = create_firestore_runner(mock_agent) + + assert runner is not None + mock_gcs.assert_called_once_with(bucket_name="env_bucket") + + +def test_create_firestore_runner_missing_bucket(mock_agent, monkeypatch): + monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) + + with pytest.raises( + ValueError, match="Required property 'ADK_GCS_BUCKET_NAME' is not set" + ): + create_firestore_runner(mock_agent)