diff --git a/inference/core/constants.py b/inference/core/constants.py index 83d2a4aded..d9d51c2f1d 100644 --- a/inference/core/constants.py +++ b/inference/core/constants.py @@ -4,6 +4,7 @@ KEYPOINTS_DETECTION_TASK = "keypoint-detection" PROCESSING_TIME_HEADER = "X-Processing-Time" MODEL_COLD_START_HEADER = "X-Model-Cold-Start" +MODEL_COLD_START_COUNT_HEADER = "X-Model-Cold-Start-Count" MODEL_LOAD_TIME_HEADER = "X-Model-Load-Time" MODEL_LOAD_DETAILS_HEADER = "X-Model-Load-Details" MODEL_ID_HEADER = "X-Model-Id" diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index 01ccd55570..ef7f230ff2 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -32,6 +32,7 @@ from inference.core import logger from inference.core.constants import ( + MODEL_COLD_START_COUNT_HEADER, MODEL_COLD_START_HEADER, MODEL_ID_HEADER, MODEL_LOAD_DETAILS_HEADER, @@ -229,6 +230,12 @@ orjson_response, orjson_response_keeping_parent_id, ) +from inference.core.interfaces.http.request_metrics import ( + REMOTE_PROCESSING_TIME_HEADER, + REMOTE_PROCESSING_TIMES_HEADER, + GCPServerlessMiddleware, + build_model_response_headers, +) from inference.core.interfaces.stream_manager.api.entities import ( CommandContext, CommandResponse, @@ -316,23 +323,9 @@ from inference.core.version import __version__ try: - from inference_sdk.config import ( - EXECUTION_ID_HEADER, - INTERNAL_REMOTE_EXEC_REQ_HEADER, - INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER, - RemoteProcessingTimeCollector, - apply_duration_minimum, - execution_id, - remote_processing_times, - ) + from inference_sdk.config import EXECUTION_ID_HEADER except ImportError: - execution_id = None - remote_processing_times = None - RemoteProcessingTimeCollector = None EXECUTION_ID_HEADER = None - INTERNAL_REMOTE_EXEC_REQ_HEADER = None - INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER = None - apply_duration_minimum = None def get_content_type(request: Request) -> str: @@ -348,8 +341,6 @@ async def dispatch(self, request, call_next): return response -REMOTE_PROCESSING_TIME_HEADER = "X-Remote-Processing-Time" -REMOTE_PROCESSING_TIMES_HEADER = "X-Remote-Processing-Times" AUTH_CACHE_TTL_SECONDS = 3600 SHORT_AUTH_CACHE_TTL_SECONDS = 60 @@ -362,48 +353,6 @@ class AuthorizationCacheEntry: message: Optional[str] = None -class GCPServerlessMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - if execution_id is not None: - execution_id_value = request.headers.get(EXECUTION_ID_HEADER) - if not execution_id_value: - execution_id_value = f"{time.time_ns()}_{uuid4().hex[:4]}" - execution_id.set(execution_id_value) - is_verified_internal = False - if apply_duration_minimum is not None: - is_verified_internal = bool( - ROBOFLOW_INTERNAL_SERVICE_SECRET - and INTERNAL_REMOTE_EXEC_REQ_HEADER - and request.headers.get(INTERNAL_REMOTE_EXEC_REQ_HEADER) - == ROBOFLOW_INTERNAL_SERVICE_SECRET - ) - apply_duration_minimum.set(not is_verified_internal) - collector = None - if ( - WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING - and remote_processing_times is not None - and RemoteProcessingTimeCollector is not None - ): - collector = RemoteProcessingTimeCollector() - remote_processing_times.set(collector) - t1 = time.time() - response = await call_next(request) - t2 = time.time() - response.headers[PROCESSING_TIME_HEADER] = str(t2 - t1) - if collector is not None and collector.has_data(): - total, detail = collector.summarize() - response.headers[REMOTE_PROCESSING_TIME_HEADER] = str(total) - if detail is not None: - response.headers[REMOTE_PROCESSING_TIMES_HEADER] = detail - if execution_id is not None: - response.headers[EXECUTION_ID_HEADER] = execution_id_value - if INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER is not None: - response.headers[INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER] = str( - is_verified_internal - ).lower() - return response - - class HttpInterface(BaseInterface): """Roboflow defined HTTP interface for a general-purpose inference server. @@ -510,6 +459,7 @@ async def on_shutdown(): REMOTE_PROCESSING_TIME_HEADER, REMOTE_PROCESSING_TIMES_HEADER, MODEL_COLD_START_HEADER, + MODEL_COLD_START_COUNT_HEADER, MODEL_LOAD_TIME_HEADER, MODEL_LOAD_DETAILS_HEADER, MODEL_ID_HEADER, @@ -820,17 +770,35 @@ async def track_model_load(request: Request, call_next): ids_collector = RequestModelIds() request_model_ids.set(ids_collector) response = await call_next(request) - if load_collector.has_data(): - total, detail = load_collector.summarize() - response.headers[MODEL_COLD_START_HEADER] = "true" - response.headers[MODEL_LOAD_TIME_HEADER] = str(total) - if detail is not None: - response.headers[MODEL_LOAD_DETAILS_HEADER] = detail + remote_processing_collector = getattr( + request.state, "remote_processing_time_collector", None + ) + if remote_processing_collector is not None: + remote_model_ids = remote_processing_collector.snapshot_model_ids() + remote_cold_start_entries = ( + remote_processing_collector.snapshot_cold_start_entries() + ) + remote_cold_start_count = ( + remote_processing_collector.snapshot_cold_start_count() + ) + remote_cold_start_total_load_time = ( + remote_processing_collector.snapshot_cold_start_total_load_time() + ) else: - response.headers[MODEL_COLD_START_HEADER] = "false" - model_ids = ids_collector.get_ids() - if model_ids: - response.headers[MODEL_ID_HEADER] = ",".join(sorted(model_ids)) + remote_model_ids = set() + remote_cold_start_entries = [] + remote_cold_start_count = 0 + remote_cold_start_total_load_time = 0.0 + response.headers.update( + build_model_response_headers( + local_model_ids=ids_collector.get_ids(), + local_cold_start_entries=load_collector.snapshot_entries(), + remote_model_ids=remote_model_ids, + remote_cold_start_entries=remote_cold_start_entries, + remote_cold_start_count=remote_cold_start_count, + remote_cold_start_total_load_time=remote_cold_start_total_load_time, + ) + ) wf_id = request_workflow_id.get(None) if wf_id: response.headers[WORKFLOW_ID_HEADER] = wf_id @@ -856,6 +824,7 @@ async def structured_access_log(request: Request, call_next): "request_id": CORRELATION_ID_HEADER, "processing_time": PROCESSING_TIME_HEADER, "model_cold_start": MODEL_COLD_START_HEADER, + "model_cold_start_count": MODEL_COLD_START_COUNT_HEADER, "model_load_time": MODEL_LOAD_TIME_HEADER, "model_id": MODEL_ID_HEADER, "workflow_id": WORKFLOW_ID_HEADER, diff --git a/inference/core/interfaces/http/request_metrics.py b/inference/core/interfaces/http/request_metrics.py new file mode 100644 index 0000000000..beb92f1424 --- /dev/null +++ b/inference/core/interfaces/http/request_metrics.py @@ -0,0 +1,134 @@ +import json +import time +from typing import Dict, List, Optional, Tuple +from uuid import uuid4 + +from starlette.middleware.base import BaseHTTPMiddleware + +from inference.core.constants import ( + MODEL_COLD_START_COUNT_HEADER, + MODEL_COLD_START_HEADER, + MODEL_ID_HEADER, + MODEL_LOAD_DETAILS_HEADER, + MODEL_LOAD_TIME_HEADER, + PROCESSING_TIME_HEADER, +) +from inference.core.env import ( + ROBOFLOW_INTERNAL_SERVICE_SECRET, + WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING, +) + +try: + from inference_sdk.config import ( + EXECUTION_ID_HEADER, + INTERNAL_REMOTE_EXEC_REQ_HEADER, + INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER, + RemoteProcessingTimeCollector, + apply_duration_minimum, + execution_id, + remote_processing_times, + ) +except ImportError: + execution_id = None + remote_processing_times = None + RemoteProcessingTimeCollector = None + EXECUTION_ID_HEADER = None + INTERNAL_REMOTE_EXEC_REQ_HEADER = None + INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER = None + apply_duration_minimum = None + + +REMOTE_PROCESSING_TIME_HEADER = "X-Remote-Processing-Time" +REMOTE_PROCESSING_TIMES_HEADER = "X-Remote-Processing-Times" + + +def summarize_model_load_entries( + entries: List[Tuple[str, float]], max_detail_bytes: int = 4096 +) -> Tuple[float, Optional[str]]: + total = sum(load_time for _, load_time in entries) + detail = json.dumps( + [{"m": model_id, "t": load_time} for model_id, load_time in entries] + ) + if len(detail) > max_detail_bytes: + detail = None + return total, detail + + +def build_model_response_headers( + local_model_ids: set, + local_cold_start_entries: List[Tuple[str, float]], + remote_model_ids: set, + remote_cold_start_entries: List[Tuple[str, float]], + remote_cold_start_count: int, + remote_cold_start_total_load_time: float, +) -> Dict[str, str]: + response_headers = { + MODEL_COLD_START_HEADER: "false", + MODEL_COLD_START_COUNT_HEADER: "0", + } + model_ids = sorted(local_model_ids | remote_model_ids) + if model_ids: + response_headers[MODEL_ID_HEADER] = ",".join(model_ids) + local_cold_start_count = len(local_cold_start_entries) + cold_start_count = local_cold_start_count + remote_cold_start_count + response_headers[MODEL_COLD_START_COUNT_HEADER] = str(cold_start_count) + if cold_start_count == 0: + return response_headers + response_headers[MODEL_COLD_START_HEADER] = "true" + local_load_time = sum(load_time for _, load_time in local_cold_start_entries) + response_headers[MODEL_LOAD_TIME_HEADER] = str( + local_load_time + remote_cold_start_total_load_time + ) + detailed_entries = local_cold_start_entries + remote_cold_start_entries + if len(detailed_entries) != cold_start_count: + return response_headers + _, detail = summarize_model_load_entries(entries=detailed_entries) + if detail is not None: + response_headers[MODEL_LOAD_DETAILS_HEADER] = detail + return response_headers + + +class GCPServerlessMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + if execution_id is not None: + execution_id_value = request.headers.get(EXECUTION_ID_HEADER) + if not execution_id_value: + execution_id_value = f"{time.time_ns()}_{uuid4().hex[:4]}" + execution_id.set(execution_id_value) + is_verified_internal = False + if apply_duration_minimum is not None: + is_verified_internal = bool( + ROBOFLOW_INTERNAL_SERVICE_SECRET + and INTERNAL_REMOTE_EXEC_REQ_HEADER + and request.headers.get(INTERNAL_REMOTE_EXEC_REQ_HEADER) + == ROBOFLOW_INTERNAL_SERVICE_SECRET + ) + apply_duration_minimum.set(not is_verified_internal) + collector = None + if ( + remote_processing_times is not None + and RemoteProcessingTimeCollector is not None + ): + collector = RemoteProcessingTimeCollector() + request.state.remote_processing_time_collector = collector + remote_processing_times.set(collector) + t1 = time.time() + response = await call_next(request) + t2 = time.time() + response.headers[PROCESSING_TIME_HEADER] = str(t2 - t1) + if ( + WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING + and collector is not None + and collector.has_data() + ): + total, detail = collector.snapshot_summary() + response.headers[REMOTE_PROCESSING_TIME_HEADER] = str(total) + if detail is not None: + response.headers[REMOTE_PROCESSING_TIMES_HEADER] = detail + if execution_id is not None: + response.headers[EXECUTION_ID_HEADER] = execution_id_value + if INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER is not None: + response.headers[INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER] = str( + is_verified_internal + ).lower() + return response diff --git a/inference/core/managers/model_load_collector.py b/inference/core/managers/model_load_collector.py index 231571382f..97321aa786 100644 --- a/inference/core/managers/model_load_collector.py +++ b/inference/core/managers/model_load_collector.py @@ -25,6 +25,10 @@ def has_data(self) -> bool: with self._lock: return len(self._entries) > 0 + def snapshot_entries(self) -> list: + with self._lock: + return list(self._entries) + def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]]: """Return (total_load_time, entries_json_or_none). @@ -32,8 +36,7 @@ def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]] entries. If the JSON exceeds *max_detail_bytes*, the detail string is omitted (None). """ - with self._lock: - entries = list(self._entries) + entries = self.snapshot_entries() total = sum(t for _, t in entries) detail = json.dumps([{"m": m, "t": t} for m, t in entries]) if len(detail) > max_detail_bytes: diff --git a/inference_sdk/config.py b/inference_sdk/config.py index 3b31bd1db6..6785f3c7bc 100644 --- a/inference_sdk/config.py +++ b/inference_sdk/config.py @@ -2,7 +2,7 @@ import json import os import threading -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from inference_sdk.utils.environment import str2bool @@ -23,12 +23,44 @@ class RemoteProcessingTimeCollector: def __init__(self): self._entries: list = [] # list of (model_id, time) tuples + self._model_ids: set = set() + self._cold_start_entries: list = [] # list of (model_id, load_time) tuples + self._cold_start_total_load_time: float = 0.0 + self._cold_start_count: int = 0 self._lock = threading.Lock() def add(self, processing_time: float, model_id: str = "unknown") -> None: with self._lock: self._entries.append((model_id, processing_time)) + def add_model_id(self, model_id: Optional[str]) -> None: + if model_id in (None, "", "unknown"): + return + with self._lock: + self._model_ids.add(model_id) + + def add_model_ids(self, model_ids: Iterable[str]) -> None: + filtered_ids = { + model_id for model_id in model_ids if model_id not in (None, "", "unknown") + } + if not filtered_ids: + return + with self._lock: + self._model_ids.update(filtered_ids) + + def record_cold_start( + self, + load_time: float, + model_id: Optional[str] = None, + count: int = 1, + ) -> None: + with self._lock: + self._cold_start_total_load_time += load_time + self._cold_start_count += count + if model_id not in (None, "", "unknown"): + self._cold_start_entries.append((model_id, load_time)) + self._model_ids.add(model_id) + def drain(self) -> list: """Atomically return all entries and clear the internal list.""" with self._lock: @@ -36,10 +68,45 @@ def drain(self) -> list: self._entries = [] return entries + def snapshot_entries(self) -> list: + with self._lock: + return list(self._entries) + + def snapshot_model_ids(self) -> set: + with self._lock: + return set(self._model_ids) + + def snapshot_cold_start_entries(self) -> list: + with self._lock: + return list(self._cold_start_entries) + + def snapshot_cold_start_total_load_time(self) -> float: + with self._lock: + return self._cold_start_total_load_time + + def snapshot_cold_start_count(self) -> int: + with self._lock: + return self._cold_start_count + def has_data(self) -> bool: with self._lock: return len(self._entries) > 0 + def has_cold_start_data(self) -> bool: + with self._lock: + return self._cold_start_count > 0 + + def snapshot_summary( + self, max_detail_bytes: int = 4096 + ) -> Tuple[float, Optional[str]]: + """Return (total_time, entries_json_or_none) without clearing entries.""" + entries = self.snapshot_entries() + total = sum(t for _, t in entries) + detail = json.dumps([{"m": m, "t": t} for m, t in entries]) + if len(detail) > max_detail_bytes: + detail = None + return total, detail + def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]]: """Atomically drain entries and return (total_time, entries_json_or_none). diff --git a/inference_sdk/http/utils/executors.py b/inference_sdk/http/utils/executors.py index f1987f76b4..ab8639a19a 100644 --- a/inference_sdk/http/utils/executors.py +++ b/inference_sdk/http/utils/executors.py @@ -1,10 +1,11 @@ import asyncio import contextvars +import json import logging from concurrent.futures import ThreadPoolExecutor from enum import Enum from functools import partial -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import aiohttp import backoff @@ -30,6 +31,11 @@ RETRYABLE_STATUS_CODES = {429, 503, 504} UNKNOWN_MODEL_ID = "unknown" +MODEL_COLD_START_HEADER = "X-Model-Cold-Start" +MODEL_COLD_START_COUNT_HEADER = "X-Model-Cold-Start-Count" +MODEL_LOAD_TIME_HEADER = "X-Model-Load-Time" +MODEL_LOAD_DETAILS_HEADER = "X-Model-Load-Details" +MODEL_ID_HEADER = "X-Model-Id" class RequestMethod(Enum): @@ -92,6 +98,69 @@ def _extract_model_id_from_request_data(request_data: RequestData) -> str: return UNKNOWN_MODEL_ID +def _extract_model_ids_from_response( + response: Response, request_data: RequestData +) -> List[str]: + model_ids_header = response.headers.get(MODEL_ID_HEADER) + if model_ids_header: + return [ + model_id.strip() + for model_id in model_ids_header.split(",") + if model_id.strip() + ] + if request_data.payload and isinstance(request_data.payload, dict): + model_id = request_data.payload.get("model_id") + if model_id: + return [str(model_id)] + return [] + + +def _parse_model_load_details( + details_header: str, +) -> Optional[List[Tuple[Optional[str], float]]]: + try: + parsed = json.loads(details_header) + except json.JSONDecodeError: + return None + if not isinstance(parsed, list): + return None + result = [] + for entry in parsed: + if not isinstance(entry, dict) or "t" not in entry: + return None + try: + load_time = float(entry["t"]) + except (TypeError, ValueError): + return None + model_id = entry.get("m") + model_id = str(model_id) if model_id not in (None, "") else None + result.append((model_id, load_time)) + return result + + +def _extract_cold_start_count_from_response(response: Response) -> int: + count_header = response.headers.get(MODEL_COLD_START_COUNT_HEADER) + if count_header is None: + return 1 + try: + count = int(count_header) + except (TypeError, ValueError): + logging.warning( + "Malformed %s header value: %r", + MODEL_COLD_START_COUNT_HEADER, + count_header, + ) + return 1 + if count < 1: + logging.warning( + "Unexpected %s header value for cold start response: %r", + MODEL_COLD_START_COUNT_HEADER, + count_header, + ) + return 1 + return count + + def _collect_remote_processing_times( responses: List[Response], requests_data: List[RequestData], @@ -117,6 +186,45 @@ def _collect_remote_processing_times( logging.warning( "Malformed %s header value: %r", PROCESSING_TIME_HEADER, pt ) + model_ids = _extract_model_ids_from_response( + response=response, request_data=request_data + ) + collector.add_model_ids(model_ids=model_ids) + if response.headers.get(MODEL_COLD_START_HEADER, "").lower() != "true": + continue + details_header = response.headers.get(MODEL_LOAD_DETAILS_HEADER) + if details_header: + parsed_details = _parse_model_load_details(details_header) + if parsed_details is None: + logging.warning( + "Malformed %s header value: %r", + MODEL_LOAD_DETAILS_HEADER, + details_header, + ) + else: + for entry_model_id, load_time in parsed_details: + collector.record_cold_start( + load_time=load_time, + model_id=entry_model_id, + ) + continue + load_time = 0.0 + load_time_header = response.headers.get(MODEL_LOAD_TIME_HEADER) + if load_time_header is not None: + try: + load_time = float(load_time_header) + except (ValueError, TypeError): + logging.warning( + "Malformed %s header value: %r", + MODEL_LOAD_TIME_HEADER, + load_time_header, + ) + synthesized_model_id = model_ids[0] if len(model_ids) == 1 else None + collector.record_cold_start( + load_time=load_time, + count=_extract_cold_start_count_from_response(response=response), + model_id=synthesized_model_id, + ) def make_parallel_requests( @@ -364,7 +472,10 @@ def send_post_request( raise error if enable_retries and response.status_code in RETRYABLE_STATUS_CODES: raise RetryError( - f"Transient error in HTTP request - response with status code: {response.status_code} received.", + ( + "Transient error in HTTP request - response with status code: " + f"{response.status_code} received." + ), status_code=response.status_code, ) api_key_safe_raise_for_status(response=response) diff --git a/tests/inference/unit_tests/core/interfaces/http/test_model_response_headers.py b/tests/inference/unit_tests/core/interfaces/http/test_model_response_headers.py new file mode 100644 index 0000000000..0c6ade99be --- /dev/null +++ b/tests/inference/unit_tests/core/interfaces/http/test_model_response_headers.py @@ -0,0 +1,91 @@ +import json + +from inference.core.constants import ( + MODEL_COLD_START_COUNT_HEADER, + MODEL_COLD_START_HEADER, + MODEL_ID_HEADER, + MODEL_LOAD_DETAILS_HEADER, + MODEL_LOAD_TIME_HEADER, +) +from inference.core.interfaces.http.request_metrics import build_model_response_headers + + +def test_build_model_response_headers_for_remote_only_cold_start() -> None: + # when + result = build_model_response_headers( + local_model_ids=set(), + local_cold_start_entries=[], + remote_model_ids={"remote-model/1"}, + remote_cold_start_entries=[("remote-model/1", 0.8)], + remote_cold_start_count=1, + remote_cold_start_total_load_time=0.8, + ) + + # then + assert result[MODEL_COLD_START_HEADER] == "true" + assert result[MODEL_COLD_START_COUNT_HEADER] == "1" + assert result[MODEL_ID_HEADER] == "remote-model/1" + assert abs(float(result[MODEL_LOAD_TIME_HEADER]) - 0.8) < 1e-9 + assert json.loads(result[MODEL_LOAD_DETAILS_HEADER]) == [ + {"m": "remote-model/1", "t": 0.8} + ] + + +def test_build_model_response_headers_merges_local_and_remote_models() -> None: + # when + result = build_model_response_headers( + local_model_ids={"local-model/1"}, + local_cold_start_entries=[("local-model/1", 0.3)], + remote_model_ids={"remote-model/2"}, + remote_cold_start_entries=[("remote-model/2", 0.7)], + remote_cold_start_count=1, + remote_cold_start_total_load_time=0.7, + ) + + # then + assert result[MODEL_ID_HEADER] == "local-model/1,remote-model/2" + assert result[MODEL_COLD_START_HEADER] == "true" + assert result[MODEL_COLD_START_COUNT_HEADER] == "2" + assert abs(float(result[MODEL_LOAD_TIME_HEADER]) - 1.0) < 1e-9 + assert json.loads(result[MODEL_LOAD_DETAILS_HEADER]) == [ + {"m": "local-model/1", "t": 0.3}, + {"m": "remote-model/2", "t": 0.7}, + ] + + +def test_build_model_response_headers_omits_partial_remote_details() -> None: + # when + result = build_model_response_headers( + local_model_ids=set(), + local_cold_start_entries=[], + remote_model_ids={"model-a/1", "model-b/2"}, + remote_cold_start_entries=[], + remote_cold_start_count=1, + remote_cold_start_total_load_time=1.4, + ) + + # then + assert result[MODEL_COLD_START_HEADER] == "true" + assert result[MODEL_COLD_START_COUNT_HEADER] == "1" + assert result[MODEL_ID_HEADER] == "model-a/1,model-b/2" + assert abs(float(result[MODEL_LOAD_TIME_HEADER]) - 1.4) < 1e-9 + assert MODEL_LOAD_DETAILS_HEADER not in result + + +def test_build_model_response_headers_sets_zero_count_when_no_cold_start() -> None: + # when + result = build_model_response_headers( + local_model_ids={"model-a/1"}, + local_cold_start_entries=[], + remote_model_ids=set(), + remote_cold_start_entries=[], + remote_cold_start_count=0, + remote_cold_start_total_load_time=0.0, + ) + + # then + assert result[MODEL_COLD_START_HEADER] == "false" + assert result[MODEL_COLD_START_COUNT_HEADER] == "0" + assert result[MODEL_ID_HEADER] == "model-a/1" + assert MODEL_LOAD_TIME_HEADER not in result + assert MODEL_LOAD_DETAILS_HEADER not in result diff --git a/tests/inference/unit_tests/core/interfaces/http/test_remote_processing_time_middleware.py b/tests/inference/unit_tests/core/interfaces/http/test_remote_processing_time_middleware.py index 98f4aa4e92..df12abf180 100644 --- a/tests/inference/unit_tests/core/interfaces/http/test_remote_processing_time_middleware.py +++ b/tests/inference/unit_tests/core/interfaces/http/test_remote_processing_time_middleware.py @@ -6,7 +6,7 @@ from starlette.routing import Route from starlette.testclient import TestClient -from inference.core.interfaces.http.http_api import ( +from inference.core.interfaces.http.request_metrics import ( REMOTE_PROCESSING_TIME_HEADER, REMOTE_PROCESSING_TIMES_HEADER, GCPServerlessMiddleware, @@ -14,11 +14,12 @@ from inference_sdk.config import ( INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER, PROCESSING_TIME_HEADER, - RemoteProcessingTimeCollector, apply_duration_minimum, remote_processing_times, ) +REQUEST_METRICS_MODULE = "inference.core.interfaces.http.request_metrics" + def _endpoint_that_adds_remote_times(request): """Simulates a workflow that records remote processing times.""" @@ -34,6 +35,15 @@ def _endpoint_no_remote_times(request): return PlainTextResponse("OK") +def _endpoint_that_adds_remote_model_metadata(request): + collector = remote_processing_times.get() + if collector is None: + return PlainTextResponse("missing") + collector.add_model_ids(["remote-model/1"]) + collector.record_cold_start(load_time=0.4, model_id="remote-model/1") + return PlainTextResponse(str(collector.snapshot_cold_start_count())) + + def _create_app(routes): app = Starlette(routes=routes) app.add_middleware(GCPServerlessMiddleware) @@ -41,7 +51,7 @@ def _create_app(routes): @patch( - "inference.core.interfaces.http.http_api.WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING", + f"{REQUEST_METRICS_MODULE}.WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING", True, ) class TestGCPServerlessMiddlewareRemoteProcessingTimes: @@ -113,7 +123,7 @@ def _endpoint_many_times(request): @patch( - "inference.core.interfaces.http.http_api.WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING", + f"{REQUEST_METRICS_MODULE}.WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING", False, ) class TestGCPServerlessMiddlewareWithForwardingDisabled: @@ -134,6 +144,22 @@ def test_no_remote_headers_when_forwarding_disabled(self) -> None: # Wall-clock time still present assert PROCESSING_TIME_HEADER in response.headers + def test_collector_stays_available_for_model_metadata(self) -> None: + # given + app = _create_app( + [Route("/workflow", endpoint=_endpoint_that_adds_remote_model_metadata)] + ) + client = TestClient(app) + + # when + response = client.get("/workflow") + + # then + assert response.status_code == 200 + assert response.text == "1" + assert REMOTE_PROCESSING_TIME_HEADER not in response.headers + assert REMOTE_PROCESSING_TIMES_HEADER not in response.headers + def _endpoint_read_duration_minimum(request): """Returns the current value of apply_duration_minimum.""" @@ -142,7 +168,7 @@ def _endpoint_read_duration_minimum(request): @patch( - "inference.core.interfaces.http.http_api.WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING", + f"{REQUEST_METRICS_MODULE}.WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING", True, ) class TestApplyDurationMinimumContextVar: @@ -159,7 +185,7 @@ def test_direct_request_sets_apply_duration_minimum_true(self) -> None: assert response.headers[INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER] == "false" @patch( - "inference.core.interfaces.http.http_api.ROBOFLOW_INTERNAL_SERVICE_SECRET", + f"{REQUEST_METRICS_MODULE}.ROBOFLOW_INTERNAL_SERVICE_SECRET", "test-secret-123", ) def test_verified_internal_request_sets_apply_duration_minimum_false(self) -> None: @@ -178,7 +204,7 @@ def test_verified_internal_request_sets_apply_duration_minimum_false(self) -> No assert response.headers[INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER] == "true" @patch( - "inference.core.interfaces.http.http_api.ROBOFLOW_INTERNAL_SERVICE_SECRET", + f"{REQUEST_METRICS_MODULE}.ROBOFLOW_INTERNAL_SERVICE_SECRET", "test-secret-123", ) def test_wrong_secret_sets_apply_duration_minimum_true(self) -> None: @@ -197,7 +223,7 @@ def test_wrong_secret_sets_apply_duration_minimum_true(self) -> None: assert response.headers[INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER] == "false" @patch( - "inference.core.interfaces.http.http_api.ROBOFLOW_INTERNAL_SERVICE_SECRET", + f"{REQUEST_METRICS_MODULE}.ROBOFLOW_INTERNAL_SERVICE_SECRET", None, ) def test_no_secret_configured_sets_apply_duration_minimum_true(self) -> None: diff --git a/tests/inference_sdk/unit_tests/http/utils/test_remote_processing_time_collection.py b/tests/inference_sdk/unit_tests/http/utils/test_remote_processing_time_collection.py index 726b49e2af..1e73ae33c7 100644 --- a/tests/inference_sdk/unit_tests/http/utils/test_remote_processing_time_collection.py +++ b/tests/inference_sdk/unit_tests/http/utils/test_remote_processing_time_collection.py @@ -15,11 +15,18 @@ from inference_sdk.http.utils.request_building import RequestData -def _make_response(processing_time: str = None, status_code: int = 200) -> Response: +def _make_response( + processing_time: str = None, + status_code: int = 200, + extra_headers: dict = None, +) -> Response: response = Response() response.status_code = status_code if processing_time is not None: response.headers[PROCESSING_TIME_HEADER] = processing_time + if extra_headers: + for key, value in extra_headers.items(): + response.headers[key] = value return response @@ -211,3 +218,141 @@ def test_handles_more_responses_than_request_data(self, caplog) -> None: assert len(entries) == 1 assert entries[0][0] == "m1" assert "does not match" in caplog.text + + def test_collects_remote_model_id_from_response_header(self) -> None: + # given + collector = RemoteProcessingTimeCollector() + token = remote_processing_times.set(collector) + responses = [ + _make_response( + "0.5", + extra_headers={"X-Model-Id": "remote-model/1"}, + ) + ] + requests_data = [_make_request_data(model_id="payload-model/1")] + + try: + # when + _collect_remote_processing_times(responses, requests_data) + finally: + remote_processing_times.reset(token) + + # then + assert collector.snapshot_model_ids() == {"remote-model/1"} + + def test_collects_remote_cold_start_from_detailed_headers(self) -> None: + # given + collector = RemoteProcessingTimeCollector() + token = remote_processing_times.set(collector) + responses = [ + _make_response( + "0.5", + extra_headers={ + "X-Model-Id": "remote-model/1", + "X-Model-Cold-Start": "true", + "X-Model-Load-Time": "1.1", + "X-Model-Load-Details": '[{"m":"remote-model/1","t":1.1}]', + }, + ) + ] + requests_data = [_make_request_data(model_id="payload-model/1")] + + try: + # when + _collect_remote_processing_times(responses, requests_data) + finally: + remote_processing_times.reset(token) + + # then + assert collector.snapshot_model_ids() == {"remote-model/1"} + assert collector.snapshot_cold_start_entries() == [("remote-model/1", 1.1)] + assert collector.snapshot_cold_start_count() == 1 + assert abs(collector.snapshot_cold_start_total_load_time() - 1.1) < 1e-9 + + def test_collects_remote_cold_start_from_summary_headers_when_details_missing( + self, + ) -> None: + # given + collector = RemoteProcessingTimeCollector() + token = remote_processing_times.set(collector) + responses = [ + _make_response( + "0.5", + extra_headers={ + "X-Model-Id": "remote-model/1", + "X-Model-Cold-Start": "true", + "X-Model-Load-Time": "0.9", + }, + ) + ] + requests_data = [_make_request_data(model_id="payload-model/1")] + + try: + # when + _collect_remote_processing_times(responses, requests_data) + finally: + remote_processing_times.reset(token) + + # then + assert collector.snapshot_cold_start_entries() == [("remote-model/1", 0.9)] + assert collector.snapshot_cold_start_count() == 1 + assert abs(collector.snapshot_cold_start_total_load_time() - 0.9) < 1e-9 + + def test_collects_remote_cold_start_count_without_detail_when_model_ambiguous( + self, + ) -> None: + # given + collector = RemoteProcessingTimeCollector() + token = remote_processing_times.set(collector) + responses = [ + _make_response( + "0.5", + extra_headers={ + "X-Model-Id": "model-a/1,model-b/2", + "X-Model-Cold-Start": "true", + "X-Model-Load-Time": "1.4", + }, + ) + ] + requests_data = [_make_request_data(model_id="payload-model/1")] + + try: + # when + _collect_remote_processing_times(responses, requests_data) + finally: + remote_processing_times.reset(token) + + # then + assert collector.snapshot_model_ids() == {"model-a/1", "model-b/2"} + assert collector.snapshot_cold_start_entries() == [] + assert collector.snapshot_cold_start_count() == 1 + assert abs(collector.snapshot_cold_start_total_load_time() - 1.4) < 1e-9 + + def test_uses_cold_start_count_header_when_detail_is_unavailable(self) -> None: + # given + collector = RemoteProcessingTimeCollector() + token = remote_processing_times.set(collector) + responses = [ + _make_response( + "0.5", + extra_headers={ + "X-Model-Id": "model-a/1,model-b/2", + "X-Model-Cold-Start": "true", + "X-Model-Cold-Start-Count": "3", + "X-Model-Load-Time": "1.4", + }, + ) + ] + requests_data = [_make_request_data(model_id="payload-model/1")] + + try: + # when + _collect_remote_processing_times(responses, requests_data) + finally: + remote_processing_times.reset(token) + + # then + assert collector.snapshot_model_ids() == {"model-a/1", "model-b/2"} + assert collector.snapshot_cold_start_entries() == [] + assert collector.snapshot_cold_start_count() == 3 + assert abs(collector.snapshot_cold_start_total_load_time() - 1.4) < 1e-9 diff --git a/tests/inference_sdk/unit_tests/test_config.py b/tests/inference_sdk/unit_tests/test_config.py index 9add25ff73..546692ac9e 100644 --- a/tests/inference_sdk/unit_tests/test_config.py +++ b/tests/inference_sdk/unit_tests/test_config.py @@ -142,3 +142,50 @@ def test_contextvar_set_and_get() -> None: assert retrieved is collector finally: remote_processing_times.reset(token) + + +def test_snapshot_summary_does_not_clear_entries() -> None: + # given + collector = RemoteProcessingTimeCollector() + collector.add(0.5, model_id="yolov8") + collector.add(0.3, model_id="clip") + + # when + total, detail = collector.snapshot_summary() + + # then + assert abs(total - 0.8) < 1e-9 + parsed = json.loads(detail) + assert parsed == [ + {"m": "yolov8", "t": 0.5}, + {"m": "clip", "t": 0.3}, + ] + assert collector.snapshot_entries() == [("yolov8", 0.5), ("clip", 0.3)] + + +def test_collector_tracks_remote_model_ids() -> None: + # given + collector = RemoteProcessingTimeCollector() + + # when + collector.add_model_id("model-a/1") + collector.add_model_ids(["model-b/2", "model-a/1", "", "unknown"]) + + # then + assert collector.snapshot_model_ids() == {"model-a/1", "model-b/2"} + + +def test_collector_tracks_remote_cold_start_metadata() -> None: + # given + collector = RemoteProcessingTimeCollector() + + # when + collector.record_cold_start(load_time=0.4, model_id="model-a/1") + collector.record_cold_start(load_time=0.2, model_id=None, count=1) + + # then + assert collector.has_cold_start_data() is True + assert collector.snapshot_cold_start_entries() == [("model-a/1", 0.4)] + assert collector.snapshot_cold_start_count() == 2 + assert abs(collector.snapshot_cold_start_total_load_time() - 0.6) < 1e-9 + assert collector.snapshot_model_ids() == {"model-a/1"}