Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion automl-service/app/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,27 @@

from typing import Optional, Tuple

from fastapi import HTTPException
from fastapi import HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.utils import remap_shared_path
from app.db import crud


def resolve_request_project_id(request: Optional[Request]) -> Optional[str]:
"""Resolve project context from ``projectId`` / ``project_id`` query param.

Returns ``None`` when no project context is available.
"""
if request is not None:
for query_key in ("projectId", "project_id"):
query_project_id = request.query_params.get(query_key)
if query_project_id:
return query_project_id

return None


async def get_job_paths(
db: AsyncSession, job_id: str
) -> Tuple[str, str, Optional[str], Optional[str]]:
Expand Down
46 changes: 46 additions & 0 deletions automl-service/app/core/leaderboard_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Helpers for normalizing leaderboard payloads across AutoGluon model types."""

from __future__ import annotations

from typing import Any, Optional


def normalize_leaderboard_rows(
rows: Optional[list[dict[str, Any]]],
) -> Optional[list[dict[str, Any]]]:
"""Normalize leaderboard rows to expose common timing keys.

AutoGluon TimeSeries leaderboards expose ``fit_time_marginal`` but not
``fit_time``. Our UI expects ``fit_time`` for the training-time chart and
leaderboard table, so copy the marginal value into ``fit_time`` when the
cumulative field is absent.
"""
if rows is None:
return None

normalized_rows: list[dict[str, Any]] = []
for row in rows:
normalized = dict(row)

if normalized.get("fit_time") is None and normalized.get("fit_time_marginal") is not None:
normalized["fit_time"] = normalized["fit_time_marginal"]

if normalized.get("pred_time_val") is None and normalized.get("pred_time_val_marginal") is not None:
normalized["pred_time_val"] = normalized["pred_time_val_marginal"]

normalized_rows.append(normalized)

return normalized_rows


def normalize_leaderboard_payload(payload: Any) -> Any:
"""Normalize leaderboard payloads stored as either lists or dict wrappers."""
if isinstance(payload, list):
return normalize_leaderboard_rows(payload)

if isinstance(payload, dict) and isinstance(payload.get("models"), list):
normalized = dict(payload)
normalized["models"] = normalize_leaderboard_rows(payload["models"])
return normalized

return payload
54 changes: 54 additions & 0 deletions automl-service/tests/test_api_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Tests for shared API request helpers."""

from starlette.requests import Request

from app.api.utils import resolve_request_project_id


def _make_request(*, headers=None, query_string: bytes = b"") -> Request:
encoded_headers = [
(key.lower().encode("latin-1"), value.encode("latin-1"))
for key, value in (headers or {}).items()
]
return Request(
{
"type": "http",
"method": "GET",
"path": "/test",
"headers": encoded_headers,
"query_string": query_string,
}
)


def test_resolve_request_project_id_ignores_header(monkeypatch):
"""X-Project-Id header is not used — only query params."""
monkeypatch.delenv("DOMINO_PROJECT_ID", raising=False)
request = _make_request(headers={"X-Project-Id": "header-proj"})

assert resolve_request_project_id(request) is None


def test_resolve_request_project_id_reads_camel_case_query_param(monkeypatch):
monkeypatch.delenv("DOMINO_PROJECT_ID", raising=False)
request = _make_request(query_string=b"projectId=query-proj")

assert resolve_request_project_id(request) == "query-proj"


def test_resolve_request_project_id_reads_snake_case_query_param(monkeypatch):
monkeypatch.delenv("DOMINO_PROJECT_ID", raising=False)
request = _make_request(query_string=b"project_id=query-proj")

assert resolve_request_project_id(request) == "query-proj"


def test_resolve_request_project_id_ignores_environment_variable(monkeypatch):
"""DOMINO_PROJECT_ID is the App's own project — never use it as fallback."""
monkeypatch.setenv("DOMINO_PROJECT_ID", "env-proj")

assert resolve_request_project_id(None) is None


def test_resolve_request_project_id_none_without_request():
assert resolve_request_project_id(None) is None
65 changes: 65 additions & 0 deletions automl-service/tests/test_leaderboard_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Tests for leaderboard payload normalization helpers."""

from pathlib import Path
import sys

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from app.core.leaderboard_utils import ( # noqa: E402
normalize_leaderboard_payload,
normalize_leaderboard_rows,
)


def test_normalize_leaderboard_rows_copies_fit_time_from_marginal():
rows = [
{
"model": "WeightedEnsemble",
"fit_time": None,
"fit_time_marginal": 0.42,
"pred_time_val": 1.5,
}
]

normalized = normalize_leaderboard_rows(rows)

assert normalized == [
{
"model": "WeightedEnsemble",
"fit_time": 0.42,
"fit_time_marginal": 0.42,
"pred_time_val": 1.5,
}
]


def test_normalize_leaderboard_rows_preserves_existing_fit_time():
rows = [
{
"model": "DirectTabular",
"fit_time": 12.3,
"fit_time_marginal": 4.5,
}
]

normalized = normalize_leaderboard_rows(rows)

assert normalized[0]["fit_time"] == 12.3
assert normalized[0]["fit_time_marginal"] == 4.5


def test_normalize_leaderboard_payload_updates_models_wrapper():
payload = {
"models": [
{
"model": "Theta",
"fit_time_marginal": 0.07,
}
],
"best_model": "Theta",
}

normalized = normalize_leaderboard_payload(payload)

assert normalized["models"][0]["fit_time"] == 0.07
assert normalized["best_model"] == "Theta"