diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index b12c045548..bfbd65b1b5 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -2715,6 +2715,11 @@ async def predict_async( my_predictions = response.predictions ``` + For dedicated endpoints (``dedicated_endpoint_enabled=True``), the call + is routed via async HTTPS POST to the endpoint's dedicated DNS, + mirroring the synchronous ``predict()`` dedicated-endpoint path. + Otherwise the GAPIC async prediction client is used. + Args: instances (List): Required. The instances that are the input to the @@ -2740,29 +2745,96 @@ async def predict_async( Returns: prediction (aiplatform.Prediction): Prediction with returned predictions and Model ID. + + Raises: + ValueError: If the dedicated endpoint DNS is empty for dedicated + endpoints, or if the prediction request to a dedicated endpoint + returns a non-200 status. """ self.wait() - prediction_response = await self._prediction_async_client.predict( - endpoint=self._gca_resource.name, - instances=instances, - parameters=parameters, - timeout=timeout, - ) - if prediction_response._pb.metadata: - metadata = json_format.MessageToDict(prediction_response._pb.metadata) + if not self.dedicated_endpoint_enabled: + prediction_response = await self._prediction_async_client.predict( + endpoint=self._gca_resource.name, + instances=instances, + parameters=parameters, + timeout=timeout, + ) + if prediction_response._pb.metadata: + metadata = json_format.MessageToDict(prediction_response._pb.metadata) + else: + metadata = None + + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in prediction_response.predictions.pb + ], + metadata=metadata, + deployed_model_id=prediction_response.deployed_model_id, + model_version_id=prediction_response.model_version_id, + model_resource_name=prediction_response.model, + ) + + # Dedicated endpoint: REST POST to the dedicated DNS via aiohttp. + # aiohttp is imported lazily so it is not a hard dependency for callers + # that only use the synchronous predict() path. + try: + import aiohttp + except ImportError as exc: + raise ImportError( + "Cannot import the aiohttp library required for async prediction" + " on dedicated endpoints. Please install aiohttp." + ) from exc + + if not self.dedicated_endpoint_dns: + raise ValueError( + "Dedicated endpoint DNS is empty. Please make sure endpoint" + "and model are ready before making a prediction." + ) + + if parameters is not None: + data = json.dumps({"instances": instances, "parameters": parameters}) else: - metadata = None + data = json.dumps({"instances": instances}) + + # Refresh the bearer token per call. ``AuthorizedSession`` (sync) + # handles refresh internally; in the async path we do it explicitly. + self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES + self.credentials.refresh(google_auth_requests.Request()) + headers = { + "Authorization": f"Bearer {self.credentials.token}", + "Content-Type": "application/json", + } + url = f"https://{self.dedicated_endpoint_dns}/v1/{self.resource_name}:predict" + aiohttp_timeout = ( + aiohttp.ClientTimeout(total=timeout) if timeout is not None else None + ) + + # Use a per-call session so the underlying connector is always closed, + # avoiding leaked ClientSessions (Python has no async destructor to do + # this for a cached session). + async with aiohttp.ClientSession() as session: + async with session.post( + url=url, + data=data, + headers=headers, + timeout=aiohttp_timeout, + ) as response: + if response.status != 200: + text = await response.text() + raise ValueError( + f"Failed to make prediction request. Status code:" + f"{response.status}, response: {text}." + ) + prediction_response = await response.json() return Prediction( - predictions=[ - json_format.MessageToDict(item) - for item in prediction_response.predictions.pb - ], - metadata=metadata, - deployed_model_id=prediction_response.deployed_model_id, - model_version_id=prediction_response.model_version_id, - model_resource_name=prediction_response.model, + predictions=prediction_response.get("predictions"), + metadata=prediction_response.get("metadata"), + deployed_model_id=prediction_response.get("deployedModelId"), + model_resource_name=prediction_response.get("model"), + model_version_id=prediction_response.get("modelVersionId"), ) def raw_predict( diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 89b53f6aef..94ca70ec65 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -19,6 +19,8 @@ from datetime import datetime, timedelta from importlib import reload import json +import aiohttp +from aiohttp import web as aiohttp_web import requests from unittest import mock from google.protobuf import duration_pb2 @@ -662,6 +664,38 @@ def predict_async_client_predict_mock(): yield predict_mock +@pytest.fixture +def predict_endpoint_aiohttp_mock(): + """Mocks aiohttp.ClientSession.post for dedicated-endpoint async predict. + + ClientSession.post returns a _RequestContextManager (async ctx mgr), so the + mock must implement __aenter__/__aexit__. Also patches Credentials.refresh + to a no-op so we don't hit real auth during tests. + """ + payload = { + "predictions": _TEST_PREDICTION, + "metadata": _TEST_METADATA, + "deployedModelId": _TEST_DEPLOYED_MODELS[0].id, + "model": _TEST_MODEL_NAME, + "modelVersionId": "1", + } + response = mock.AsyncMock() + response.status = 200 + response.json = mock.AsyncMock(return_value=payload) + response.text = mock.AsyncMock(return_value=json.dumps(payload)) + + post_cm = mock.MagicMock() + post_cm.__aenter__ = mock.AsyncMock(return_value=response) + post_cm.__aexit__ = mock.AsyncMock(return_value=None) + + with mock.patch.object( + aiohttp.ClientSession, "post", return_value=post_cm + ) as post_mock, mock.patch.object( + auth_credentials.AnonymousCredentials, "refresh" + ): + yield post_mock + + @pytest.fixture def predict_client_direct_predict_mock(): with mock.patch.object( @@ -3687,6 +3721,197 @@ async def test_predict_async(self, predict_async_client_predict_mock): timeout=None, ) + @pytest.mark.asyncio + @pytest.mark.usefixtures("get_dedicated_endpoint_mock") + async def test_predict_async_dedicated_endpoint( + self, predict_endpoint_aiohttp_mock + ): + """Async predict on a dedicated endpoint routes via aiohttp REST.""" + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + + test_prediction = await test_endpoint.predict_async( + instances=_TEST_INSTANCES, parameters={"param": 3.0} + ) + + true_prediction = models.Prediction( + predictions=_TEST_PREDICTION, + deployed_model_id=_TEST_ID, + metadata=_TEST_METADATA, + model_version_id=_TEST_VERSION_ID, + model_resource_name=_TEST_MODEL_NAME, + ) + assert true_prediction == test_prediction + + predict_endpoint_aiohttp_mock.assert_called_once() + call_kwargs = predict_endpoint_aiohttp_mock.call_args.kwargs + assert call_kwargs["url"] == ( + f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/" + f"{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict" + ) + assert call_kwargs["data"] == ( + '{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]],' + ' "parameters": {"param": 3.0}}' + ) + assert call_kwargs["headers"]["Content-Type"] == "application/json" + assert call_kwargs["headers"]["Authorization"].startswith("Bearer ") + assert call_kwargs["timeout"] is None + + @pytest.mark.asyncio + @pytest.mark.usefixtures("get_dedicated_endpoint_no_dns_mock") + async def test_predict_async_dedicated_endpoint_without_dns( + self, predict_endpoint_aiohttp_mock + ): + """Async predict on a dedicated endpoint with no DNS raises ValueError.""" + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + + with pytest.raises(ValueError) as err: + await test_endpoint.predict_async( + instances=_TEST_INSTANCES, parameters={"param": 3.0} + ) + assert err.match( + regexp=r"Dedicated endpoint DNS is empty. Please make sure endpoint" + "and model are ready before making a prediction." + ) + + @pytest.mark.asyncio + @pytest.mark.usefixtures("get_dedicated_endpoint_mock") + async def test_predict_async_dedicated_endpoint_with_timeout( + self, predict_endpoint_aiohttp_mock + ): + """A non-None timeout is forwarded as an aiohttp.ClientTimeout.""" + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + + await test_endpoint.predict_async( + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + timeout=_TEST_PREDICT_TIMEOUT, + ) + + predict_endpoint_aiohttp_mock.assert_called_once() + call_timeout = predict_endpoint_aiohttp_mock.call_args.kwargs["timeout"] + assert isinstance(call_timeout, aiohttp.ClientTimeout) + assert call_timeout.total == _TEST_PREDICT_TIMEOUT + + @pytest.mark.asyncio + @pytest.mark.usefixtures( + "get_dedicated_endpoint_mock", "predict_endpoint_aiohttp_mock" + ) + async def test_predict_async_dedicated_endpoint_closes_session(self): + """The per-call ClientSession is closed, so no session is leaked.""" + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + + with mock.patch.object( + aiohttp.ClientSession, "close", new_callable=mock.AsyncMock + ) as close_mock: + await test_endpoint.predict_async( + instances=_TEST_INSTANCES, parameters={"param": 3.0} + ) + + close_mock.assert_awaited_once() + + @pytest.mark.asyncio + @pytest.mark.usefixtures("get_dedicated_endpoint_mock") + async def test_predict_async_dedicated_endpoint_integration(self): + """Hermetic integration test: real aiohttp request against a local server. + + Unlike the mock-based tests above, this exercises the real + aiohttp.ClientSession code path (connection, request encoding, + response decoding, JSON parsing, async context manager teardown) by + pointing the dedicated DNS at an in-process aiohttp.web server. + + The server validates the URL path, headers, and body shape the SDK + sends, and returns a Vertex-shaped JSON response. ``predict_async`` + creates its own ClientSession per call, so we patch the constructor to + return a session whose ``post`` downgrades the scheme from ``https://`` + to ``http://`` so the local server can serve plain HTTP without needing + self-signed TLS plumbing in the test. + """ + captured = {} + + async def predict_handler(request: aiohttp_web.Request): + captured["path"] = request.path + captured["headers"] = dict(request.headers) + captured["body"] = await request.json() + return aiohttp_web.json_response( + { + "predictions": _TEST_PREDICTION, + "metadata": _TEST_METADATA, + "deployedModelId": _TEST_DEPLOYED_MODELS[0].id, + "model": _TEST_MODEL_NAME, + "modelVersionId": "1", + } + ) + + app = aiohttp_web.Application() + app.router.add_post( + f"/v1/{_TEST_ENDPOINT_NAME}:predict", predict_handler + ) + runner = aiohttp_web.AppRunner(app) + await runner.setup() + site = aiohttp_web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] + + # Build a real session whose post() downgrades https:// -> http:// so + # the local test server can serve plain HTTP. Real aiohttp does + # everything else (real socket, real request/response cycle). Capture + # the real constructor here so the patched one below doesn't recurse. + real_client_session_cls = aiohttp.ClientSession + + def make_http_session(*args, **kwargs): + session = real_client_session_cls(*args, **kwargs) + original_post = session.post + + def http_post(url, **post_kwargs): + return original_post( + url.replace("https://", "http://", 1), **post_kwargs + ) + + session.post = http_post + return session + + try: + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + + # Force the dedicated DNS to point at the local server. + with mock.patch.object( + models.Endpoint, + "dedicated_endpoint_dns", + new_callable=mock.PropertyMock, + return_value=f"127.0.0.1:{port}", + ), mock.patch.object( + auth_credentials.AnonymousCredentials, "refresh" + ), mock.patch.object( + aiohttp, "ClientSession", side_effect=make_http_session + ): + # Give the mocked credentials a fake token. + test_endpoint.credentials.token = "test-token" + + prediction = await test_endpoint.predict_async( + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + ) + + # Verify the SDK sent the right wire request. + assert captured["path"] == f"/v1/{_TEST_ENDPOINT_NAME}:predict" + assert captured["headers"]["Content-Type"] == "application/json" + assert captured["headers"]["Authorization"] == "Bearer test-token" + assert captured["body"] == { + "instances": _TEST_INSTANCES, + "parameters": {"param": 3.0}, + } + + # Verify the response was parsed correctly into a Prediction. + assert prediction == models.Prediction( + predictions=_TEST_PREDICTION, + deployed_model_id=_TEST_ID, + metadata=_TEST_METADATA, + model_version_id=_TEST_VERSION_ID, + model_resource_name=_TEST_MODEL_NAME, + ) + finally: + await runner.cleanup() + @pytest.mark.usefixtures("get_endpoint_mock") def test_explain(self, predict_client_explain_mock): test_endpoint = models.Endpoint(_TEST_ID)