diff --git a/.env.example b/.env.example index 0a2462f..4982435 100644 --- a/.env.example +++ b/.env.example @@ -24,5 +24,11 @@ REPORT_GENERATION_DB__DRIVER="sqlite" REPORT_GENERATION_DB__DATABASE="implementations/report_generation/data/OnlineRetail.db" REPORT_GENERATION_DB__QUERY__MODE="ro" +# Vertex AI Search (custom knowledge base) - no API key needed, uses ADC +# On Coder/GCE workspaces the attached service account handles auth automatically. +# Required IAM roles on the service account: roles/discoveryengine.viewer, roles/aiplatform.user +GOOGLE_CLOUD_LOCATION="us-central1" +VERTEX_AI_DATASTORE_ID="projects/{project}/locations/global/collections/default_collection/dataStores/{datastore-id}" + # Report Generation (all optional, defaults are in implementations/report_generation/env_vars.py) REPORT_GENERATION_OUTPUT_PATH="..." diff --git a/aieng-eval-agents/aieng/agent_evals/configs.py b/aieng-eval-agents/aieng/agent_evals/configs.py index ab25b4e..05e48f2 100644 --- a/aieng-eval-agents/aieng/agent_evals/configs.py +++ b/aieng-eval-agents/aieng/agent_evals/configs.py @@ -160,6 +160,21 @@ class Configs(BaseSettings): web_search_base_url: str | None = Field(default=None, description="Base URL for web search service.") web_search_api_key: SecretStr | None = Field(default=None, description="API key for web search service.") + # === Vertex AI Search (custom knowledge base) === + google_cloud_location: str = Field( + default="us-central1", + description="GCP region for Vertex AI model calls. Must match a region that supports Gemini.", + ) + vertex_datastore_id: str | None = Field( + default=None, + validation_alias="VERTEX_AI_DATASTORE_ID", + description=( + "Full Vertex AI Search data store resource name. " + "Format: projects/{project}/locations/global/collections/default_collection/dataStores/{id}. " + "Authentication uses Application Default Credentials (ADC) — no API key required." + ), + ) + # === Report Generation === # Defaults are set in the implementations/report_generation/env_vars.py file report_generation_output_path: str | None = Field( diff --git a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/__init__.py b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/__init__.py index 9252b12..a9d058f 100644 --- a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/__init__.py +++ b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/__init__.py @@ -21,14 +21,13 @@ format_response_with_citations, ) -from .agent import KnowledgeAgentManager, KnowledgeGroundedAgent +from .agent import KnowledgeGroundedAgent from .data import DeepSearchQADataset, DSQAExample __all__ = [ # Agent "KnowledgeGroundedAgent", - "KnowledgeAgentManager", # Grounding tool "create_google_search_tool", "format_response_with_citations", diff --git a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py index 5869061..06b5c8e 100644 --- a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py +++ b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py @@ -747,80 +747,3 @@ def answer( """ logger.info(f"Answering question (sync): {question[:100]}...") return asyncio.run(self.answer_async(question, session_id)) - - -class KnowledgeAgentManager: - """Manages KnowledgeGroundedAgent lifecycle with lazy initialization. - - This class provides convenient lifecycle management for the knowledge agent, - with lazy initialization and state tracking. - - Parameters - ---------- - config : Configs, optional - Configuration object for client setup. If not provided, creates default. - - Examples - -------- - >>> manager = KnowledgeAgentManager() - >>> agent = manager.agent - >>> response = await agent.answer_async("What is quantum computing?") - >>> print(response.text) - >>> manager.close() - """ - - def __init__( - self, - config: Configs | None = None, - enable_caching: bool = True, - enable_planning: bool = True, - enable_compaction: bool = True, - ) -> None: - """Initialize the client manager. - - Parameters - ---------- - config : Configs, optional - Configuration object. If not provided, creates default config. - enable_caching : bool, default True - Whether to enable context caching. - enable_planning : bool, default True - Whether to enable built-in planning (Gemini thinking mode). - enable_compaction : bool, default True - Whether to enable context compaction. - """ - self._config = config - self._enable_caching = enable_caching - self._enable_planning = enable_planning - self._enable_compaction = enable_compaction - self._agent: KnowledgeGroundedAgent | None = None - self._initialized = False - - @property - def config(self) -> Configs: - """Get or create the config instance.""" - if self._config is None: - self._config = Configs() # type: ignore[call-arg] - return self._config - - @property - def agent(self) -> KnowledgeGroundedAgent: - """Get or create the knowledge-grounded agent.""" - if self._agent is None: - self._agent = KnowledgeGroundedAgent( - config=self.config, - enable_caching=self._enable_caching, - enable_planning=self._enable_planning, - enable_compaction=self._enable_compaction, - ) - self._initialized = True - return self._agent - - def close(self) -> None: - """Close all initialized clients and reset state.""" - self._agent = None - self._initialized = False - - def is_initialized(self) -> bool: - """Check if any clients have been initialized.""" - return self._initialized diff --git a/aieng-eval-agents/aieng/agent_evals/tools/__init__.py b/aieng-eval-agents/aieng/agent_evals/tools/__init__.py index c51ab67..d1a71f6 100644 --- a/aieng-eval-agents/aieng/agent_evals/tools/__init__.py +++ b/aieng-eval-agents/aieng/agent_evals/tools/__init__.py @@ -2,6 +2,7 @@ This package provides modular tools for: - Google Search (search.py) +- Vertex AI Search / custom knowledge base (vertex_search.py) - Web content fetching - HTML and PDF (web.py) - File downloading and searching - CSV, XLSX, text (file.py) - SQL Database access (sql_database.py) @@ -24,6 +25,7 @@ google_search, ) from .sql_database import ReadOnlySqlDatabase, ReadOnlySqlPolicy +from .vertex_search import create_vertex_search_tool, vertex_search from .web import ( create_web_fetch_tool, web_fetch, @@ -31,13 +33,15 @@ __all__ = [ - # Search tools + # Google Search tools "create_google_search_tool", "google_search", "format_response_with_citations", - "google_search", "GroundedResponse", "GroundingChunk", + # Vertex AI Search tools (custom knowledge base) + "create_vertex_search_tool", + "vertex_search", # Web tools (HTML pages and PDFs) "web_fetch", "create_web_fetch_tool", diff --git a/aieng-eval-agents/aieng/agent_evals/tools/vertex_search.py b/aieng-eval-agents/aieng/agent_evals/tools/vertex_search.py new file mode 100644 index 0000000..99bc5b2 --- /dev/null +++ b/aieng-eval-agents/aieng/agent_evals/tools/vertex_search.py @@ -0,0 +1,252 @@ +"""Vertex AI Search tool for knowledge-grounded QA using a custom data store. + +This module provides a search tool that queries a Vertex AI Search data store, +returning grounded summaries with document citations. Unlike the Google Search +tool, content is retrieved by the grounding mechanism — no separate fetch step +is required and no API key is needed (authentication uses ADC). +""" + +import logging +from typing import Any + +from aieng.agent_evals.configs import Configs +from google.adk.tools.function_tool import FunctionTool +from google.genai import Client, types + + +logger = logging.getLogger(__name__) + + +def _parse_project_from_datastore_id(datastore_id: str) -> str | None: + """Parse GCP project ID from a Vertex AI Search data store resource name. + + Parameters + ---------- + datastore_id : str + Full resource name, e.g. + ``projects/my-project/locations/global/collections/default_collection/dataStores/my-store``. + + Returns + ------- + str or None + The project ID, or None if the resource name is not in the expected format. + """ + parts = datastore_id.split("/") + if len(parts) >= 2 and parts[0] == "projects": + return parts[1] + return None + + +def _extract_datastore_sources(response: Any) -> list[dict[str, str]]: + """Extract grounding sources from a Vertex AI Search grounded response. + + Vertex AI Search returns ``retrieved_context`` chunks (not ``web`` chunks). + Each chunk has a ``uri`` (GCS path or document resource name) and an + optional ``title``. + + Parameters + ---------- + response : Any + The Gemini API response object from a Vertex AI Search grounded call. + + Returns + ------- + list[dict[str, str]] + List of source dictionaries with ``'title'`` and ``'uri'`` keys. + Sources with an empty URI are excluded. + """ + sources: list[dict[str, str]] = [] + if not response.candidates: + return sources + + gm = getattr(response.candidates[0], "grounding_metadata", None) + if not gm or not hasattr(gm, "grounding_chunks") or not gm.grounding_chunks: + return sources + + for chunk in gm.grounding_chunks: + rc = getattr(chunk, "retrieved_context", None) + if rc: + # Vertex AI Search returns 'document_name' (full resource path), not 'uri' + uri = getattr(rc, "document_name", "") or "" + title = getattr(rc, "title", "") or "" + if uri: + sources.append({"title": title, "uri": uri}) + + return sources + + +async def _vertex_search_async( + query: str, + model: str, + datastore_id: str, + location: str, + temperature: float = 1.0, +) -> dict[str, Any]: + """Query a Vertex AI Search data store with grounding enabled. + + Parameters + ---------- + query : str + The search query. + model : str + The Gemini model to use (accessed via the Vertex AI endpoint). + datastore_id : str + Full resource name of the Vertex AI Search data store. + location : str + GCP region for the Vertex AI model call (e.g. ``'us-central1'``). + This is the *compute* region and may differ from the data store's + ``global`` location. + temperature : float, default=1.0 + Temperature for generation. + + Returns + ------- + dict + Search results with the following keys: + + - **status** (str): ``"success"`` or ``"error"`` + - **summary** (str): Grounded text answer drawn from the data store + - **sources** (list[dict]): Each entry has: + - **title** (str): Document title + - **uri** (str): GCS path or Vertex AI document resource name + - **source_count** (int): Number of sources cited (success case only) + - **error** (str): Error message (error case only) + """ + project = _parse_project_from_datastore_id(datastore_id) + client = Client(vertexai=True, project=project, location=location) + try: + response = client.models.generate_content( + model=model, + contents=query, + config=types.GenerateContentConfig( + tools=[ + types.Tool(retrieval=types.Retrieval(vertex_ai_search=types.VertexAISearch(datastore=datastore_id))) + ], + temperature=temperature, + ), + ) + + summary = "" + if response.candidates and response.candidates[0].content and response.candidates[0].content.parts: + for part in response.candidates[0].content.parts: + if hasattr(part, "text") and part.text: + summary += part.text + + sources = _extract_datastore_sources(response) + return { + "status": "success", + "summary": summary, + "sources": sources, + "source_count": len(sources), + } + + except Exception as e: + logger.exception("Vertex AI Search failed: %s", e) + return { + "status": "error", + "error": str(e), + "summary": "", + "sources": [], + } + finally: + client.close() + + +async def vertex_search(query: str, model: str | None = None) -> dict[str, Any]: + """Search the custom knowledge base and return grounded results with citations. + + Use this tool to find information from internal documents and knowledge bases. + Results are grounded directly from retrieved document content — the summary + is more reliable than web search snippets and no separate fetch step is needed. + + Authentication uses Application Default Credentials (ADC) — no API key is + required. On GCE/Coder workspaces the attached service account is used + automatically. + + Parameters + ---------- + query : str + The search query. Be specific and include key terms. + model : str, optional + The Gemini model to use. Defaults to ``config.default_worker_model``. + + Returns + ------- + dict + Search results with the following keys: + + - **status** (str): ``"success"`` or ``"error"`` + - **summary** (str): Grounded answer from the knowledge base + - **sources** (list[dict]): Each with ``'title'`` and ``'uri'`` + - **source_count** (int): Number of sources cited (success case only) + - **error** (str): Error message (error case only) + + Raises + ------ + ValueError + If ``VERTEX_AI_DATASTORE_ID`` is not set in config. + + Examples + -------- + >>> result = await vertex_search("What is the company leave policy?") + >>> print(result["summary"]) + >>> for source in result["sources"]: + ... print(f"{source['title']}: {source['uri']}") + """ + config = Configs() # type: ignore[call-arg] + if not config.vertex_datastore_id: + raise ValueError( + "VERTEX_AI_DATASTORE_ID must be set to use vertex_search. " + "Set it in your .env file or as an environment variable." + ) + if model is None: + model = config.default_worker_model + + return await _vertex_search_async( + query, + model=model, + datastore_id=config.vertex_datastore_id, + location=config.google_cloud_location, + temperature=config.default_temperature, + ) + + +def create_vertex_search_tool(config: Configs | None = None) -> FunctionTool: + """Create a search tool backed by a custom Vertex AI Search data store. + + Authentication uses Application Default Credentials (ADC) — no API key is + needed. On GCE/Coder workspaces the attached service account handles auth + automatically. + + Parameters + ---------- + config : Configs, optional + Configuration settings. If not provided, creates default config. + Must have ``vertex_datastore_id`` set. + + Returns + ------- + FunctionTool + An ADK-compatible tool that returns grounded summaries with citations. + + Raises + ------ + ValueError + If ``VERTEX_AI_DATASTORE_ID`` is not set in config. + + Examples + -------- + >>> from aieng.agent_evals.tools import create_vertex_search_tool + >>> tool = create_vertex_search_tool() + >>> agent = Agent(tools=[tool]) + """ + if config is None: + config = Configs() # type: ignore[call-arg] + + if not config.vertex_datastore_id: + raise ValueError( + "VERTEX_AI_DATASTORE_ID must be set to use create_vertex_search_tool. " + "Set it in your .env file or as an environment variable." + ) + + return FunctionTool(func=vertex_search) diff --git a/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py index 2c0c1cc..fa8e348 100644 --- a/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py +++ b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py @@ -5,7 +5,6 @@ import pytest from aieng.agent_evals.knowledge_qa.agent import ( AgentResponse, - KnowledgeAgentManager, KnowledgeGroundedAgent, StepExecution, ) @@ -471,64 +470,6 @@ def test_agent_with_custom_model( assert agent.model == "gemini-2.5-pro" -class TestKnowledgeAgentManager: - """Tests for the KnowledgeAgentManager class.""" - - @patch("aieng.agent_evals.knowledge_qa.agent.PlanReActPlanner") - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_read_file_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_grep_file_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_fetch_file_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_web_fetch_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - def test_lazy_initialization(self, *_mocks): - """Test that agent is lazily initialized.""" - with patch("aieng.agent_evals.knowledge_qa.agent.Configs") as mock_config_class: - mock_config = MagicMock() - mock_config.default_worker_model = "gemini-2.5-flash" - mock_config.default_temperature = 0.0 - mock_config.openai_api_key.get_secret_value.return_value = "test-api-key" - mock_config_class.return_value = mock_config - - manager = KnowledgeAgentManager(enable_caching=False, enable_compaction=False) - - # Should not be initialized yet - assert not manager.is_initialized() - - # Access agent to trigger initialization - _ = manager.agent - - # Now should be initialized - assert manager.is_initialized() - - @patch("aieng.agent_evals.knowledge_qa.agent.PlanReActPlanner") - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_read_file_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_grep_file_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_fetch_file_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_web_fetch_tool") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - def test_close(self, *_mocks): - """Test closing the client manager.""" - with patch("aieng.agent_evals.knowledge_qa.agent.Configs") as mock_config_class: - mock_config = MagicMock() - mock_config.default_worker_model = "gemini-2.5-flash" - mock_config.default_temperature = 0.0 - mock_config.openai_api_key.get_secret_value.return_value = "test-api-key" - mock_config_class.return_value = mock_config - - manager = KnowledgeAgentManager(enable_caching=False, enable_compaction=False) - _ = manager.agent - assert manager.is_initialized() - - manager.close() - assert not manager.is_initialized() - - class TestAgentResponse: """Tests for the AgentResponse model.""" diff --git a/aieng-eval-agents/tests/aieng/agent_evals/tools/test_vertex_search.py b/aieng-eval-agents/tests/aieng/agent_evals/tools/test_vertex_search.py new file mode 100644 index 0000000..106fb83 --- /dev/null +++ b/aieng-eval-agents/tests/aieng/agent_evals/tools/test_vertex_search.py @@ -0,0 +1,639 @@ +"""Tests for Vertex AI Search tool.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aieng.agent_evals.configs import Configs +from aieng.agent_evals.tools import create_vertex_search_tool, vertex_search +from aieng.agent_evals.tools.vertex_search import ( + _extract_datastore_sources, + _parse_project_from_datastore_id, + _vertex_search_async, +) +from dotenv import load_dotenv +from google.adk.tools.function_tool import FunctionTool + + +SAMPLE_DATASTORE_ID = "projects/my-project/locations/global/collections/default_collection/dataStores/my-store" + + +class TestParseProjectFromDatastoreId: + """Tests for _parse_project_from_datastore_id.""" + + def test_standard_resource_name(self): + """Test parsing project from a well-formed resource name.""" + assert _parse_project_from_datastore_id(SAMPLE_DATASTORE_ID) == "my-project" + + def test_numeric_project_id(self): + """Test parsing a numeric project number.""" + datastore_id = "projects/123456789/locations/global/collections/default_collection/dataStores/my-store" + assert _parse_project_from_datastore_id(datastore_id) == "123456789" + + def test_returns_none_for_invalid_format(self): + """Test that a non-resource-name string returns None.""" + assert _parse_project_from_datastore_id("not-a-resource-name") is None + + def test_returns_none_for_empty_string(self): + """Test that an empty string returns None.""" + assert _parse_project_from_datastore_id("") is None + + def test_returns_none_when_no_projects_prefix(self): + """Test that a string without 'projects' prefix returns None.""" + assert _parse_project_from_datastore_id("locations/global/collections/default_collection") is None + + +class TestExtractDatastoreSources: + """Tests for _extract_datastore_sources.""" + + def test_no_candidates_returns_empty(self): + """Test that an empty candidates list yields no sources.""" + response = MagicMock() + response.candidates = [] + assert _extract_datastore_sources(response) == [] + + def test_no_grounding_metadata_returns_empty(self): + """Test that a candidate with no grounding_metadata yields no sources.""" + candidate = MagicMock() + candidate.grounding_metadata = None + response = MagicMock() + response.candidates = [candidate] + assert _extract_datastore_sources(response) == [] + + def test_grounding_chunks_attribute_missing_returns_empty(self): + """Test that grounding_metadata without grounding_chunks yields no sources.""" + gm = MagicMock(spec=[]) # hasattr(gm, "grounding_chunks") → False + candidate = MagicMock() + candidate.grounding_metadata = gm + response = MagicMock() + response.candidates = [candidate] + assert _extract_datastore_sources(response) == [] + + def test_none_grounding_chunks_returns_empty(self): + """Test that grounding_chunks=None yields no sources.""" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = None + response = MagicMock() + response.candidates = [candidate] + assert _extract_datastore_sources(response) == [] + + def test_empty_grounding_chunks_returns_empty(self): + """Test that an empty grounding_chunks list yields no sources.""" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [] + response = MagicMock() + response.candidates = [candidate] + assert _extract_datastore_sources(response) == [] + + def test_single_valid_retrieved_context_chunk(self): + """Test that a single retrieved_context chunk with document_name is returned.""" + chunk = MagicMock() + chunk.retrieved_context.document_name = "projects/my-project/locations/global/collections/default_collection/dataStores/my-store/branches/0/documents/doc-1" + chunk.retrieved_context.title = "Company Policy" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk] + response = MagicMock() + response.candidates = [candidate] + + result = _extract_datastore_sources(response) + + assert result == [ + { + "title": "Company Policy", + "uri": "projects/my-project/locations/global/collections/default_collection/dataStores/my-store/branches/0/documents/doc-1", + } + ] + + def test_multiple_chunks_preserved_in_order(self): + """Test that multiple retrieved_context chunks are returned in order.""" + chunk1 = MagicMock() + chunk1.retrieved_context.document_name = ( + "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc1" + ) + chunk1.retrieved_context.title = "Document 1" + chunk2 = MagicMock() + chunk2.retrieved_context.document_name = ( + "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc2" + ) + chunk2.retrieved_context.title = "Document 2" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk1, chunk2] + response = MagicMock() + response.candidates = [candidate] + + result = _extract_datastore_sources(response) + + assert result == [ + { + "title": "Document 1", + "uri": "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc1", + }, + { + "title": "Document 2", + "uri": "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc2", + }, + ] + + def test_chunk_without_retrieved_context_is_skipped(self): + """Test that chunks with no retrieved_context attribute are ignored.""" + chunk_no_rc = MagicMock(spec=[]) # getattr returns None fallback + chunk_valid = MagicMock() + chunk_valid.retrieved_context.document_name = ( + "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/valid" + ) + chunk_valid.retrieved_context.title = "Valid Doc" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk_no_rc, chunk_valid] + response = MagicMock() + response.candidates = [candidate] + + result = _extract_datastore_sources(response) + + assert result == [ + { + "title": "Valid Doc", + "uri": "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/valid", + } + ] + + def test_chunk_with_none_retrieved_context_is_skipped(self): + """Test that chunks whose retrieved_context is falsy are skipped.""" + chunk = MagicMock() + chunk.retrieved_context = None + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk] + response = MagicMock() + response.candidates = [candidate] + + assert _extract_datastore_sources(response) == [] + + def test_empty_document_name_is_excluded(self): + """Test that a retrieved_context with an empty document_name is excluded.""" + chunk = MagicMock() + chunk.retrieved_context.document_name = "" + chunk.retrieved_context.title = "No Name" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk] + response = MagicMock() + response.candidates = [candidate] + + assert _extract_datastore_sources(response) == [] + + def test_none_document_name_is_excluded(self): + """Test that a retrieved_context with a None document_name is excluded.""" + chunk = MagicMock() + chunk.retrieved_context.document_name = None + chunk.retrieved_context.title = "None Name" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk] + response = MagicMock() + response.candidates = [candidate] + + assert _extract_datastore_sources(response) == [] + + def test_empty_title_is_preserved(self): + """Test that a chunk with an empty title is returned with empty title string.""" + chunk = MagicMock() + chunk.retrieved_context.document_name = ( + "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc" + ) + chunk.retrieved_context.title = "" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk] + response = MagicMock() + response.candidates = [candidate] + + result = _extract_datastore_sources(response) + + assert result == [ + {"title": "", "uri": "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc"} + ] + + def test_none_title_normalised_to_empty_string(self): + """Test that a None title is coerced to an empty string.""" + chunk = MagicMock() + chunk.retrieved_context.document_name = ( + "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc" + ) + chunk.retrieved_context.title = None + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk] + response = MagicMock() + response.candidates = [candidate] + + result = _extract_datastore_sources(response) + + assert result == [ + {"title": "", "uri": "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/doc"} + ] + + def test_only_first_candidate_is_used(self): + """Test that only the first candidate's grounding chunks are considered.""" + chunk1 = MagicMock() + chunk1.retrieved_context.document_name = ( + "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/first" + ) + chunk1.retrieved_context.title = "First" + chunk2 = MagicMock() + chunk2.retrieved_context.document_name = ( + "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/second" + ) + chunk2.retrieved_context.title = "Second" + candidate1 = MagicMock() + candidate1.grounding_metadata.grounding_chunks = [chunk1] + candidate2 = MagicMock() + candidate2.grounding_metadata.grounding_chunks = [chunk2] + response = MagicMock() + response.candidates = [candidate1, candidate2] + + result = _extract_datastore_sources(response) + + assert result == [ + { + "title": "First", + "uri": "projects/p/locations/global/collections/c/dataStores/s/branches/0/documents/first", + } + ] + + def test_web_chunks_are_not_retrieved_context_and_skipped(self): + """Test that web-type grounding chunks are skipped (not datastore sources).""" + chunk = MagicMock(spec=["web"]) # has 'web' but not 'retrieved_context' + chunk.web.uri = "https://example.com" + chunk.web.title = "Web Result" + candidate = MagicMock() + candidate.grounding_metadata.grounding_chunks = [chunk] + response = MagicMock() + response.candidates = [candidate] + + # getattr(chunk, "retrieved_context", None) returns None for spec=["web"] + assert _extract_datastore_sources(response) == [] + + +class TestVertexSearchAsync: + """Tests for _vertex_search_async.""" + + @pytest.mark.asyncio + async def test_success_returns_expected_structure(self): + """Test that a successful call returns the correct response dict.""" + mock_part = MagicMock() + mock_part.text = "The leave policy allows 20 days per year." + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_chunk = MagicMock() + mock_chunk.retrieved_context.document_name = "projects/my-project/locations/global/collections/default_collection/dataStores/my-store/branches/0/documents/policy" + mock_chunk.retrieved_context.title = "Leave Policy" + mock_candidate.grounding_metadata.grounding_chunks = [mock_chunk] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_client.close = MagicMock() + + with patch("aieng.agent_evals.tools.vertex_search.Client", return_value=mock_client): + result = await _vertex_search_async( + query="What is the leave policy?", + model="gemini-2.5-flash", + datastore_id=SAMPLE_DATASTORE_ID, + location="us-central1", + ) + + assert result["status"] == "success" + assert result["summary"] == "The leave policy allows 20 days per year." + assert result["sources"] == [ + { + "title": "Leave Policy", + "uri": "projects/my-project/locations/global/collections/default_collection/dataStores/my-store/branches/0/documents/policy", + } + ] + assert result["source_count"] == 1 + + @pytest.mark.asyncio + async def test_client_created_with_correct_project_and_location(self): + """Test that the client is created with the correct project and location.""" + mock_client = MagicMock() + mock_client.models.generate_content.return_value = MagicMock(candidates=[]) + mock_client.close = MagicMock() + + with patch("aieng.agent_evals.tools.vertex_search.Client", return_value=mock_client) as mock_cls: + await _vertex_search_async( + query="test", + model="gemini-2.5-flash", + datastore_id=SAMPLE_DATASTORE_ID, + location="us-central1", + ) + + mock_cls.assert_called_once_with(vertexai=True, project="my-project", location="us-central1") + + @pytest.mark.asyncio + async def test_client_closed_on_success(self): + """Test that the client is always closed after a successful call.""" + mock_client = MagicMock() + mock_client.models.generate_content.return_value = MagicMock(candidates=[]) + mock_client.close = MagicMock() + + with patch("aieng.agent_evals.tools.vertex_search.Client", return_value=mock_client): + await _vertex_search_async("test", "gemini-2.5-flash", SAMPLE_DATASTORE_ID, "us-central1") + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_client_closed_on_exception(self): + """Test that the client is closed even when generate_content raises.""" + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = RuntimeError("API error") + mock_client.close = MagicMock() + + with patch("aieng.agent_evals.tools.vertex_search.Client", return_value=mock_client): + result = await _vertex_search_async("test", "gemini-2.5-flash", SAMPLE_DATASTORE_ID, "us-central1") + + mock_client.close.assert_called_once() + assert result["status"] == "error" + assert "API error" in result["error"] + + @pytest.mark.asyncio + async def test_exception_returns_error_structure(self): + """Test that an exception produces a well-formed error response.""" + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = Exception("Connection timeout") + mock_client.close = MagicMock() + + with patch("aieng.agent_evals.tools.vertex_search.Client", return_value=mock_client): + result = await _vertex_search_async("test", "gemini-2.5-flash", SAMPLE_DATASTORE_ID, "us-central1") + + assert result["status"] == "error" + assert "Connection timeout" in result["error"] + assert result["summary"] == "" + assert result["sources"] == [] + + @pytest.mark.asyncio + async def test_uses_vertex_ai_search_tool_in_config(self): + """Test that generate_content is called with a VertexAISearch retrieval tool.""" + mock_client = MagicMock() + mock_client.models.generate_content.return_value = MagicMock(candidates=[]) + mock_client.close = MagicMock() + + with patch("aieng.agent_evals.tools.vertex_search.Client", return_value=mock_client): + await _vertex_search_async("test", "gemini-2.5-flash", SAMPLE_DATASTORE_ID, "us-central1") + + call_kwargs = mock_client.models.generate_content.call_args + config_arg = call_kwargs.kwargs["config"] + assert config_arg.tools is not None + tool = config_arg.tools[0] + assert tool.retrieval is not None + assert tool.retrieval.vertex_ai_search.datastore == SAMPLE_DATASTORE_ID + + @pytest.mark.asyncio + async def test_empty_candidates_returns_empty_summary_and_sources(self): + """Test that a response with no candidates returns empty summary and sources.""" + mock_client = MagicMock() + mock_client.models.generate_content.return_value = MagicMock(candidates=[]) + mock_client.close = MagicMock() + + with patch("aieng.agent_evals.tools.vertex_search.Client", return_value=mock_client): + result = await _vertex_search_async("test", "gemini-2.5-flash", SAMPLE_DATASTORE_ID, "us-central1") + + assert result["status"] == "success" + assert result["summary"] == "" + assert result["sources"] == [] + assert result["source_count"] == 0 + + +class TestCreateVertexSearchTool: + """Tests for create_vertex_search_tool.""" + + def test_raises_if_no_datastore_id(self): + """Test that ValueError is raised when vertex_datastore_id is not configured.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = None + + with pytest.raises(ValueError, match="VERTEX_AI_DATASTORE_ID"): + create_vertex_search_tool(config=mock_config) + + def test_raises_if_empty_datastore_id(self): + """Test that ValueError is raised for an empty vertex_datastore_id.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = "" + + with pytest.raises(ValueError, match="VERTEX_AI_DATASTORE_ID"): + create_vertex_search_tool(config=mock_config) + + def test_returns_function_tool(self): + """Test that a FunctionTool is returned when datastore_id is set.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = SAMPLE_DATASTORE_ID + + result = create_vertex_search_tool(config=mock_config) + + assert isinstance(result, FunctionTool) + + def test_tool_func_named_vertex_search(self): + """Test that the wrapped function is named 'vertex_search' for ADK discovery.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = SAMPLE_DATASTORE_ID + + result = create_vertex_search_tool(config=mock_config) + + assert result.func.__name__ == "vertex_search" + + @pytest.mark.asyncio + async def test_tool_calls_vertex_search_async(self): + """Test that the tool delegates to _vertex_search_async with config values.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = SAMPLE_DATASTORE_ID + mock_config.default_worker_model = "gemini-2.5-flash" + mock_config.default_temperature = 1.0 + mock_config.google_cloud_location = "us-central1" + + tool = create_vertex_search_tool(config=mock_config) + + expected = {"status": "success", "summary": "Answer.", "sources": [], "source_count": 0} + with ( + patch("aieng.agent_evals.tools.vertex_search.Configs", return_value=mock_config), + patch( + "aieng.agent_evals.tools.vertex_search._vertex_search_async", + new=AsyncMock(return_value=expected), + ) as mock_async, + ): + result = await tool.func("leave policy?") + + mock_async.assert_called_once_with( + "leave policy?", + model="gemini-2.5-flash", + datastore_id=SAMPLE_DATASTORE_ID, + location="us-central1", + temperature=1.0, + ) + assert result == expected + + +class TestVertexSearchPublicFunction: + """Tests for the standalone vertex_search public function.""" + + @pytest.mark.asyncio + async def test_raises_if_datastore_id_not_configured(self): + """Test that ValueError is raised when VERTEX_AI_DATASTORE_ID is not set.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = None + + with ( + patch("aieng.agent_evals.tools.vertex_search.Configs", return_value=mock_config), + pytest.raises(ValueError, match="VERTEX_AI_DATASTORE_ID"), + ): + await vertex_search("test query") + + @pytest.mark.asyncio + async def test_uses_default_model_from_config(self): + """Test that the worker model defaults to config when none is specified.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = SAMPLE_DATASTORE_ID + mock_config.default_worker_model = "gemini-2.5-flash" + mock_config.default_temperature = 1.0 + mock_config.google_cloud_location = "us-central1" + + expected = {"status": "success", "summary": "ok", "sources": [], "source_count": 0} + with ( + patch("aieng.agent_evals.tools.vertex_search.Configs", return_value=mock_config), + patch( + "aieng.agent_evals.tools.vertex_search._vertex_search_async", + new=AsyncMock(return_value=expected), + ) as mock_async, + ): + await vertex_search("test query") + + mock_async.assert_called_once_with( + "test query", + model="gemini-2.5-flash", + datastore_id=SAMPLE_DATASTORE_ID, + location="us-central1", + temperature=1.0, + ) + + @pytest.mark.asyncio + async def test_explicit_model_overrides_config(self): + """Test that passing a model explicitly overrides the config default.""" + mock_config = MagicMock() + mock_config.vertex_datastore_id = SAMPLE_DATASTORE_ID + mock_config.default_worker_model = "gemini-2.5-flash" + mock_config.default_temperature = 1.0 + mock_config.google_cloud_location = "us-central1" + + expected = {"status": "success", "summary": "ok", "sources": [], "source_count": 0} + with ( + patch("aieng.agent_evals.tools.vertex_search.Configs", return_value=mock_config), + patch( + "aieng.agent_evals.tools.vertex_search._vertex_search_async", + new=AsyncMock(return_value=expected), + ) as mock_async, + ): + await vertex_search("test query", model="gemini-2.5-pro") + + mock_async.assert_called_once_with( + "test query", + model="gemini-2.5-pro", + datastore_id=SAMPLE_DATASTORE_ID, + location="us-central1", + temperature=1.0, + ) + + +@pytest.mark.integration_test +class TestVertexSearchIntegration: + """Integration tests for the Vertex AI Search tool. + + These tests run against a real Vertex AI Search data store loaded with + ``aieng-eval-agents/tests/fixtures/vertex_test_data.jsonl`` (synthetic + Northstar Analytics content). Provision the store once before running: + + uv run python -m scripts.create_test_datastore --bucket + + Then set in .env: + VERTEX_AI_DATASTORE_ID="projects/agentic-ai-evaluation-bootcamp/locations/global/..." + GOOGLE_CLOUD_LOCATION="us-central1" + + Authentication is handled automatically via ADC / the GCE service account. + """ + + @pytest.fixture(autouse=True) + def skip_if_not_configured(self): + """Skip the entire class if VERTEX_AI_DATASTORE_ID is not set.""" + load_dotenv(verbose=False) + config = Configs() # type: ignore[call-arg] + if not config.vertex_datastore_id: + pytest.skip("VERTEX_AI_DATASTORE_ID not set — run scripts/create_test_datastore.py first") + + def test_create_vertex_search_tool_real(self): + """Test that a FunctionTool can be created against the real data store.""" + tool = create_vertex_search_tool() + assert isinstance(tool, FunctionTool) + assert tool.func.__name__ == "vertex_search" + + @pytest.mark.asyncio + async def test_response_structure(self): + """Test that vertex_search returns a well-formed response dict.""" + result = await vertex_search("What does Northstar Analytics do?") + + assert result["status"] == "success", f"Unexpected error: {result.get('error')}" + assert isinstance(result["summary"], str) + assert result["summary"], "Expected a non-empty summary" + assert isinstance(result["sources"], list) + assert isinstance(result["source_count"], int) + assert result["source_count"] == len(result["sources"]) + + @pytest.mark.asyncio + async def test_sources_use_uri_not_url(self): + """Test that sources contain 'uri' (GCS / resource name) not 'url'.""" + result = await vertex_search("What does Northstar Analytics do?") + + assert result["status"] == "success" + if result["sources"]: + source = result["sources"][0] + assert "uri" in source, "Sources must have a 'uri' key" + assert "title" in source, "Sources must have a 'title' key" + assert "url" not in source, "Sources must not use 'url' — this is not a web search tool" + + @pytest.mark.asyncio + async def test_grounding_professional_tier_price(self): + """Test that grounding retrieves the Professional tier price from the datastore. + + The fixture contains: Professional at $899/month. + This number does not exist in model training data (it is synthetic), + so its presence in the summary confirms the data store was consulted. + """ + result = await vertex_search("What is the monthly price for the Professional tier at Northstar Analytics?") + + assert result["status"] == "success", f"Search failed: {result.get('error')}" + assert result["source_count"] > 0, "Expected at least one grounding source" + assert "899" in result["summary"], ( + f"Expected '899' (Professional tier price) in summary, got: {result['summary']}" + ) + + @pytest.mark.asyncio + async def test_grounding_enterprise_sla(self): + """Test that grounding retrieves the Enterprise SLA uptime figure. + + The fixture contains: Enterprise tier 99.95% uptime guarantee. + """ + result = await vertex_search("What uptime does Northstar Analytics guarantee for Enterprise customers?") + + assert result["status"] == "success", f"Search failed: {result.get('error')}" + assert result["source_count"] > 0, "Expected at least one grounding source" + assert "99.95" in result["summary"], f"Expected '99.95' (Enterprise SLA) in summary, got: {result['summary']}" + + @pytest.mark.asyncio + async def test_grounding_api_rate_limit(self): + """Test that grounding retrieves the Enterprise API rate limit. + + The fixture contains: Enterprise tier 8,000 requests per minute. + """ + result = await vertex_search( + "How many API requests per minute does the Enterprise tier allow at Northstar Analytics?" + ) + + assert result["status"] == "success", f"Search failed: {result.get('error')}" + assert result["source_count"] > 0, "Expected at least one grounding source" + # Allow for comma-formatted (8,000) or plain (8000) + assert "8,000" in result["summary"] or "8000" in result["summary"], ( + f"Expected '8,000' or '8000' (Enterprise API limit) in summary, got: {result['summary']}" + ) diff --git a/aieng-eval-agents/tests/fixtures/vertex_test_data.jsonl b/aieng-eval-agents/tests/fixtures/vertex_test_data.jsonl new file mode 100644 index 0000000..acb8e6c --- /dev/null +++ b/aieng-eval-agents/tests/fixtures/vertex_test_data.jsonl @@ -0,0 +1,5 @@ +{"id":"northstar-founding","text":"Northstar Analytics was founded in 2019 by Dr. Elena Vasquez and Marcus Thibodeau in Waterloo, Ontario. The company specialises in real-time supply chain optimisation software.","title":"Company Founding","category":"company"} +{"id":"northstar-pricing","text":"Northstar Analytics offers three pricing tiers: Starter at $299 per month for up to 5 users, Professional at $899 per month for up to 25 users, and Enterprise at $2,450 per month for unlimited users.","title":"Pricing Tiers","category":"product"} +{"id":"northstar-sla","text":"Northstar Analytics guarantees 99.7% uptime under its standard SLA. Customers on the Enterprise tier receive a 99.95% uptime guarantee and a maximum incident response time of 15 minutes.","title":"Service Level Agreement","category":"support"} +{"id":"northstar-storage","text":"Each Northstar Analytics workspace is allocated 500 GB of data storage on the Starter tier, 2 TB on Professional, and 10 TB on Enterprise. Storage is hosted exclusively in Canadian data centres.","title":"Storage Limits","category":"product"} +{"id":"northstar-api-limit","text":"The Northstar Analytics REST API allows 1,200 requests per minute on the Professional tier and 8,000 requests per minute on the Enterprise tier. Starter tier users are limited to 200 requests per minute.","title":"API Rate Limits","category":"technical"} diff --git a/implementations/knowledge_qa/04_custom_knowledge_base.ipynb b/implementations/knowledge_qa/04_custom_knowledge_base.ipynb new file mode 100644 index 0000000..08c474f --- /dev/null +++ b/implementations/knowledge_qa/04_custom_knowledge_base.ipynb @@ -0,0 +1,324 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# 04: Custom Knowledge Base with Vertex AI Search\n", + "\n", + "This notebook demonstrates how to query a **custom knowledge base** stored in Vertex AI Search.\n", + "\n", + "Unlike the Google Search tool (which searches the public web), `vertex_search` queries a\n", + "**private data store** that you control. The model's answer is grounded directly in the\n", + "retrieved document content — not the public internet and not the model's training data.\n", + "\n", + "## What You'll Learn\n", + "\n", + "1. How to call `vertex_search` directly and read the result\n", + "2. How grounding works — what the model can and cannot answer\n", + "3. How to wrap the tool for use inside an ADK agent\n", + "\n", + "## Prerequisites\n", + "\n", + "- `VERTEX_AI_DATASTORE_ID` set in your `.env` file\n", + "- `GOOGLE_CLOUD_LOCATION` set (default: `us-central1`)\n", + "- Authentication via Application Default Credentials (ADC) — automatic on GCE/Coder workspaces" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "setup", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "from aieng.agent_evals.configs import Configs\n", + "from aieng.agent_evals.tools import create_vertex_search_tool, vertex_search\n", + "from dotenv import load_dotenv\n", + "from rich.console import Console\n", + "from rich.panel import Panel\n", + "from rich.table import Table\n", + "\n", + "\n", + "if Path(\"\").absolute().name == \"eval-agents\":\n", + " print(f\"Working directory: {Path('').absolute()}\")\n", + "else:\n", + " os.chdir(Path(\"\").absolute().parent.parent)\n", + " print(f\"Working directory set to: {Path('').absolute()}\")\n", + "\n", + "load_dotenv(verbose=True)\n", + "console = Console(width=100)" + ] + }, + { + "cell_type": "markdown", + "id": "s1-intro", + "metadata": {}, + "source": [ + "## 1. Configuration\n", + "\n", + "The tool reads its configuration from the environment. Let's check what's loaded." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "check-config", + "metadata": {}, + "outputs": [], + "source": [ + "config = Configs() # type: ignore[call-arg]\n", + "\n", + "if not config.vertex_datastore_id:\n", + " console.print(\"[red]✗[/red] VERTEX_AI_DATASTORE_ID is not set.\")\n", + " console.print(\"[dim]Set it in your .env file and re-run this cell.[/dim]\")\n", + "else:\n", + " cfg_table = Table(title=\"Vertex AI Search Configuration\", show_header=False)\n", + " cfg_table.add_column(\"Key\", style=\"cyan\")\n", + " cfg_table.add_column(\"Value\", style=\"white\")\n", + " cfg_table.add_row(\"Data store\", config.vertex_datastore_id)\n", + " cfg_table.add_row(\"Region\", config.google_cloud_location)\n", + " cfg_table.add_row(\"Model\", config.default_worker_model)\n", + " cfg_table.add_row(\"Auth\", \"Application Default Credentials (ADC)\")\n", + " console.print(cfg_table)" + ] + }, + { + "cell_type": "markdown", + "id": "s2-intro", + "metadata": {}, + "source": [ + "## 2. Querying the Knowledge Base\n", + "\n", + "`vertex_search(query)` sends a query to the data store and returns a grounded answer.\n", + "The model only has access to the documents in your data store — it cannot draw on the\n", + "public web or its training data to answer.\n", + "\n", + "The result dict always contains:\n", + "\n", + "| Key | Type | Description |\n", + "|-----|------|-------------|\n", + "| `status` | `str` | `\"success\"` or `\"error\"` |\n", + "| `summary` | `str` | Grounded answer drawn from retrieved documents |\n", + "| `sources` | `list[dict]` | Retrieved documents — each with `title` and `uri` |\n", + "| `source_count` | `int` | Number of documents cited |\n", + "| `error` | `str` | Error message (only present when `status == \"error\"`) |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "single-query", + "metadata": {}, + "outputs": [], + "source": [ + "result = await vertex_search(\"What are the pricing tiers for Northstar Analytics?\")\n", + "\n", + "console.print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "display-result", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the result in a more readable format\n", + "console.print(\n", + " Panel(\n", + " result[\"summary\"],\n", + " title=\"Answer\",\n", + " border_style=\"cyan\",\n", + " )\n", + ")\n", + "\n", + "if result[\"sources\"]:\n", + " src_table = Table(title=f\"Sources ({result['source_count']} retrieved)\")\n", + " src_table.add_column(\"#\", style=\"dim\", width=3)\n", + " src_table.add_column(\"Title\", style=\"cyan\")\n", + " src_table.add_column(\"Document\", style=\"dim\")\n", + "\n", + " for i, src in enumerate(result[\"sources\"], 1):\n", + " # URI is the full document resource name — the last segment is the document ID\n", + " doc_id = src[\"uri\"].split(\"/\")[-1] if src[\"uri\"] else \"\"\n", + " src_table.add_row(str(i), src[\"title\"], doc_id)\n", + "\n", + " console.print(src_table)" + ] + }, + { + "cell_type": "markdown", + "id": "s3-intro", + "metadata": {}, + "source": [ + "## 3. Grounding in Practice\n", + "\n", + "The key property of this tool: the model's answer is **constrained to what's in the data store**.\n", + "Below we run several queries and verify that specific values from the documents appear in the answers.\n", + "\n", + "The test data store contains five documents about a fictional company, **Northstar Analytics**,\n", + "with invented values (prices, SLA figures, rate limits) that don't exist in the model's training data.\n", + "If those values appear in the answer, they came from the data store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "grounding-queries", + "metadata": {}, + "outputs": [], + "source": [ + "queries = [\n", + " (\"Professional tier monthly price\", \"What is the monthly price for the Professional tier?\"),\n", + " (\"Enterprise SLA uptime\", \"What uptime does the Enterprise plan guarantee?\"),\n", + " (\"Enterprise API rate limit\", \"How many API requests per minute does the Enterprise tier allow?\"),\n", + " (\"Storage on Starter tier\", \"How much storage does the Starter tier include?\"),\n", + " (\"Company founders\", \"Who founded Northstar Analytics and when?\"),\n", + "]\n", + "\n", + "results_table = Table(title=\"Grounded Query Results\")\n", + "results_table.add_column(\"Query\", style=\"cyan\", width=28)\n", + "results_table.add_column(\"Summary\", style=\"white\", width=55)\n", + "results_table.add_column(\"Sources\", style=\"dim\", justify=\"right\", width=7)\n", + "\n", + "for label, query in queries:\n", + " r = await vertex_search(query)\n", + " summary = r[\"summary\"][:110] + \"...\" if len(r[\"summary\"]) > 110 else r[\"summary\"]\n", + " status = r[\"source_count\"]\n", + " results_table.add_row(label, summary, str(status))\n", + "\n", + "console.print(results_table)" + ] + }, + { + "cell_type": "markdown", + "id": "s3-out-of-scope", + "metadata": {}, + "source": [ + "### What happens for out-of-scope questions?\n", + "\n", + "If the question cannot be answered from the data store, the model says so rather than hallucinating." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "out-of-scope", + "metadata": {}, + "outputs": [], + "source": [ + "oos_result = await vertex_search(\"What is the capital of France?\")\n", + "\n", + "console.print(\n", + " Panel(\n", + " f\"[bold]Summary:[/bold] {oos_result['summary']}\\n[bold]Sources:[/bold] {oos_result['source_count']}\",\n", + " title=\"Out-of-scope question\",\n", + " border_style=\"yellow\",\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "s4-intro", + "metadata": {}, + "source": [ + "## 4. Using the Tool Inside an Agent\n", + "\n", + "`create_vertex_search_tool()` wraps `vertex_search` as an ADK `FunctionTool` that can be\n", + "passed directly to a `KnowledgeGroundedAgent` or any other ADK agent.\n", + "\n", + "The tool's docstring becomes the model's description of what the tool does — write it clearly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "create-tool", + "metadata": {}, + "outputs": [], + "source": [ + "tool = create_vertex_search_tool()\n", + "\n", + "console.print(f\"Type: [cyan]{type(tool).__name__}[/cyan]\")\n", + "console.print(f\"Function name: [cyan]{tool.func.__name__}[/cyan]\")\n", + "console.print()\n", + "console.print(Panel(tool.func.__doc__ or \"\", title=\"Tool description (seen by the model)\", border_style=\"dim\"))" + ] + }, + { + "cell_type": "markdown", + "id": "s4-agent-usage", + "metadata": {}, + "source": [ + "### Passing the tool to an agent\n", + "\n", + "Replace `create_google_search_tool()` with `create_vertex_search_tool()` in your agent\n", + "setup, or add it alongside the web search tool:\n", + "\n", + "```python\n", + "from aieng.agent_evals.tools import create_vertex_search_tool\n", + "from google.adk.agents import Agent\n", + "\n", + "agent = Agent(\n", + " model=\"gemini-2.5-flash\",\n", + " tools=[\n", + " create_vertex_search_tool(), # custom knowledge base\n", + " ],\n", + " instruction=\"Use vertex_search to answer questions from the knowledge base.\",\n", + ")\n", + "```\n", + "\n", + "The agent will call `vertex_search` when it determines the question can be answered\n", + "from the knowledge base, and the grounded answer is returned directly — no separate\n", + "fetch step required." + ] + }, + { + "cell_type": "markdown", + "id": "summary", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "| | `google_search` + `web_fetch` | `vertex_search` |\n", + "|---|---|---|\n", + "| **Data source** | Public web | Your private data store |\n", + "| **Auth** | API key | ADC (automatic on GCE) |\n", + "| **Steps to get an answer** | Search → Fetch → Read | Single call |\n", + "| **Grounding** | Model reads fetched HTML | Model reads retrieved document chunks |\n", + "| **Out-of-scope queries** | Searches the web anyway | Tells the user it doesn't know |\n", + "\n", + "Use `vertex_search` when you want the agent to be strictly grounded in a controlled\n", + "document set. Use `google_search` when the agent needs to find current information\n", + "from the public internet." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/create_test_datastore.py b/scripts/create_test_datastore.py new file mode 100644 index 0000000..d11bc21 --- /dev/null +++ b/scripts/create_test_datastore.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +r"""Create a Vertex AI Search test data store for integration testing. + +Provisions the GCS bucket, uploads the fixture JSONL, creates a NO_CONTENT +structured data store, imports the documents, and waits for indexing. + +Usage +----- + # Authenticate first (or rely on the GCE service account in CI) + gcloud auth application-default login + + # Run from the repo root + uv run python -m scripts.create_test_datastore \\ + --bucket \\ + [--project agentic-ai-evaluation-bootcamp] \\ + [--datastore-id vertex-search-integration-test] + +After the script finishes it prints the VERTEX_AI_DATASTORE_ID value +to add to your .env file. +""" + +import argparse +import base64 +import json +import sys +import time +from pathlib import Path + +import google.auth +import google.auth.transport.requests + + +DISCOVERY_ENGINE_BASE = "https://discoveryengine.googleapis.com/v1" +STORAGE_BASE = "https://storage.googleapis.com/storage/v1" +STORAGE_UPLOAD_BASE = "https://storage.googleapis.com/upload/storage/v1" + +# Fixture file relative to the repo root +FIXTURE_PATH = Path(__file__).parent.parent / "aieng-eval-agents" / "tests" / "fixtures" / "vertex_test_data.jsonl" + + +def get_session() -> google.auth.transport.requests.AuthorizedSession: + """Return an authorised requests session using Application Default Credentials.""" + credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + return google.auth.transport.requests.AuthorizedSession(credentials) + + +def create_bucket(session, project: str, bucket: str) -> None: + """Create a GCS bucket in us-central1, skipping if it already exists.""" + url = f"{STORAGE_BASE}/b?project={project}" + body = {"name": bucket, "location": "us-central1", "storageClass": "STANDARD"} + resp = session.post(url, json=body) + if resp.status_code == 409: + print(f" Bucket gs://{bucket} already exists — skipping creation.") + elif resp.status_code in (200, 201): + print(f" Created bucket gs://{bucket}") + else: + print(f" Error creating bucket: {resp.status_code} {resp.text}", file=sys.stderr) + resp.raise_for_status() + + +def transform_to_content_required(source_path: Path) -> bytes: + """Transform participant-format JSONL to Discovery Engine CONTENT_REQUIRED format. + + Participant format (flat): + {"id": "x", "text": "...", "title": "...", "category": "..."} + + Discovery Engine CONTENT_REQUIRED format: + { + "id": "x", + "content": {"mimeType": "text/plain", "rawBytes": ""}, + "structData": {...} + } + + The ``text`` field becomes the indexed document content (stored as base64 rawBytes). + All other fields (except ``id``) become metadata in ``structData``. + """ + if not source_path.exists(): + raise FileNotFoundError(f"Fixture file not found: {source_path}") + + output_lines = [] + for raw_line in source_path.read_text(encoding="utf-8").strip().splitlines(): + row = json.loads(raw_line) + doc_id = row.pop("id") + text = row.pop("text", "") + doc = { + "id": doc_id, + "content": { + "mimeType": "text/plain", + "rawBytes": base64.b64encode(text.encode("utf-8")).decode("ascii"), + }, + "structData": row, # title, category, and any other metadata fields + } + output_lines.append(json.dumps(doc)) + + return "\n".join(output_lines).encode("utf-8") + + +def upload_fixture(session, bucket: str, object_name: str) -> None: + """Transform and upload the JSONL fixture to GCS in CONTENT_REQUIRED format.""" + payload = transform_to_content_required(FIXTURE_PATH) + url = f"{STORAGE_UPLOAD_BASE}/b/{bucket}/o?uploadType=media&name={object_name}" + resp = session.post( + url, + data=payload, + headers={"Content-Type": "application/json"}, + ) + resp.raise_for_status() + print(f" Transformed and uploaded {FIXTURE_PATH.name} → gs://{bucket}/{object_name}") + + +def create_datastore(session, project: str, datastore_id: str) -> None: + """Create a NO_CONTENT structured search data store, skipping if it exists.""" + url = ( + f"{DISCOVERY_ENGINE_BASE}/projects/{project}/locations/global" + f"/collections/default_collection/dataStores?dataStoreId={datastore_id}" + ) + body = { + "displayName": "Vertex Search Integration Test", + "industryVertical": "GENERIC", + "contentConfig": "CONTENT_REQUIRED", + "solutionTypes": ["SOLUTION_TYPE_SEARCH"], + } + resp = session.post(url, json=body) + if resp.status_code == 409: + print(f" Data store '{datastore_id}' already exists — skipping creation.") + elif resp.status_code in (200, 201): + print(f" Created data store '{datastore_id}'") + # Allow a moment for the data store to become fully ready + time.sleep(5) + else: + print(f" Error creating data store: {resp.status_code} {resp.text}", file=sys.stderr) + resp.raise_for_status() + + +def import_documents(session, project: str, datastore_id: str, gcs_uri: str) -> str: + """Trigger an async document import from GCS. Returns the operation name.""" + url = ( + f"{DISCOVERY_ENGINE_BASE}/projects/{project}/locations/global" + f"/collections/default_collection/dataStores/{datastore_id}" + f"/branches/default_branch/documents:import" + ) + body = { + "gcsSource": { + "inputUris": [gcs_uri], + # "document" matches our JSONL format: + # {id, content:{mimeType,rawBytes}, structData:{...}} + "dataSchema": "document", + }, + # FULL replaces all existing documents, keeping the test store deterministic + "reconciliationMode": "FULL", + } + resp = session.post(url, json=body) + resp.raise_for_status() + operation_name = resp.json()["name"] + print(f" Import operation started: {operation_name}") + return operation_name + + +def wait_for_operation( + session, + operation_name: str, + timeout_sec: int = 600, + poll_interval: int = 15, +) -> dict: + """Poll the operation until it is done or the timeout is reached.""" + url = f"{DISCOVERY_ENGINE_BASE}/{operation_name}" + start = time.time() + deadline = start + timeout_sec + + while time.time() < deadline: + resp = session.get(url) + resp.raise_for_status() + op = resp.json() + + if op.get("done"): + if "error" in op: + raise RuntimeError(f"Import operation failed: {op['error']}") + # Check per-document failure count in metadata + metadata = op.get("metadata", {}) + failure_count = int(metadata.get("failureCount", 0)) + total_count = int(metadata.get("totalCount", 0)) + if failure_count > 0: + samples = op.get("response", {}).get("errorSamples", []) + sample_msg = samples[0]["message"] if samples else "unknown error" + raise RuntimeError( + f"Import completed but {failure_count}/{total_count} documents failed. First error: {sample_msg}" + ) + print(f" Indexing complete — {total_count} documents imported.") + return op + + elapsed = int(time.time() - start) + print(f" Indexing in progress… ({elapsed}s elapsed, checking again in {poll_interval}s)") + time.sleep(poll_interval) + + raise TimeoutError(f"Operation did not complete within {timeout_sec}s: {operation_name}") + + +def main() -> None: + """Parse CLI arguments and provision the Vertex AI Search test data store.""" + parser = argparse.ArgumentParser( + description="Provision a Vertex AI Search test data store.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--project", + default="agentic-ai-evaluation-bootcamp", + help="GCP project ID (default: agentic-ai-evaluation-bootcamp)", + ) + parser.add_argument( + "--bucket", + required=True, + help="GCS bucket name for staging the import file (must be globally unique)", + ) + parser.add_argument( + "--datastore-id", + default="vertex-search-integration-test", + help="Vertex AI Search data store ID (default: vertex-search-integration-test)", + ) + args = parser.parse_args() + + gcs_object = "vertex-search-test/vertex_test_data.jsonl" + gcs_uri = f"gs://{args.bucket}/{gcs_object}" + datastore_resource = ( + f"projects/{args.project}/locations/global/collections/default_collection/dataStores/{args.datastore_id}" + ) + + print("Vertex AI Search — test data store provisioning") + print("=" * 55) + print(f" Project: {args.project}") + print(f" Bucket: gs://{args.bucket}") + print(f" Data store: {datastore_resource}") + print() + + session = get_session() + + print("Step 1/5 Creating GCS bucket…") + create_bucket(session, args.project, args.bucket) + + print("Step 2/5 Uploading fixture data to GCS…") + upload_fixture(session, args.bucket, gcs_object) + + print("Step 3/5 Creating Vertex AI Search data store…") + create_datastore(session, args.project, args.datastore_id) + + print("Step 4/5 Importing documents…") + operation_name = import_documents(session, args.project, args.datastore_id, gcs_uri) + + print("Step 5/5 Waiting for indexing (may take several minutes)…") + wait_for_operation(session, operation_name) + + print() + print("=" * 55) + print("Done! Add this to your .env file:") + print() + print(f'VERTEX_AI_DATASTORE_ID="{datastore_resource}"') + print("=" * 55) + + +if __name__ == "__main__": + main()