Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions inference/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
KEYPOINTS_DETECTION_TASK = "keypoint-detection"
PROCESSING_TIME_HEADER = "X-Processing-Time"
MODEL_COLD_START_HEADER = "X-Model-Cold-Start"
MODEL_COLD_START_COUNT_HEADER = "X-Model-Cold-Start-Count"
MODEL_LOAD_TIME_HEADER = "X-Model-Load-Time"
MODEL_LOAD_DETAILS_HEADER = "X-Model-Load-Details"
MODEL_ID_HEADER = "X-Model-Id"
Expand Down
107 changes: 38 additions & 69 deletions inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from inference.core import logger
from inference.core.constants import (
MODEL_COLD_START_COUNT_HEADER,
MODEL_COLD_START_HEADER,
MODEL_ID_HEADER,
MODEL_LOAD_DETAILS_HEADER,
Expand Down Expand Up @@ -229,6 +230,12 @@
orjson_response,
orjson_response_keeping_parent_id,
)
from inference.core.interfaces.http.request_metrics import (
REMOTE_PROCESSING_TIME_HEADER,
REMOTE_PROCESSING_TIMES_HEADER,
GCPServerlessMiddleware,
build_model_response_headers,
)
from inference.core.interfaces.stream_manager.api.entities import (
CommandContext,
CommandResponse,
Expand Down Expand Up @@ -316,23 +323,9 @@
from inference.core.version import __version__

try:
from inference_sdk.config import (
EXECUTION_ID_HEADER,
INTERNAL_REMOTE_EXEC_REQ_HEADER,
INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER,
RemoteProcessingTimeCollector,
apply_duration_minimum,
execution_id,
remote_processing_times,
)
from inference_sdk.config import EXECUTION_ID_HEADER
except ImportError:
execution_id = None
remote_processing_times = None
RemoteProcessingTimeCollector = None
EXECUTION_ID_HEADER = None
INTERNAL_REMOTE_EXEC_REQ_HEADER = None
INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER = None
apply_duration_minimum = None


def get_content_type(request: Request) -> str:
Expand All @@ -348,8 +341,6 @@ async def dispatch(self, request, call_next):
return response


REMOTE_PROCESSING_TIME_HEADER = "X-Remote-Processing-Time"
REMOTE_PROCESSING_TIMES_HEADER = "X-Remote-Processing-Times"
AUTH_CACHE_TTL_SECONDS = 3600
SHORT_AUTH_CACHE_TTL_SECONDS = 60

Expand All @@ -362,48 +353,6 @@ class AuthorizationCacheEntry:
message: Optional[str] = None


class GCPServerlessMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
if execution_id is not None:
execution_id_value = request.headers.get(EXECUTION_ID_HEADER)
if not execution_id_value:
execution_id_value = f"{time.time_ns()}_{uuid4().hex[:4]}"
execution_id.set(execution_id_value)
is_verified_internal = False
if apply_duration_minimum is not None:
is_verified_internal = bool(
ROBOFLOW_INTERNAL_SERVICE_SECRET
and INTERNAL_REMOTE_EXEC_REQ_HEADER
and request.headers.get(INTERNAL_REMOTE_EXEC_REQ_HEADER)
== ROBOFLOW_INTERNAL_SERVICE_SECRET
)
apply_duration_minimum.set(not is_verified_internal)
collector = None
if (
WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING
and remote_processing_times is not None
and RemoteProcessingTimeCollector is not None
):
collector = RemoteProcessingTimeCollector()
remote_processing_times.set(collector)
t1 = time.time()
response = await call_next(request)
t2 = time.time()
response.headers[PROCESSING_TIME_HEADER] = str(t2 - t1)
if collector is not None and collector.has_data():
total, detail = collector.summarize()
response.headers[REMOTE_PROCESSING_TIME_HEADER] = str(total)
if detail is not None:
response.headers[REMOTE_PROCESSING_TIMES_HEADER] = detail
if execution_id is not None:
response.headers[EXECUTION_ID_HEADER] = execution_id_value
if INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER is not None:
response.headers[INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER] = str(
is_verified_internal
).lower()
return response


class HttpInterface(BaseInterface):
"""Roboflow defined HTTP interface for a general-purpose inference server.

Expand Down Expand Up @@ -510,6 +459,7 @@ async def on_shutdown():
REMOTE_PROCESSING_TIME_HEADER,
REMOTE_PROCESSING_TIMES_HEADER,
MODEL_COLD_START_HEADER,
MODEL_COLD_START_COUNT_HEADER,
MODEL_LOAD_TIME_HEADER,
MODEL_LOAD_DETAILS_HEADER,
MODEL_ID_HEADER,
Expand Down Expand Up @@ -820,17 +770,35 @@ async def track_model_load(request: Request, call_next):
ids_collector = RequestModelIds()
request_model_ids.set(ids_collector)
response = await call_next(request)
if load_collector.has_data():
total, detail = load_collector.summarize()
response.headers[MODEL_COLD_START_HEADER] = "true"
response.headers[MODEL_LOAD_TIME_HEADER] = str(total)
if detail is not None:
response.headers[MODEL_LOAD_DETAILS_HEADER] = detail
remote_processing_collector = getattr(
request.state, "remote_processing_time_collector", None
)
if remote_processing_collector is not None:
remote_model_ids = remote_processing_collector.snapshot_model_ids()
remote_cold_start_entries = (
remote_processing_collector.snapshot_cold_start_entries()
)
remote_cold_start_count = (
remote_processing_collector.snapshot_cold_start_count()
)
remote_cold_start_total_load_time = (
remote_processing_collector.snapshot_cold_start_total_load_time()
)
else:
response.headers[MODEL_COLD_START_HEADER] = "false"
model_ids = ids_collector.get_ids()
if model_ids:
response.headers[MODEL_ID_HEADER] = ",".join(sorted(model_ids))
remote_model_ids = set()
remote_cold_start_entries = []
remote_cold_start_count = 0
remote_cold_start_total_load_time = 0.0
response.headers.update(
build_model_response_headers(
local_model_ids=ids_collector.get_ids(),
local_cold_start_entries=load_collector.snapshot_entries(),
remote_model_ids=remote_model_ids,
remote_cold_start_entries=remote_cold_start_entries,
remote_cold_start_count=remote_cold_start_count,
remote_cold_start_total_load_time=remote_cold_start_total_load_time,
)
)
wf_id = request_workflow_id.get(None)
if wf_id:
response.headers[WORKFLOW_ID_HEADER] = wf_id
Expand All @@ -856,6 +824,7 @@ async def structured_access_log(request: Request, call_next):
"request_id": CORRELATION_ID_HEADER,
"processing_time": PROCESSING_TIME_HEADER,
"model_cold_start": MODEL_COLD_START_HEADER,
"model_cold_start_count": MODEL_COLD_START_COUNT_HEADER,
"model_load_time": MODEL_LOAD_TIME_HEADER,
"model_id": MODEL_ID_HEADER,
"workflow_id": WORKFLOW_ID_HEADER,
Expand Down
134 changes: 134 additions & 0 deletions inference/core/interfaces/http/request_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
import time
from typing import Dict, List, Optional, Tuple
from uuid import uuid4

from starlette.middleware.base import BaseHTTPMiddleware

from inference.core.constants import (
MODEL_COLD_START_COUNT_HEADER,
MODEL_COLD_START_HEADER,
MODEL_ID_HEADER,
MODEL_LOAD_DETAILS_HEADER,
MODEL_LOAD_TIME_HEADER,
PROCESSING_TIME_HEADER,
)
from inference.core.env import (
ROBOFLOW_INTERNAL_SERVICE_SECRET,
WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING,
)

try:
from inference_sdk.config import (
EXECUTION_ID_HEADER,
INTERNAL_REMOTE_EXEC_REQ_HEADER,
INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER,
RemoteProcessingTimeCollector,
apply_duration_minimum,
execution_id,
remote_processing_times,
)
except ImportError:
execution_id = None
remote_processing_times = None
RemoteProcessingTimeCollector = None
EXECUTION_ID_HEADER = None
INTERNAL_REMOTE_EXEC_REQ_HEADER = None
INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER = None
apply_duration_minimum = None


REMOTE_PROCESSING_TIME_HEADER = "X-Remote-Processing-Time"
REMOTE_PROCESSING_TIMES_HEADER = "X-Remote-Processing-Times"


def summarize_model_load_entries(
entries: List[Tuple[str, float]], max_detail_bytes: int = 4096
) -> Tuple[float, Optional[str]]:
total = sum(load_time for _, load_time in entries)
detail = json.dumps(
[{"m": model_id, "t": load_time} for model_id, load_time in entries]
)
if len(detail) > max_detail_bytes:
detail = None
return total, detail


def build_model_response_headers(
local_model_ids: set,
local_cold_start_entries: List[Tuple[str, float]],
remote_model_ids: set,
remote_cold_start_entries: List[Tuple[str, float]],
remote_cold_start_count: int,
remote_cold_start_total_load_time: float,
) -> Dict[str, str]:
response_headers = {
MODEL_COLD_START_HEADER: "false",
MODEL_COLD_START_COUNT_HEADER: "0",
}
model_ids = sorted(local_model_ids | remote_model_ids)
if model_ids:
response_headers[MODEL_ID_HEADER] = ",".join(model_ids)
local_cold_start_count = len(local_cold_start_entries)
cold_start_count = local_cold_start_count + remote_cold_start_count
response_headers[MODEL_COLD_START_COUNT_HEADER] = str(cold_start_count)
if cold_start_count == 0:
return response_headers
response_headers[MODEL_COLD_START_HEADER] = "true"
local_load_time = sum(load_time for _, load_time in local_cold_start_entries)
response_headers[MODEL_LOAD_TIME_HEADER] = str(
local_load_time + remote_cold_start_total_load_time
)
detailed_entries = local_cold_start_entries + remote_cold_start_entries
if len(detailed_entries) != cold_start_count:
return response_headers
_, detail = summarize_model_load_entries(entries=detailed_entries)
if detail is not None:
response_headers[MODEL_LOAD_DETAILS_HEADER] = detail
return response_headers


class GCPServerlessMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
if execution_id is not None:
execution_id_value = request.headers.get(EXECUTION_ID_HEADER)
if not execution_id_value:
execution_id_value = f"{time.time_ns()}_{uuid4().hex[:4]}"
execution_id.set(execution_id_value)
is_verified_internal = False
if apply_duration_minimum is not None:
is_verified_internal = bool(
ROBOFLOW_INTERNAL_SERVICE_SECRET
and INTERNAL_REMOTE_EXEC_REQ_HEADER
and request.headers.get(INTERNAL_REMOTE_EXEC_REQ_HEADER)
== ROBOFLOW_INTERNAL_SERVICE_SECRET
)
apply_duration_minimum.set(not is_verified_internal)
collector = None
if (
remote_processing_times is not None
and RemoteProcessingTimeCollector is not None
):
collector = RemoteProcessingTimeCollector()
request.state.remote_processing_time_collector = collector
remote_processing_times.set(collector)
t1 = time.time()
response = await call_next(request)
t2 = time.time()
response.headers[PROCESSING_TIME_HEADER] = str(t2 - t1)
if (
WORKFLOWS_REMOTE_EXECUTION_TIME_FORWARDING
and collector is not None
and collector.has_data()
):
total, detail = collector.snapshot_summary()
response.headers[REMOTE_PROCESSING_TIME_HEADER] = str(total)
if detail is not None:
response.headers[REMOTE_PROCESSING_TIMES_HEADER] = detail
if execution_id is not None:
response.headers[EXECUTION_ID_HEADER] = execution_id_value
if INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER is not None:
response.headers[INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER] = str(
is_verified_internal
).lower()
return response
7 changes: 5 additions & 2 deletions inference/core/managers/model_load_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ def has_data(self) -> bool:
with self._lock:
return len(self._entries) > 0

def snapshot_entries(self) -> list:
with self._lock:
return list(self._entries)

def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]]:
"""Return (total_load_time, entries_json_or_none).

Returns the total model load time and a JSON string of individual
entries. If the JSON exceeds *max_detail_bytes*, the detail string
is omitted (None).
"""
with self._lock:
entries = list(self._entries)
entries = self.snapshot_entries()
total = sum(t for _, t in entries)
detail = json.dumps([{"m": m, "t": t} for m, t in entries])
if len(detail) > max_detail_bytes:
Expand Down
Loading
Loading