diff --git a/src/google/adk/cli/cli_create.py b/src/google/adk/cli/cli_create.py index a1be9a0273..c5bf7c3819 100644 --- a/src/google/adk/cli/cli_create.py +++ b/src/google/adk/cli/cli_create.py @@ -22,6 +22,7 @@ import click from ..apps.app import validate_app_name +from .utils import gcp_utils _INIT_PY_TEMPLATE = """\ from . import agent @@ -61,11 +62,26 @@ https://google.github.io/adk-docs/agents/models """ +_EXPRESS_TOS_MSG = """ +Google Cloud Express Mode Terms of Service: https://cloud.google.com/terms/google-cloud-express +By continuing, you agree to the Terms of Service for Vertex AI Express Mode. +Would you like to proceed? (yes/no) +""" + +_NOT_ELIGIBLE_MSG = """ +You are not eligible for Express Mode. +Please follow these instructions to set up a full Google Cloud project: +https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai +""" + _SUCCESS_MSG_CODE = """ Agent created in {agent_folder}: - .env - __init__.py - agent.py + +⚠️ WARNING: Secrets (like GOOGLE_API_KEY) are stored in .env. +Please ensure .env is added to your .gitignore to avoid committing secrets to version control. """ _SUCCESS_MSG_CONFIG = """ @@ -73,6 +89,9 @@ - .env - __init__.py - root_agent.yaml + +⚠️ WARNING: Secrets (like GOOGLE_API_KEY) are stored in .env. +Please ensure .env is added to your .gitignore to avoid committing secrets to version control. """ @@ -187,10 +206,10 @@ def _generate_files( with open(dotenv_file_path, "w", encoding="utf-8") as f: lines = [] - if google_api_key: - lines.append("GOOGLE_GENAI_USE_VERTEXAI=0") - elif google_cloud_project and google_cloud_region: + if google_cloud_project and google_cloud_region: lines.append("GOOGLE_GENAI_USE_VERTEXAI=1") + elif google_api_key: + lines.append("GOOGLE_GENAI_USE_VERTEXAI=0") if google_api_key: lines.append(f"GOOGLE_API_KEY={google_api_key}") if google_cloud_project: @@ -247,8 +266,8 @@ def _prompt_to_choose_backend( A tuple of (google_api_key, google_cloud_project, google_cloud_region). """ backend_choice = click.prompt( - "1. Google AI\n2. Vertex AI\nChoose a backend", - type=click.Choice(["1", "2"]), + "1. Google AI\n2. Vertex AI\n3. Login with Google\nChoose a backend", + type=click.Choice(["1", "2", "3"]), ) if backend_choice == "1": google_api_key = _prompt_for_google_api_key(google_api_key) @@ -256,9 +275,101 @@ def _prompt_to_choose_backend( click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green") google_cloud_project = _prompt_for_google_cloud(google_cloud_project) google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region) + elif backend_choice == "3": + google_api_key, google_cloud_project, google_cloud_region = ( + _handle_login_with_google() + ) return google_api_key, google_cloud_project, google_cloud_region +def _handle_login_with_google() -> ( + Tuple[Optional[str], Optional[str], Optional[str]] +): + """Handles the "Login with Google" flow.""" + if not gcp_utils.check_adc(): + click.secho( + "No Application Default Credentials found. " + "Opening browser for login...", + fg="yellow", + ) + try: + gcp_utils.login_adc() + except RuntimeError as e: + click.secho(str(e), fg="red") + raise click.Abort() + + # Check for existing Express project + express_project = gcp_utils.retrieve_express_project() + if express_project: + api_key = express_project.get("api_key") + project_id = express_project.get("project_id") + region = express_project.get("region", "us-central1") + if project_id: + click.secho(f"Using existing Express project: {project_id}", fg="green") + return api_key, project_id, region + + # Check for existing full GCP projects + projects = gcp_utils.list_gcp_projects(limit=20) + if projects: + click.secho("Recently created Google Cloud projects found:", fg="green") + click.echo("0. Enter project ID manually") + for i, (p_id, p_name) in enumerate(projects, 1): + click.echo(f"{i}. {p_name} ({p_id})") + + project_index = click.prompt( + "Select a project", + type=click.IntRange(0, len(projects)), + ) + if project_index == 0: + selected_project_id = _prompt_for_google_cloud(None) + else: + selected_project_id = projects[project_index - 1][0] + region = _prompt_for_google_cloud_region(None) + return None, selected_project_id, region + else: + if click.confirm( + "No projects found automatically. Would you like to enter one" + " manually?", + default=False, + ): + selected_project_id = _prompt_for_google_cloud(None) + region = _prompt_for_google_cloud_region(None) + return None, selected_project_id, region + + # Check Express eligibility + if gcp_utils.check_express_eligibility(): + click.secho(_EXPRESS_TOS_MSG, fg="yellow") + if click.confirm("Do you accept the Terms of Service?", default=False): + selected_region = click.prompt( + """\ +Choose a region for Express Mode: +1. us-central1 +2. europe-west1 +3. asia-southeast1 +Choose region""", + type=click.Choice(["1", "2", "3"]), + default="1", + ) + region_map = { + "1": "us-central1", + "2": "europe-west1", + "3": "asia-southeast1", + } + region = region_map[selected_region] + express_info = gcp_utils.sign_up_express(location=region) + api_key = express_info.get("api_key") + project_id = express_info.get("project_id") + region = express_info.get("region", region) + click.secho( + f"Express Mode project created: {project_id}", + fg="green", + ) + return api_key, project_id, region + + click.secho(_NOT_ELIGIBLE_MSG, fg="red") + raise click.Abort() + + def _prompt_to_choose_type() -> str: """Prompts user to choose type of agent to create.""" type_choice = click.prompt( diff --git a/src/google/adk/cli/utils/gcp_utils.py b/src/google/adk/cli/utils/gcp_utils.py new file mode 100644 index 0000000000..09fc26058f --- /dev/null +++ b/src/google/adk/cli/utils/gcp_utils.py @@ -0,0 +1,176 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for GCP authentication and Vertex AI Express Mode.""" + +from __future__ import annotations + +import subprocess +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import google.auth +import google.auth.exceptions +from google.auth.transport.requests import AuthorizedSession +from google.auth.transport.requests import Request +from google.cloud import resourcemanager_v3 +import requests + +# TODO: Update to production endpoint before making this public. +_STAGING_ENDPOINT = ( + "https://{location}-staging-aiplatform.sandbox.googleapis.com/v1beta1" +) + + +def check_adc() -> bool: + """Checks if Application Default Credentials exist.""" + try: + google.auth.default() + return True + except google.auth.exceptions.DefaultCredentialsError: + return False + + +def login_adc() -> None: + """Prompts user to login via gcloud ADC.""" + try: + subprocess.run( + ["gcloud", "auth", "application-default", "login"], check=True + ) + except (subprocess.CalledProcessError, FileNotFoundError): + raise RuntimeError( + "gcloud is not installed or failed to run. " + "Please install gcloud to login to Application Default Credentials." + ) + + +def get_access_token() -> str: + """Gets the ADC access token.""" + try: + credentials, _ = google.auth.default() + if not credentials.valid: + credentials.refresh(Request()) + return credentials.token or "" + except google.auth.exceptions.DefaultCredentialsError: + raise RuntimeError("Application Default Credentials not found.") + + +def _call_vertex_express_api( + method: str, + action: str, + location: str = "us-central1", + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Calls a Vertex AI Express API.""" + credentials, _ = google.auth.default() + session = AuthorizedSession(credentials) + url = f"{_STAGING_ENDPOINT.format(location=location)}/vertexExpress{action}" + headers = { + "Content-Type": "application/json", + } + + if method == "GET": + response = session.get(url, headers=headers, params=params) + elif method == "POST": + response = session.post(url, headers=headers, json=data, params=params) + else: + raise ValueError(f"Unsupported method: {method}") + + response.raise_for_status() + return response.json() + + +def retrieve_express_project( + location: str = "us-central1", +) -> Optional[Dict[str, Any]]: + """Retrieves existing Express project info.""" + try: + response = _call_vertex_express_api( + "GET", + ":retrieveExpressProject", + location=location, + params={"get_default_api_key": True}, + ) + project = response.get("expressProject") + if not project: + return None + + return { + "project_id": project.get("projectId"), + "api_key": project.get("defaultApiKey"), + "region": project.get("region", location), + } + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + return None + raise + + +def check_express_eligibility( + location: str = "us-central1", +) -> bool: + """Checks if user is eligible for Express Mode.""" + try: + result = _call_vertex_express_api( + "GET", "/Eligibility:check", location=location + ) + return result.get("eligibility") == "IN_SCOPE" + except (requests.exceptions.HTTPError, KeyError) as e: + return False + + +def sign_up_express( + location: str = "us-central1", +) -> Dict[str, Any]: + """Signs up for Express Mode.""" + project = _call_vertex_express_api( + "POST", + ":signUp", + location=location, + data={"region": location, "tos_accepted": True}, + ) + return { + "project_id": project.get("projectId"), + "api_key": project.get("defaultApiKey"), + "region": project.get("region", location), + } + + +def list_gcp_projects(limit: int = 20) -> List[Tuple[str, str]]: + """Lists GCP projects available to the user. + + Args: + limit: The maximum number of projects to return. + + Returns: + A list of (project_id, name) tuples. + """ + try: + client = resourcemanager_v3.ProjectsClient() + search_results = client.search_projects() + + projects = [] + for project in search_results: + if len(projects) >= limit: + break + projects.append( + (project.project_id, project.display_name or project.project_id) + ) + return projects + except Exception: + return [] diff --git a/tests/unittests/cli/utils/test_cli_create.py b/tests/unittests/cli/utils/test_cli_create.py index 0d76ab8a54..520245b3fc 100644 --- a/tests/unittests/cli/utils/test_cli_create.py +++ b/tests/unittests/cli/utils/test_cli_create.py @@ -26,6 +26,7 @@ import click import google.adk.cli.cli_create as cli_create +from google.adk.cli.utils import gcp_utils import pytest @@ -87,6 +88,23 @@ def test_generate_files_with_gcp(agent_folder: Path) -> None: assert "GOOGLE_GENAI_USE_VERTEXAI=1" in env_content +def test_generate_files_with_express_mode(agent_folder: Path) -> None: + """Files should be created with Vertex AI backend when both project and API key are present (Express Mode).""" + cli_create._generate_files( + str(agent_folder), + google_api_key="express-api-key", + google_cloud_project="express-project-id", + google_cloud_region="us-central1", + model="gemini-2.0-flash-001", + type="code", + ) + + env_content = (agent_folder / ".env").read_text() + assert "GOOGLE_GENAI_USE_VERTEXAI=1" in env_content + assert "GOOGLE_API_KEY=express-api-key" in env_content + assert "GOOGLE_CLOUD_PROJECT=express-project-id" in env_content + + def test_generate_files_overwrite(agent_folder: Path) -> None: """Existing files should be overwritten when generating again.""" agent_folder.mkdir(parents=True, exist_ok=True) @@ -284,6 +302,87 @@ def test_prompt_to_choose_backend_vertex( assert region == "region" +def test_prompt_to_choose_backend_login( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Choosing Login with Google returns (api_key, project, region) from handler.""" + monkeypatch.setattr(click, "prompt", lambda *a, **k: "3") + monkeypatch.setattr( + cli_create, + "_handle_login_with_google", + lambda: ("api-key", "proj", "region"), + ) + + api_key, proj, region = cli_create._prompt_to_choose_backend(None, None, None) + assert api_key == "api-key" + assert proj == "proj" + assert region == "region" + + +def test_handle_login_with_google_existing_express( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Handler should return existing Express project if found.""" + monkeypatch.setattr(gcp_utils, "check_adc", lambda: True) + monkeypatch.setattr( + gcp_utils, + "retrieve_express_project", + lambda: {"api_key": "key", "project_id": "proj", "region": "us-central1"}, + ) + + api_key, proj, region = cli_create._handle_login_with_google() + assert api_key == "key" + assert proj == "proj" + assert region == "us-central1" + + +def test_handle_login_with_google_select_gcp_project( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Handler should prompt for project selection if no Express project found.""" + monkeypatch.setattr(gcp_utils, "check_adc", lambda: True) + monkeypatch.setattr(gcp_utils, "retrieve_express_project", lambda: None) + monkeypatch.setattr( + gcp_utils, "list_gcp_projects", lambda limit: [("p1", "Project 1")] + ) + monkeypatch.setattr(click, "prompt", lambda *a, **k: 1) + monkeypatch.setattr( + cli_create, "_prompt_for_google_cloud_region", lambda _v: "us-east1" + ) + + api_key, proj, region = cli_create._handle_login_with_google() + assert api_key is None + assert proj == "p1" + assert region == "us-east1" + + +def test_handle_login_with_google_express_signup( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Handler should sign up for Express if eligible and user accepts TOS.""" + monkeypatch.setattr(gcp_utils, "check_adc", lambda: True) + monkeypatch.setattr(gcp_utils, "retrieve_express_project", lambda: None) + monkeypatch.setattr(gcp_utils, "list_gcp_projects", lambda limit: []) + monkeypatch.setattr(gcp_utils, "check_express_eligibility", lambda: True) + confirms = iter([False, True]) + monkeypatch.setattr(click, "confirm", lambda *a, **k: next(confirms)) + monkeypatch.setattr(click, "prompt", lambda *a, **k: "1") + monkeypatch.setattr( + gcp_utils, + "sign_up_express", + lambda location="us-central1": { + "api_key": "new-key", + "project_id": "new-proj", + "region": location, + }, + ) + + api_key, proj, region = cli_create._handle_login_with_google() + assert api_key == "new-key" + assert proj == "new-proj" + assert region == "us-central1" + + # prompt_str def test_prompt_str_non_empty(monkeypatch: pytest.MonkeyPatch) -> None: """_prompt_str should retry until a non-blank string is provided.""" @@ -317,3 +416,42 @@ def test_get_gcp_region_from_gcloud_fail( ), ) assert cli_create._get_gcp_region_from_gcloud() == "" + + +def test_handle_login_with_google_manual_project( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Handler should allow manual project ID entry when '0' is selected.""" + monkeypatch.setattr(gcp_utils, "check_adc", lambda: True) + monkeypatch.setattr(gcp_utils, "retrieve_express_project", lambda: None) + monkeypatch.setattr( + gcp_utils, "list_gcp_projects", lambda limit: [("p1", "Project 1")] + ) + # First prompt is for project selection (0), second is for manual ID entry, + # third is for region selection. + prompts = iter([0, "manual-proj", "us-east1"]) + monkeypatch.setattr(click, "prompt", lambda *a, **k: next(prompts)) + + api_key, proj, region = cli_create._handle_login_with_google() + assert api_key is None + assert proj == "manual-proj" + assert region == "us-east1" + + +def test_handle_login_with_google_empty_projects_manual_entry( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Handler should allow manual entry if no projects are found and user accepts.""" + monkeypatch.setattr(gcp_utils, "check_adc", lambda: True) + monkeypatch.setattr(gcp_utils, "retrieve_express_project", lambda: None) + monkeypatch.setattr(gcp_utils, "list_gcp_projects", lambda limit: []) + + # User says Yes to "enter manually", then provides project ID and region + prompts = iter(["manual-proj", "us-east1"]) + monkeypatch.setattr(click, "confirm", lambda *a, **k: True) + monkeypatch.setattr(click, "prompt", lambda *a, **k: next(prompts)) + + api_key, proj, region = cli_create._handle_login_with_google() + assert api_key is None + assert proj == "manual-proj" + assert region == "us-east1" diff --git a/tests/unittests/cli/utils/test_gcp_utils.py b/tests/unittests/cli/utils/test_gcp_utils.py new file mode 100644 index 0000000000..fc26af4997 --- /dev/null +++ b/tests/unittests/cli/utils/test_gcp_utils.py @@ -0,0 +1,173 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for gcp_utils.""" + +import json +import unittest +from unittest import mock + +from google.adk.cli.utils import gcp_utils +import google.auth +import google.auth.exceptions +import requests + + +class TestGcpUtils(unittest.TestCase): + + @mock.patch("google.auth.default") + def test_check_adc_success(self, mock_auth_default): + mock_auth_default.return_value = (mock.Mock(), "test-project") + self.assertTrue(gcp_utils.check_adc()) + + @mock.patch("google.auth.default") + def test_check_adc_failure(self, mock_auth_default): + mock_auth_default.side_effect = ( + google.auth.exceptions.DefaultCredentialsError() + ) + self.assertFalse(gcp_utils.check_adc()) + + @mock.patch("google.auth.default") + def test_get_access_token(self, mock_auth_default): + mock_creds = mock.Mock() + mock_creds.token = "test-token" + mock_creds.valid = True + mock_auth_default.return_value = (mock_creds, "test-project") + self.assertEqual(gcp_utils.get_access_token(), "test-token") + + @mock.patch("google.auth.default") + def test_get_gcloud_project_success(self, mock_auth_default): + mock_auth_default.return_value = (mock.Mock(), "my-project") + self.assertEqual(gcp_utils._get_gcloud_project(), "my-project") + + @mock.patch("google.auth.default") + def test_get_gcloud_project_unset(self, mock_auth_default): + mock_auth_default.return_value = (mock.Mock(), None) + self.assertIsNone(gcp_utils._get_gcloud_project()) + + @mock.patch("google.adk.cli.utils.gcp_utils.AuthorizedSession") + @mock.patch("google.auth.default") + @mock.patch("google.adk.cli.utils.gcp_utils._get_gcloud_project") + def test_retrieve_express_project_success( + self, mock_get_project, mock_auth_default, mock_session_cls + ): + mock_auth_default.return_value = (mock.Mock(), "test-project-id") + mock_get_project.return_value = "test-project-id" + + mock_session = mock.Mock() + mock_session_cls.return_value = mock_session + mock_response = mock.Mock() + mock_response.json.return_value = { + "expressProject": { + "projectId": "test-project", + "defaultApiKey": "test-api-key", + "region": "us-central1", + } + } + mock_session.get.return_value = mock_response + + result = gcp_utils.retrieve_express_project() + self.assertEqual(result["project_id"], "test-project") + self.assertEqual(result["api_key"], "test-api-key") + self.assertEqual(result["region"], "us-central1") + mock_session.get.assert_called_once() + args, kwargs = mock_session.get.call_args + self.assertEqual(kwargs["params"], {"get_default_api_key": True}) + + @mock.patch("google.adk.cli.utils.gcp_utils.AuthorizedSession") + @mock.patch("google.auth.default") + @mock.patch("google.adk.cli.utils.gcp_utils._get_gcloud_project") + def test_retrieve_express_project_not_found( + self, mock_get_project, mock_auth_default, mock_session_cls + ): + mock_auth_default.return_value = (mock.Mock(), "test-project-id") + mock_get_project.return_value = "test-project-id" + + mock_session = mock.Mock() + mock_session_cls.return_value = mock_session + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_response + ) + mock_session.get.return_value = mock_response + + result = gcp_utils.retrieve_express_project() + self.assertIsNone(result) + + @mock.patch("google.adk.cli.utils.gcp_utils.AuthorizedSession") + @mock.patch("google.auth.default") + @mock.patch("google.adk.cli.utils.gcp_utils._get_gcloud_project") + def test_check_express_eligibility( + self, mock_get_project, mock_auth_default, mock_session_cls + ): + mock_auth_default.return_value = (mock.Mock(), "test-project-id") + mock_get_project.return_value = "test-project-id" + + mock_session = mock.Mock() + mock_session_cls.return_value = mock_session + mock_response = mock.Mock() + mock_response.json.return_value = {"eligibility": "IN_SCOPE"} + mock_session.get.return_value = mock_response + + self.assertTrue(gcp_utils.check_express_eligibility()) + + @mock.patch("google.adk.cli.utils.gcp_utils.AuthorizedSession") + @mock.patch("google.auth.default") + @mock.patch("google.adk.cli.utils.gcp_utils._get_gcloud_project") + def test_sign_up_express( + self, mock_get_project, mock_auth_default, mock_session_cls + ): + mock_auth_default.return_value = (mock.Mock(), "test-project-id") + mock_get_project.return_value = "test-project-id" + + mock_session = mock.Mock() + mock_session_cls.return_value = mock_session + mock_response = mock.Mock() + mock_response.json.return_value = { + "projectId": "new-project", + "defaultApiKey": "new-api-key", + "region": "us-central1", + } + mock_session.post.return_value = mock_response + + result = gcp_utils.sign_up_express() + self.assertEqual(result["project_id"], "new-project") + self.assertEqual(result["api_key"], "new-api-key") + + @mock.patch( + "google.adk.cli.utils.gcp_utils.resourcemanager_v3.ProjectsClient" + ) + def test_list_gcp_projects(self, mock_client_cls): + mock_client = mock.Mock() + mock_client_cls.return_value = mock_client + + mock_project1 = mock.Mock() + mock_project1.project_id = "p1" + mock_project1.display_name = "Project 1" + + mock_project2 = mock.Mock() + mock_project2.project_id = "p2" + mock_project2.display_name = None + + mock_client.search_projects.return_value = [mock_project1, mock_project2] + + projects = gcp_utils.list_gcp_projects() + self.assertEqual(len(projects), 2) + self.assertEqual(projects[0], ("p1", "Project 1")) + self.assertEqual(projects[1], ("p2", "p2")) + + +if __name__ == "__main__": + unittest.main()