From 69e85ed643b9e3a87b7fcbf1e2fd9b2fd1f62977 Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:03:59 +0200 Subject: [PATCH 01/36] feat: add risk-api client and sequence confidence helper --- .env.example | 8 +++ src/app/core/config.py | 8 +++ src/app/services/risk.py | 73 +++++++++++++++++++++++++ src/app/services/sequence_confidence.py | 72 ++++++++++++++++++++++++ 4 files changed, 161 insertions(+) create mode 100644 src/app/services/risk.py create mode 100644 src/app/services/sequence_confidence.py diff --git a/.env.example b/.env.example index 7cbc4733..5b139946 100644 --- a/.env.example +++ b/.env.example @@ -23,6 +23,14 @@ POSTHOG_KEY= SUPPORT_EMAIL= TELEGRAM_TOKEN= +# Risk API (daily fire-weather index per camera) +RISK_API_HOST= +RISK_API_USERNAME= +RISK_API_PASSWORD= +RISK_REFRESH_HOUR_UTC=4 +FWI_LOW_MIN_CONF=0.45 +FWI_VERY_LOW_MIN_CONF=0.6 + # Production-only ACME_EMAIL= BACKEND_HOST= diff --git a/src/app/core/config.py b/src/app/core/config.py index f4c0f3e5..e14e0977 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -77,6 +77,14 @@ def sqlachmey_uri(cls, v: str) -> str: TELEGRAM_TOKEN: Union[str, None] = os.environ.get("TELEGRAM_TOKEN") PLATFORM_URL: str = os.environ.get("PLATFORM_URL", "") + # Risk API (daily fire-weather index per camera) + RISK_API_HOST: Union[str, None] = os.environ.get("RISK_API_HOST") + RISK_API_USERNAME: Union[str, None] = os.environ.get("RISK_API_USERNAME") + RISK_API_PASSWORD: Union[str, None] = os.environ.get("RISK_API_PASSWORD") + RISK_REFRESH_HOUR_UTC: int = int(os.environ.get("RISK_REFRESH_HOUR_UTC") or 4) + FWI_LOW_MIN_CONF: float = float(os.environ.get("FWI_LOW_MIN_CONF") or 0.45) + FWI_VERY_LOW_MIN_CONF: float = float(os.environ.get("FWI_VERY_LOW_MIN_CONF") or 0.6) + # Error monitoring SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN") SERVER_NAME: str = os.environ.get("SERVER_NAME", socket.gethostname()) diff --git a/src/app/services/risk.py b/src/app/services/risk.py new file mode 100644 index 00000000..7ce9a20e --- /dev/null +++ b/src/app/services/risk.py @@ -0,0 +1,73 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import logging +from typing import Union + +import httpx + +from app.core.config import settings + +logger = logging.getLogger("uvicorn.error") + +__all__ = ["risk_service"] + +# EFFIS classes that should trigger filtering. Anything else (moderate+) → no filter. +_LOW = "low" +_VERY_LOW = "very_low" + + +class RiskService: + """In-memory cache of the daily FWI class per camera. + + Refreshed once at startup and once a day from the pyro-risk-api service. + Fail-open: if the API is unreachable or a camera is unknown, no filter is applied. + """ + + def __init__(self) -> None: + self._scores: dict[int, str] = {} + + @property + def is_configured(self) -> bool: + return bool(settings.RISK_API_HOST and settings.RISK_API_USERNAME and settings.RISK_API_PASSWORD) + + def min_confidence(self, camera_id: int) -> Union[float, None]: + """Return the min confidence required for this camera, or None if no filter applies.""" + fwi_class = self._scores.get(camera_id) + if fwi_class == _VERY_LOW: + return settings.FWI_VERY_LOW_MIN_CONF + if fwi_class == _LOW: + return settings.FWI_LOW_MIN_CONF + return None + + async def refresh(self) -> None: + """Fetch fresh FWI classes from the risk API. On error, keep the previous cache.""" + if not self.is_configured: + return + host = settings.RISK_API_HOST.rstrip("/") # type: ignore[union-attr] + url = f"{host}/cameras" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.get( + url, + auth=(settings.RISK_API_USERNAME, settings.RISK_API_PASSWORD), # type: ignore[arg-type] + ) + response.raise_for_status() + payload = response.json() + except (httpx.HTTPError, ValueError) as exc: + logger.warning("Risk API refresh failed (%s); keeping previous cache (%d entries)", exc, len(self._scores)) + return + + scores: dict[int, str] = {} + for item in payload: + camera_id = item.get("id") + fwi_class = item.get("fwi_class") + if isinstance(camera_id, int) and isinstance(fwi_class, str): + scores[camera_id] = fwi_class + self._scores = scores + logger.info("Risk API refresh: cached FWI class for %d camera(s)", len(scores)) + + +risk_service = RiskService() diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py new file mode 100644 index 00000000..a46b1392 --- /dev/null +++ b/src/app/services/sequence_confidence.py @@ -0,0 +1,72 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +import logging +import re +from ast import literal_eval +from typing import Any, Dict, Iterable, List, Union, cast + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models import Detection +from app.schemas.detections import BOX_PATTERN + +logger = logging.getLogger("uvicorn.error") + +__all__ = ["max_conf_from_bboxes", "get_max_conf_by_sequence_ids"] + + +def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: + """Return the highest confidence found across the given bbox strings, or None if none parse.""" + best: Union[float, None] = None + for raw in bbox_strings: + if not raw: + continue + for match in re.finditer(BOX_PATTERN, raw): + try: + bbox = literal_eval(match.group(0)) + except (SyntaxError, ValueError): + continue + if not isinstance(bbox, tuple) or len(bbox) != 5: + continue + conf = bbox[4] + if not isinstance(conf, (int, float)): + continue + if best is None or conf > best: + best = float(conf) + return best + + +async def get_max_conf_by_sequence_ids( + session: AsyncSession, + sequence_ids: Iterable[int], +) -> Dict[int, float]: + """Return {sequence_id: max_conf} computed from all detections of those sequences. + + Sequences with no detections or unparseable bboxes are omitted from the result — + callers should treat a missing key as "unknown" and fail open. + """ + seq_ids: List[int] = list({int(sid) for sid in sequence_ids}) + if not seq_ids: + return {} + + stmt: Any = select(Detection.sequence_id, Detection.bbox, Detection.others_bboxes).where( + cast(Any, Detection.sequence_id).in_(seq_ids) + ) + res = await session.exec(stmt) + + out: Dict[int, float] = {} + for sequence_id, bbox, others in res.all(): + if sequence_id is None: + continue + conf = max_conf_from_bboxes(bbox, others) + if conf is None: + continue + current = out.get(sequence_id) + if current is None or conf > current: + out[sequence_id] = conf + return out From 72e4990fe0df095d080bb8ab0f4dcb09108cb4db Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:04:18 +0200 Subject: [PATCH 02/36] feat: schedule daily risk-api refresh in app lifespan --- src/app/main.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/app/main.py b/src/app/main.py index 63065c1f..ac4fa8ed 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -3,8 +3,11 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +import asyncio import logging import time +from contextlib import asynccontextmanager +from datetime import datetime, timedelta, timezone import sentry_sdk from fastapi import FastAPI, Request, status @@ -19,6 +22,7 @@ from app.api.api_v1.router import api_router from app.core.config import settings from app.schemas.base import Status +from app.services.risk import risk_service logger = logging.getLogger("uvicorn.error") @@ -40,6 +44,44 @@ logger.info(f"Sentry middleware enabled on server {settings.SERVER_NAME}") +def _seconds_until_next_utc_hour(target_hour: int) -> float: + now = datetime.now(tz=timezone.utc) + target = now.replace(hour=target_hour, minute=0, second=0, microsecond=0) + if target <= now: + target += timedelta(days=1) + return (target - now).total_seconds() + + +async def _risk_refresh_loop() -> None: + while True: + try: + await asyncio.sleep(_seconds_until_next_utc_hour(settings.RISK_REFRESH_HOUR_UTC)) + await risk_service.refresh() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Risk refresh loop iteration failed; continuing") + + +@asynccontextmanager +async def lifespan(_: FastAPI): + if risk_service.is_configured: + await risk_service.refresh() + task = asyncio.create_task(_risk_refresh_loop()) + else: + task = None + logger.info("Risk API not configured; skipping daily refresh") + try: + yield + finally: + if task is not None: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + app = FastAPI( title=settings.PROJECT_NAME, description=settings.PROJECT_DESCRIPTION, @@ -47,6 +89,7 @@ version=settings.VERSION, openapi_url=f"{settings.API_V1_STR}/openapi.json", docs_url=None, + lifespan=lifespan, ) From 5357a9a17aa35b927d41079bb4dc81f7c576f1c1 Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:05:36 +0200 Subject: [PATCH 03/36] feat: filter alerts and sequences by risk-driven confidence threshold --- src/app/api/api_v1/endpoints/alerts.py | 25 ++++++++++++-- src/app/api/api_v1/endpoints/sequences.py | 3 ++ src/app/services/sequence_confidence.py | 41 +++++++++++++++++++++-- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 7cf64074..54cbf019 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -20,6 +20,7 @@ from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead +from app.services.sequence_confidence import filter_sequences_by_risk from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.telemetry import telemetry_client @@ -47,6 +48,24 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i return mapping +async def _apply_risk_filter_to_alerts( + session: AsyncSession, + alerts: List[Alert], + seq_map: Dict[int, List[Sequence]], +) -> List[Alert]: + """Drop sequences below the risk threshold and alerts that end up empty.""" + all_sequences = [seq for seqs in seq_map.values() for seq in seqs] + kept_seqs = await filter_sequences_by_risk(session, all_sequences) + kept_ids = {seq.id for seq in kept_seqs} + kept_alerts: List[Alert] = [] + for alert in alerts: + filtered = [seq for seq in seq_map.get(alert.id, []) if seq.id in kept_ids] + if filtered: + seq_map[alert.id] = filtered + kept_alerts.append(alert) + return kept_alerts + + def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead: return SequenceRead(**sequence.model_dump(), detections_count=detections_count) @@ -128,9 +147,10 @@ async def fetch_latest_unlabeled_alerts( .limit(15) ) alerts_res = await session.exec(alerts_stmt) - alerts = alerts_res.unique().all() + alerts = list(alerts_res.unique().all()) alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) + alerts = await _apply_risk_filter_to_alerts(session, alerts, seq_map) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), @@ -157,9 +177,10 @@ async def fetch_alerts_from_date( .offset(offset) ) alerts_res = await session.exec(alerts_stmt) - alerts = alerts_res.all() + alerts = list(alerts_res.all()) alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) + alerts = await _apply_risk_filter_to_alerts(session, alerts, seq_map) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index edacaa46..5e703cf3 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -22,6 +22,7 @@ from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel, SequenceRead from app.services.overlap import compute_overlap +from app.services.sequence_confidence import filter_sequences_by_risk from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.storage import s3_service from app.services.telemetry import telemetry_client @@ -162,6 +163,7 @@ async def fetch_latest_unlabeled_sequences( .limit(15) ) ).all() + fetched_sequences = await filter_sequences_by_risk(session, fetched_sequences) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] @@ -188,6 +190,7 @@ async def fetch_sequences_from_date( .offset(offset) ) ).all() + fetched_sequences = await filter_sequences_by_risk(session, fetched_sequences) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index a46b1392..6ad29ce9 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -7,17 +7,22 @@ import logging import re from ast import literal_eval -from typing import Any, Dict, Iterable, List, Union, cast +from typing import Any, Dict, Iterable, List, Sequence as TypingSequence, Union, cast from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -from app.models import Detection +from app.models import Detection, Sequence from app.schemas.detections import BOX_PATTERN +from app.services.risk import risk_service logger = logging.getLogger("uvicorn.error") -__all__ = ["max_conf_from_bboxes", "get_max_conf_by_sequence_ids"] +__all__ = [ + "max_conf_from_bboxes", + "get_max_conf_by_sequence_ids", + "filter_sequences_by_risk", +] def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: @@ -70,3 +75,33 @@ async def get_max_conf_by_sequence_ids( if current is None or conf > current: out[sequence_id] = conf return out + + +async def filter_sequences_by_risk( + session: AsyncSession, + sequences: TypingSequence[Sequence], +) -> List[Sequence]: + """Drop sequences whose max conf is below the risk-driven threshold for their camera. + + Fail-open: a sequence is kept if either the camera has no FWI score (moderate+ or unknown) + or the sequence has no parseable confidence. + """ + if not sequences: + return [] + thresholds = {seq.camera_id: risk_service.min_confidence(seq.camera_id) for seq in sequences} + if all(threshold is None for threshold in thresholds.values()): + return list(sequences) + + seq_ids_to_check = [seq.id for seq in sequences if thresholds.get(seq.camera_id) is not None] + confs = await get_max_conf_by_sequence_ids(session, seq_ids_to_check) + + kept: List[Sequence] = [] + for seq in sequences: + threshold = thresholds.get(seq.camera_id) + if threshold is None: + kept.append(seq) + continue + conf = confs.get(seq.id) + if conf is None or conf >= threshold: + kept.append(seq) + return kept From a003b895334330a643e7a8e6613c3e7eab74976d Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:06:10 +0200 Subject: [PATCH 04/36] feat: skip slack alert when sequence max conf below risk threshold --- src/app/api/api_v1/endpoints/detections.py | 30 ++++++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/app/api/api_v1/endpoints/detections.py b/src/app/api/api_v1/endpoints/detections.py index 60ae87ed..2ecfaf55 100644 --- a/src/app/api/api_v1/endpoints/detections.py +++ b/src/app/api/api_v1/endpoints/detections.py @@ -5,6 +5,7 @@ import json +import logging import re from ast import literal_eval from datetime import datetime, timedelta @@ -56,11 +57,15 @@ from app.schemas.sequences import SequenceUpdate from app.services.cones import resolve_cone from app.services.overlap import compute_overlap, haversine_km +from app.services.risk import risk_service +from app.services.sequence_confidence import max_conf_from_bboxes from app.services.slack import slack_client from app.services.storage import s3_service, upload_file from app.services.telegram import telegram_client from app.services.telemetry import telemetry_client +logger = logging.getLogger("uvicorn.error") + router = APIRouter() @@ -490,11 +495,26 @@ async def create_detection( if org is None: org = cast(Organization, await organizations.get(token_payload.organization_id, strict=True)) if org.slack_hook: - slack_payload = jsonable_encoder(det) - slack_payload["sequence_azimuth"] = sequence_.sequence_azimuth - background_tasks.add_task( - slack_client.notify, org.slack_hook, json.dumps(slack_payload), camera.name, alert_id - ) + min_conf = risk_service.min_confidence(camera.id) + seq_max_conf: Optional[float] = None + if min_conf is not None: + seq_max_conf = max_conf_from_bboxes( + *[d.bbox for d in overlapping_dets], + *[d.others_bboxes for d in overlapping_dets], + ) + if min_conf is None or seq_max_conf is None or seq_max_conf >= min_conf: + slack_payload = jsonable_encoder(det) + slack_payload["sequence_azimuth"] = sequence_.sequence_azimuth + background_tasks.add_task( + slack_client.notify, org.slack_hook, json.dumps(slack_payload), camera.name, alert_id + ) + else: + logger.info( + "Skipping Slack notification for camera %s: max conf %.3f < threshold %.3f", + camera.name, + seq_max_conf, + min_conf, + ) created.append(det) From 7e3650f7158f2d9f7f5acc2002c147f022ad0d5b Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:08:15 +0200 Subject: [PATCH 05/36] feat: query risk-api per date for from_date endpoints --- src/app/api/api_v1/endpoints/alerts.py | 16 +++++-- src/app/api/api_v1/endpoints/sequences.py | 4 +- src/app/services/risk.py | 54 ++++++++++++++++------- src/app/services/sequence_confidence.py | 45 ++++++++++++++----- 4 files changed, 88 insertions(+), 31 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 54cbf019..9731300c 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -20,7 +20,7 @@ from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead -from app.services.sequence_confidence import filter_sequences_by_risk +from app.services.sequence_confidence import filter_sequences_by_risk, filter_sequences_by_risk_for_date from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.telemetry import telemetry_client @@ -52,10 +52,18 @@ async def _apply_risk_filter_to_alerts( session: AsyncSession, alerts: List[Alert], seq_map: Dict[int, List[Sequence]], + target_date: Union[date, None] = None, ) -> List[Alert]: - """Drop sequences below the risk threshold and alerts that end up empty.""" + """Drop sequences below the risk threshold and alerts that end up empty. + + When ``target_date`` is provided, look up the FWI class persisted for that day; + otherwise use today's cached value. + """ all_sequences = [seq for seqs in seq_map.values() for seq in seqs] - kept_seqs = await filter_sequences_by_risk(session, all_sequences) + if target_date is None: + kept_seqs = await filter_sequences_by_risk(session, all_sequences) + else: + kept_seqs = await filter_sequences_by_risk_for_date(session, all_sequences, target_date) kept_ids = {seq.id for seq in kept_seqs} kept_alerts: List[Alert] = [] for alert in alerts: @@ -180,7 +188,7 @@ async def fetch_alerts_from_date( alerts = list(alerts_res.all()) alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - alerts = await _apply_risk_filter_to_alerts(session, alerts, seq_map) + alerts = await _apply_risk_filter_to_alerts(session, alerts, seq_map, target_date=from_date) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index 5e703cf3..744f3d32 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -22,7 +22,7 @@ from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel, SequenceRead from app.services.overlap import compute_overlap -from app.services.sequence_confidence import filter_sequences_by_risk +from app.services.sequence_confidence import filter_sequences_by_risk, filter_sequences_by_risk_for_date from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.storage import s3_service from app.services.telemetry import telemetry_client @@ -190,7 +190,7 @@ async def fetch_sequences_from_date( .offset(offset) ) ).all() - fetched_sequences = await filter_sequences_by_risk(session, fetched_sequences) + fetched_sequences = await filter_sequences_by_risk_for_date(session, fetched_sequences, from_date) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] diff --git a/src/app/services/risk.py b/src/app/services/risk.py index 7ce9a20e..d4334860 100644 --- a/src/app/services/risk.py +++ b/src/app/services/risk.py @@ -4,6 +4,7 @@ # See LICENSE or go to for full license details. import logging +from datetime import date from typing import Union import httpx @@ -12,13 +13,22 @@ logger = logging.getLogger("uvicorn.error") -__all__ = ["risk_service"] +__all__ = ["risk_service", "min_confidence_for_class"] # EFFIS classes that should trigger filtering. Anything else (moderate+) → no filter. _LOW = "low" _VERY_LOW = "very_low" +def min_confidence_for_class(fwi_class: Union[str, None]) -> Union[float, None]: + """Return the min confidence required for this FWI class, or None if no filter applies.""" + if fwi_class == _VERY_LOW: + return settings.FWI_VERY_LOW_MIN_CONF + if fwi_class == _LOW: + return settings.FWI_LOW_MIN_CONF + return None + + class RiskService: """In-memory cache of the daily FWI class per camera. @@ -34,30 +44,31 @@ def is_configured(self) -> bool: return bool(settings.RISK_API_HOST and settings.RISK_API_USERNAME and settings.RISK_API_PASSWORD) def min_confidence(self, camera_id: int) -> Union[float, None]: - """Return the min confidence required for this camera, or None if no filter applies.""" - fwi_class = self._scores.get(camera_id) - if fwi_class == _VERY_LOW: - return settings.FWI_VERY_LOW_MIN_CONF - if fwi_class == _LOW: - return settings.FWI_LOW_MIN_CONF - return None + """Return the min confidence required for this camera (today), or None if no filter.""" + return min_confidence_for_class(self._scores.get(camera_id)) - async def refresh(self) -> None: - """Fetch fresh FWI classes from the risk API. On error, keep the previous cache.""" + async def _fetch(self, path: str, params: Union[dict, None] = None) -> Union[list, dict, None]: if not self.is_configured: - return + return None host = settings.RISK_API_HOST.rstrip("/") # type: ignore[union-attr] - url = f"{host}/cameras" try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get( - url, + f"{host}/{path.lstrip('/')}", + params=params, auth=(settings.RISK_API_USERNAME, settings.RISK_API_PASSWORD), # type: ignore[arg-type] ) response.raise_for_status() - payload = response.json() + return response.json() except (httpx.HTTPError, ValueError) as exc: - logger.warning("Risk API refresh failed (%s); keeping previous cache (%d entries)", exc, len(self._scores)) + logger.warning("Risk API call %s failed (%s)", path, exc) + return None + + async def refresh(self) -> None: + """Fetch fresh FWI classes from the risk API. On error, keep the previous cache.""" + payload = await self._fetch("cameras") + if not isinstance(payload, list): + logger.warning("Risk API refresh: keeping previous cache (%d entries)", len(self._scores)) return scores: dict[int, str] = {} @@ -69,5 +80,18 @@ async def refresh(self) -> None: self._scores = scores logger.info("Risk API refresh: cached FWI class for %d camera(s)", len(scores)) + async def get_scores_for_date(self, target_date: date) -> dict[int, str]: + """Fetch persisted FWI classes for a specific date. Returns {} on error or when not configured.""" + payload = await self._fetch(f"scores/{target_date.isoformat()}") + if not isinstance(payload, list): + return {} + scores: dict[int, str] = {} + for item in payload: + camera_id = item.get("id") or item.get("camera_id") + fwi_class = item.get("fwi_class") + if isinstance(camera_id, int) and isinstance(fwi_class, str): + scores[camera_id] = fwi_class + return scores + risk_service = RiskService() diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 6ad29ce9..28a1536c 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -7,6 +7,7 @@ import logging import re from ast import literal_eval +from datetime import date from typing import Any, Dict, Iterable, List, Sequence as TypingSequence, Union, cast from sqlmodel import select @@ -14,7 +15,7 @@ from app.models import Detection, Sequence from app.schemas.detections import BOX_PATTERN -from app.services.risk import risk_service +from app.services.risk import min_confidence_for_class, risk_service logger = logging.getLogger("uvicorn.error") @@ -22,6 +23,7 @@ "max_conf_from_bboxes", "get_max_conf_by_sequence_ids", "filter_sequences_by_risk", + "filter_sequences_by_risk_for_date", ] @@ -77,18 +79,11 @@ async def get_max_conf_by_sequence_ids( return out -async def filter_sequences_by_risk( +async def _filter_sequences( session: AsyncSession, sequences: TypingSequence[Sequence], + thresholds: Dict[int, Union[float, None]], ) -> List[Sequence]: - """Drop sequences whose max conf is below the risk-driven threshold for their camera. - - Fail-open: a sequence is kept if either the camera has no FWI score (moderate+ or unknown) - or the sequence has no parseable confidence. - """ - if not sequences: - return [] - thresholds = {seq.camera_id: risk_service.min_confidence(seq.camera_id) for seq in sequences} if all(threshold is None for threshold in thresholds.values()): return list(sequences) @@ -105,3 +100,33 @@ async def filter_sequences_by_risk( if conf is None or conf >= threshold: kept.append(seq) return kept + + +async def filter_sequences_by_risk( + session: AsyncSession, + sequences: TypingSequence[Sequence], +) -> List[Sequence]: + """Drop sequences whose max conf is below today's risk-driven threshold for their camera. + + Fail-open: a sequence is kept if either the camera has no FWI score (moderate+ or unknown) + or the sequence has no parseable confidence. + """ + if not sequences: + return [] + thresholds = {seq.camera_id: risk_service.min_confidence(seq.camera_id) for seq in sequences} + return await _filter_sequences(session, sequences, thresholds) + + +async def filter_sequences_by_risk_for_date( + session: AsyncSession, + sequences: TypingSequence[Sequence], + target_date: date, +) -> List[Sequence]: + """Like filter_sequences_by_risk, but uses the FWI class persisted for a specific date.""" + if not sequences: + return [] + scores = await risk_service.get_scores_for_date(target_date) + thresholds = { + seq.camera_id: min_confidence_for_class(scores.get(seq.camera_id)) for seq in sequences + } + return await _filter_sequences(session, sequences, thresholds) From dd2921d2795486a9cbd880899d51ac5fed5668fd Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:10:45 +0200 Subject: [PATCH 06/36] test: cover risk-driven filtering of alerts and sequences --- src/tests/endpoints/test_risk_filter.py | 154 ++++++++++++++++++ src/tests/services/test_risk.py | 38 +++++ .../services/test_sequence_confidence.py | 27 +++ 3 files changed, 219 insertions(+) create mode 100644 src/tests/endpoints/test_risk_filter.py create mode 100644 src/tests/services/test_risk.py create mode 100644 src/tests/services/test_sequence_confidence.py diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py new file mode 100644 index 00000000..962889b3 --- /dev/null +++ b/src/tests/endpoints/test_risk_filter.py @@ -0,0 +1,154 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from datetime import timedelta + +import pytest # type: ignore +from httpx import AsyncClient +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.time import utcnow +from app.models import Alert, AlertSequence, Detection, Sequence +from app.services.risk import risk_service + + +@pytest.fixture +def reset_risk_cache(): + previous = dict(risk_service._scores) + risk_service._scores = {} + yield + risk_service._scores = previous + + +async def _seed_unlabeled_sequence( + session: AsyncSession, + camera_id: int, + pose_id: int, + bbox: str, + minutes_ago: int = 30, +) -> Sequence: + now = utcnow() + seq = Sequence( + camera_id=camera_id, + pose_id=pose_id, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=minutes_ago), + last_seen_at=now - timedelta(minutes=minutes_ago - 1), + ) + session.add(seq) + await session.commit() + await session.refresh(seq) + session.add( + Detection( + camera_id=camera_id, + pose_id=pose_id, + sequence_id=seq.id, + bucket_key=f"risk-test-{seq.id}.jpg", + bbox=bbox, + others_bboxes=None, + created_at=now - timedelta(minutes=minutes_ago - 1), + ) + ) + await session.commit() + return seq + + +@pytest.mark.asyncio +async def test_unlabeled_latest_drops_low_conf_when_camera_is_low_risk( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.40)]", 30) + high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.55)]", 20) + + risk_service._scores = {camera_id: "low"} + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + returned_ids = {item["id"] for item in response.json()} + assert low_seq.id not in returned_ids + assert high_seq.id in returned_ids + + +@pytest.mark.asyncio +async def test_unlabeled_latest_drops_below_very_low_threshold( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + # 0.55 passes the low threshold (0.45) but fails very_low (0.6) + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.55)]", 25) + + risk_service._scores = {camera_id: "very_low"} + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert seq.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_unlabeled_latest_keeps_all_when_class_is_moderate( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.10)]", 30) + + risk_service._scores = {camera_id: "moderate"} + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert low_seq.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_drops_alert_when_all_seqs_below_threshold( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] # belongs to org 2 (user_idx 2) + pose_id = pytest.pose_table[2]["id"] + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.30)]", 20) + + now = utcnow() + alert = Alert( + organization_id=2, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=19), + ) + detection_session.add(alert) + await detection_session.commit() + await detection_session.refresh(alert) + detection_session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id)) + await detection_session.commit() + + risk_service._scores = {camera_id: "low"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id not in {item["id"] for item in response.json()} diff --git a/src/tests/services/test_risk.py b/src/tests/services/test_risk.py new file mode 100644 index 00000000..65b126ff --- /dev/null +++ b/src/tests/services/test_risk.py @@ -0,0 +1,38 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import pytest + +from app.core.config import settings +from app.services.risk import RiskService, min_confidence_for_class + + +def test_min_confidence_for_class(): + assert min_confidence_for_class("very_low") == settings.FWI_VERY_LOW_MIN_CONF + assert min_confidence_for_class("low") == settings.FWI_LOW_MIN_CONF + assert min_confidence_for_class("moderate") is None + assert min_confidence_for_class("high") is None + assert min_confidence_for_class("very_high") is None + assert min_confidence_for_class("extreme") is None + assert min_confidence_for_class(None) is None + assert min_confidence_for_class("unexpected") is None + + +def test_risk_service_min_confidence_uses_cached_class(): + service = RiskService() + service._scores = {1: "very_low", 2: "low", 3: "moderate"} + assert service.min_confidence(1) == settings.FWI_VERY_LOW_MIN_CONF + assert service.min_confidence(2) == settings.FWI_LOW_MIN_CONF + assert service.min_confidence(3) is None + assert service.min_confidence(99) is None + + +@pytest.mark.asyncio +async def test_refresh_no_op_when_not_configured(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(settings, "RISK_API_HOST", None) + service = RiskService() + service._scores = {1: "low"} + await service.refresh() + assert service._scores == {1: "low"} diff --git a/src/tests/services/test_sequence_confidence.py b/src/tests/services/test_sequence_confidence.py new file mode 100644 index 00000000..0bbd3e9c --- /dev/null +++ b/src/tests/services/test_sequence_confidence.py @@ -0,0 +1,27 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from app.services.sequence_confidence import max_conf_from_bboxes + + +def test_max_conf_from_single_bbox(): + assert max_conf_from_bboxes("[(.1,.1,.7,.8,.42)]") == 0.42 + + +def test_max_conf_picks_max_across_bboxes_and_strings(): + bbox = "[(.1,.1,.7,.8,.3)]" + others = "[(.1,.1,.5,.5,.7),(.2,.2,.6,.6,.55)]" + assert max_conf_from_bboxes(bbox, others) == 0.7 + + +def test_max_conf_handles_none_and_empty(): + assert max_conf_from_bboxes(None) is None + assert max_conf_from_bboxes("") is None + assert max_conf_from_bboxes(None, None) is None + + +def test_max_conf_skips_unparseable(): + assert max_conf_from_bboxes("[invalid]") is None + assert max_conf_from_bboxes("[(.1,.1,.7,.8,.5)]", "[garbage]") == 0.5 From 71da13255fa0402d7d11316d3b762e128807129f Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:17:18 +0200 Subject: [PATCH 07/36] chore: apply ruff fixes --- src/app/main.py | 5 ++--- src/app/services/risk.py | 2 +- src/app/services/sequence_confidence.py | 11 +++++------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/app/main.py b/src/app/main.py index ac4fa8ed..5a7a76f9 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -4,6 +4,7 @@ # See LICENSE or go to for full license details. import asyncio +import contextlib import logging import time from contextlib import asynccontextmanager @@ -76,10 +77,8 @@ async def lifespan(_: FastAPI): finally: if task is not None: task.cancel() - try: + with contextlib.suppress(asyncio.CancelledError): await task - except (asyncio.CancelledError, Exception): - pass app = FastAPI( diff --git a/src/app/services/risk.py b/src/app/services/risk.py index d4334860..016fd822 100644 --- a/src/app/services/risk.py +++ b/src/app/services/risk.py @@ -13,7 +13,7 @@ logger = logging.getLogger("uvicorn.error") -__all__ = ["risk_service", "min_confidence_for_class"] +__all__ = ["min_confidence_for_class", "risk_service"] # EFFIS classes that should trigger filtering. Anything else (moderate+) → no filter. _LOW = "low" diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 28a1536c..a4d4439e 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -8,7 +8,8 @@ import re from ast import literal_eval from datetime import date -from typing import Any, Dict, Iterable, List, Sequence as TypingSequence, Union, cast +from typing import Any, Dict, Iterable, List, Union, cast +from typing import Sequence as TypingSequence from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -20,10 +21,10 @@ logger = logging.getLogger("uvicorn.error") __all__ = [ - "max_conf_from_bboxes", - "get_max_conf_by_sequence_ids", "filter_sequences_by_risk", "filter_sequences_by_risk_for_date", + "get_max_conf_by_sequence_ids", + "max_conf_from_bboxes", ] @@ -126,7 +127,5 @@ async def filter_sequences_by_risk_for_date( if not sequences: return [] scores = await risk_service.get_scores_for_date(target_date) - thresholds = { - seq.camera_id: min_confidence_for_class(scores.get(seq.camera_id)) for seq in sequences - } + thresholds = {seq.camera_id: min_confidence_for_class(scores.get(seq.camera_id)) for seq in sequences} return await _filter_sequences(session, sequences, thresholds) From c48824d7d0a25270060d2d7cb56df01e2c992291 Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:27:27 +0200 Subject: [PATCH 08/36] fix: clamp risk-refresh hour to 0..23 to avoid retry loop --- src/app/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/app/main.py b/src/app/main.py index 5a7a76f9..b20bd155 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -46,8 +46,9 @@ def _seconds_until_next_utc_hour(target_hour: int) -> float: + hour = max(0, min(23, target_hour)) now = datetime.now(tz=timezone.utc) - target = now.replace(hour=target_hour, minute=0, second=0, microsecond=0) + target = now.replace(hour=hour, minute=0, second=0, microsecond=0) if target <= now: target += timedelta(days=1) return (target - now).total_seconds() From a1cc3e2b5de95f9580618793327ca4a3d1544303 Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:27:31 +0200 Subject: [PATCH 09/36] feat: scope from_date risk lookup to caller organization, normalize fwi class casing --- src/app/api/api_v1/endpoints/alerts.py | 13 +++++++++---- src/app/api/api_v1/endpoints/sequences.py | 4 +++- src/app/services/risk.py | 23 +++++++++++++++-------- src/app/services/sequence_confidence.py | 8 ++++++-- src/tests/services/test_risk.py | 8 +++++++- 5 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 9731300c..0dd4bd65 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -53,17 +53,20 @@ async def _apply_risk_filter_to_alerts( alerts: List[Alert], seq_map: Dict[int, List[Sequence]], target_date: Union[date, None] = None, + organization_id: Union[int, None] = None, ) -> List[Alert]: """Drop sequences below the risk threshold and alerts that end up empty. - When ``target_date`` is provided, look up the FWI class persisted for that day; - otherwise use today's cached value. + When ``target_date`` is provided, look up the FWI class persisted for that day + (scoped to ``organization_id`` if given); otherwise use today's cached value. """ all_sequences = [seq for seqs in seq_map.values() for seq in seqs] if target_date is None: kept_seqs = await filter_sequences_by_risk(session, all_sequences) else: - kept_seqs = await filter_sequences_by_risk_for_date(session, all_sequences, target_date) + kept_seqs = await filter_sequences_by_risk_for_date( + session, all_sequences, target_date, organization_id=organization_id + ) kept_ids = {seq.id for seq in kept_seqs} kept_alerts: List[Alert] = [] for alert in alerts: @@ -188,7 +191,9 @@ async def fetch_alerts_from_date( alerts = list(alerts_res.all()) alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - alerts = await _apply_risk_filter_to_alerts(session, alerts, seq_map, target_date=from_date) + alerts = await _apply_risk_filter_to_alerts( + session, alerts, seq_map, target_date=from_date, organization_id=token_payload.organization_id + ) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index 744f3d32..4a3f5b45 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -190,7 +190,9 @@ async def fetch_sequences_from_date( .offset(offset) ) ).all() - fetched_sequences = await filter_sequences_by_risk_for_date(session, fetched_sequences, from_date) + fetched_sequences = await filter_sequences_by_risk_for_date( + session, fetched_sequences, from_date, organization_id=token_payload.organization_id + ) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] diff --git a/src/app/services/risk.py b/src/app/services/risk.py index 016fd822..cebcfd0c 100644 --- a/src/app/services/risk.py +++ b/src/app/services/risk.py @@ -22,9 +22,12 @@ def min_confidence_for_class(fwi_class: Union[str, None]) -> Union[float, None]: """Return the min confidence required for this FWI class, or None if no filter applies.""" - if fwi_class == _VERY_LOW: + if not fwi_class: + return None + normalized = fwi_class.strip().lower().replace(" ", "_") + if normalized == _VERY_LOW: return settings.FWI_VERY_LOW_MIN_CONF - if fwi_class == _LOW: + if normalized == _LOW: return settings.FWI_LOW_MIN_CONF return None @@ -41,7 +44,7 @@ def __init__(self) -> None: @property def is_configured(self) -> bool: - return bool(settings.RISK_API_HOST and settings.RISK_API_USERNAME and settings.RISK_API_PASSWORD) + return bool(settings.RISK_API_URL and settings.RISK_API_LOGIN and settings.RISK_API_PWD) def min_confidence(self, camera_id: int) -> Union[float, None]: """Return the min confidence required for this camera (today), or None if no filter.""" @@ -50,13 +53,13 @@ def min_confidence(self, camera_id: int) -> Union[float, None]: async def _fetch(self, path: str, params: Union[dict, None] = None) -> Union[list, dict, None]: if not self.is_configured: return None - host = settings.RISK_API_HOST.rstrip("/") # type: ignore[union-attr] + host = settings.RISK_API_URL.rstrip("/") # type: ignore[union-attr] try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get( f"{host}/{path.lstrip('/')}", params=params, - auth=(settings.RISK_API_USERNAME, settings.RISK_API_PASSWORD), # type: ignore[arg-type] + auth=(settings.RISK_API_LOGIN, settings.RISK_API_PWD), # type: ignore[arg-type] ) response.raise_for_status() return response.json() @@ -80,9 +83,13 @@ async def refresh(self) -> None: self._scores = scores logger.info("Risk API refresh: cached FWI class for %d camera(s)", len(scores)) - async def get_scores_for_date(self, target_date: date) -> dict[int, str]: - """Fetch persisted FWI classes for a specific date. Returns {} on error or when not configured.""" - payload = await self._fetch(f"scores/{target_date.isoformat()}") + async def get_scores_for_date(self, target_date: date, organization_id: Union[int, None] = None) -> dict[int, str]: + """Fetch persisted FWI classes for a specific date, optionally scoped to an organization. + + Returns {} on error or when not configured. + """ + params: Union[dict, None] = {"organization_id": organization_id} if organization_id is not None else None + payload = await self._fetch(f"scores/{target_date.isoformat()}", params=params) if not isinstance(payload, list): return {} scores: dict[int, str] = {} diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index a4d4439e..315bb154 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -122,10 +122,14 @@ async def filter_sequences_by_risk_for_date( session: AsyncSession, sequences: TypingSequence[Sequence], target_date: date, + organization_id: Union[int, None] = None, ) -> List[Sequence]: - """Like filter_sequences_by_risk, but uses the FWI class persisted for a specific date.""" + """Like filter_sequences_by_risk, but uses the FWI class persisted for a specific date. + + When ``organization_id`` is provided, the risk-api call is scoped to that organization. + """ if not sequences: return [] - scores = await risk_service.get_scores_for_date(target_date) + scores = await risk_service.get_scores_for_date(target_date, organization_id=organization_id) thresholds = {seq.camera_id: min_confidence_for_class(scores.get(seq.camera_id)) for seq in sequences} return await _filter_sequences(session, sequences, thresholds) diff --git a/src/tests/services/test_risk.py b/src/tests/services/test_risk.py index 65b126ff..49af7d45 100644 --- a/src/tests/services/test_risk.py +++ b/src/tests/services/test_risk.py @@ -20,6 +20,12 @@ def test_min_confidence_for_class(): assert min_confidence_for_class("unexpected") is None +def test_min_confidence_for_class_normalizes_casing(): + assert min_confidence_for_class("Very Low") == settings.FWI_VERY_LOW_MIN_CONF + assert min_confidence_for_class("LOW") == settings.FWI_LOW_MIN_CONF + assert min_confidence_for_class(" very_low ") == settings.FWI_VERY_LOW_MIN_CONF + + def test_risk_service_min_confidence_uses_cached_class(): service = RiskService() service._scores = {1: "very_low", 2: "low", 3: "moderate"} @@ -31,7 +37,7 @@ def test_risk_service_min_confidence_uses_cached_class(): @pytest.mark.asyncio async def test_refresh_no_op_when_not_configured(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(settings, "RISK_API_HOST", None) + monkeypatch.setattr(settings, "RISK_API_URL", None) service = RiskService() service._scores = {1: "low"} await service.refresh() From 8750cbfaa14ffad01b3cf2451c8ef9d95e94668e Mon Sep 17 00:00:00 2001 From: Mateo Date: Mon, 4 May 2026 20:27:34 +0200 Subject: [PATCH 10/36] chore: rename risk-api env vars to RISK_API_URL/LOGIN/PWD --- .env.example | 6 +++--- docker-compose.yml | 6 ++++++ src/app/core/config.py | 6 +++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.env.example b/.env.example index 5b139946..494b3ab6 100644 --- a/.env.example +++ b/.env.example @@ -24,9 +24,9 @@ SUPPORT_EMAIL= TELEGRAM_TOKEN= # Risk API (daily fire-weather index per camera) -RISK_API_HOST= -RISK_API_USERNAME= -RISK_API_PASSWORD= +RISK_API_URL= +RISK_API_LOGIN= +RISK_API_PWD= RISK_REFRESH_HOUR_UTC=4 FWI_LOW_MIN_CONF=0.45 FWI_VERY_LOW_MIN_CONF=0.6 diff --git a/docker-compose.yml b/docker-compose.yml index 6088b9d1..7405255f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -65,6 +65,12 @@ services: - S3_PROXY_URL=${S3_PROXY_URL} - SERVER_NAME=${SERVER_NAME} - PLATFORM_URL=${PLATFORM_URL:-https://platform.pyronear.org} + - RISK_API_URL=${RISK_API_URL} + - RISK_API_LOGIN=${RISK_API_LOGIN} + - RISK_API_PWD=${RISK_API_PWD} + - RISK_REFRESH_HOUR_UTC=${RISK_REFRESH_HOUR_UTC:-4} + - FWI_LOW_MIN_CONF=${FWI_LOW_MIN_CONF:-0.45} + - FWI_VERY_LOW_MIN_CONF=${FWI_VERY_LOW_MIN_CONF:-0.6} volumes: - ./src/:/app/ command: "sh -c 'alembic upgrade head && python app/db.py && uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers'" diff --git a/src/app/core/config.py b/src/app/core/config.py index e14e0977..7c0ff1b2 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -78,9 +78,9 @@ def sqlachmey_uri(cls, v: str) -> str: PLATFORM_URL: str = os.environ.get("PLATFORM_URL", "") # Risk API (daily fire-weather index per camera) - RISK_API_HOST: Union[str, None] = os.environ.get("RISK_API_HOST") - RISK_API_USERNAME: Union[str, None] = os.environ.get("RISK_API_USERNAME") - RISK_API_PASSWORD: Union[str, None] = os.environ.get("RISK_API_PASSWORD") + RISK_API_URL: Union[str, None] = os.environ.get("RISK_API_URL") + RISK_API_LOGIN: Union[str, None] = os.environ.get("RISK_API_LOGIN") + RISK_API_PWD: Union[str, None] = os.environ.get("RISK_API_PWD") RISK_REFRESH_HOUR_UTC: int = int(os.environ.get("RISK_REFRESH_HOUR_UTC") or 4) FWI_LOW_MIN_CONF: float = float(os.environ.get("FWI_LOW_MIN_CONF") or 0.45) FWI_VERY_LOW_MIN_CONF: float = float(os.environ.get("FWI_VERY_LOW_MIN_CONF") or 0.6) From 9d56748b774f2b9c75f45930ba76b8bd94c5ea30 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:01:46 +0200 Subject: [PATCH 11/36] feat: add max_conf column to sequences with backfill migration --- src/app/models.py | 3 + ...-b3d8a9c1e2f4_add_max_conf_to_sequences.py | 75 +++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py diff --git a/src/app/models.py b/src/app/models.py index c9eed1b1..04e231d1 100644 --- a/src/app/models.py +++ b/src/app/models.py @@ -104,6 +104,9 @@ class Sequence(SQLModel, table=True): cone_angle: Union[float, None] = Field(None, nullable=True) started_at: datetime = Field(..., nullable=False) last_seen_at: datetime = Field(..., nullable=False) + # Highest detection confidence ever attached to this sequence. + # Monotonic: never recomputed downward when detections are deleted/reassigned. + max_conf: Union[float, None] = Field(None, nullable=True) class Alert(SQLModel, table=True): diff --git a/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py new file mode 100644 index 00000000..9d77dcb7 --- /dev/null +++ b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py @@ -0,0 +1,75 @@ +"""add max_conf column to sequences and backfill from detections + +Revision ID: b3d8a9c1e2f4 +Revises: a1b2c3d4e5f6 +Create Date: 2026-05-05 09:30:00.000000 + +""" + +import logging +import re +from ast import literal_eval +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b3d8a9c1e2f4" +down_revision: Union[str, None] = "a1b2c3d4e5f6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger("alembic.runtime.migration") + +_FLOAT = r"-?\d+(?:\.\d+)?|-?\.\d+" +_BOX_PATTERN = rf"\({_FLOAT},{_FLOAT},{_FLOAT},{_FLOAT},{_FLOAT}\)" + + +def _max_conf(*bbox_strings: Union[str, None]) -> Union[float, None]: + best: Union[float, None] = None + for raw in bbox_strings: + if not raw: + continue + for match in re.finditer(_BOX_PATTERN, raw): + try: + bbox = literal_eval(match.group(0)) + except (SyntaxError, ValueError): + continue + if not isinstance(bbox, tuple) or len(bbox) != 5: + continue + conf = bbox[4] + if not isinstance(conf, (int, float)): + continue + if best is None or conf > best: + best = float(conf) + return best + + +def upgrade() -> None: + op.add_column("sequences", sa.Column("max_conf", sa.Float(), nullable=True)) + + bind = op.get_bind() + rows = bind.execute( + sa.text("SELECT sequence_id, bbox, others_bboxes FROM detections WHERE sequence_id IS NOT NULL") + ).fetchall() + + seq_max: dict[int, float] = {} + for sequence_id, bbox, others in rows: + conf = _max_conf(bbox, others) + if conf is None: + continue + current = seq_max.get(sequence_id) + if current is None or conf > current: + seq_max[sequence_id] = conf + + if seq_max: + bind.execute( + sa.text("UPDATE sequences SET max_conf = :conf WHERE id = :sid"), + [{"sid": sid, "conf": conf} for sid, conf in seq_max.items()], + ) + logger.info("Backfilled max_conf for %d sequence(s)", len(seq_max)) + + +def downgrade() -> None: + op.drop_column("sequences", "max_conf") From d1e9a8ce1ec8e68917ea8dea38f8f33d76426f68 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:01:47 +0200 Subject: [PATCH 12/36] feat: maintain sequence max_conf at ingest with atomic update --- src/app/api/api_v1/endpoints/detections.py | 18 ++++++++++-------- src/app/crud/crud_sequence.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/app/api/api_v1/endpoints/detections.py b/src/app/api/api_v1/endpoints/detections.py index 2ecfaf55..48b3b835 100644 --- a/src/app/api/api_v1/endpoints/detections.py +++ b/src/app/api/api_v1/endpoints/detections.py @@ -432,6 +432,9 @@ async def create_detection( if matched_sequence is not None: await sequences.update(matched_sequence.id, SequenceUpdate(last_seen_at=det.created_at)) det = await detections.update(det.id, DetectionSequence(sequence_id=matched_sequence.id)) + det_max_conf = max_conf_from_bboxes(det.bbox, det.others_bboxes) + if det_max_conf is not None: + await sequences.bump_max_conf(matched_sequence.id, det_max_conf) else: det_filters: List[tuple[str, Any]] = [ ("camera_id", token_payload.sub), @@ -460,6 +463,10 @@ async def create_detection( if len(overlapping_dets) >= settings.SEQUENCE_MIN_INTERVAL_DETS: first_det = min(overlapping_dets, key=lambda item: item.created_at) cone_azimuth, cone_angle = resolve_cone(pose.azimuth, first_det.bbox, camera.angle_of_view) + seq_max_conf = max_conf_from_bboxes( + *[d.bbox for d in overlapping_dets], + *[d.others_bboxes for d in overlapping_dets], + ) sequence_ = await sequences.create( Sequence( camera_id=token_payload.sub, @@ -469,6 +476,7 @@ async def create_detection( cone_angle=cone_angle, started_at=first_det.created_at, last_seen_at=det.created_at, + max_conf=seq_max_conf, ) ) for det_ in overlapping_dets: @@ -496,13 +504,7 @@ async def create_detection( org = cast(Organization, await organizations.get(token_payload.organization_id, strict=True)) if org.slack_hook: min_conf = risk_service.min_confidence(camera.id) - seq_max_conf: Optional[float] = None - if min_conf is not None: - seq_max_conf = max_conf_from_bboxes( - *[d.bbox for d in overlapping_dets], - *[d.others_bboxes for d in overlapping_dets], - ) - if min_conf is None or seq_max_conf is None or seq_max_conf >= min_conf: + if min_conf is None or sequence_.max_conf is None or sequence_.max_conf >= min_conf: slack_payload = jsonable_encoder(det) slack_payload["sequence_azimuth"] = sequence_.sequence_azimuth background_tasks.add_task( @@ -512,7 +514,7 @@ async def create_detection( logger.info( "Skipping Slack notification for camera %s: max conf %.3f < threshold %.3f", camera.name, - seq_max_conf, + sequence_.max_conf, min_conf, ) diff --git a/src/app/crud/crud_sequence.py b/src/app/crud/crud_sequence.py index ed622e06..16c22efd 100644 --- a/src/app/crud/crud_sequence.py +++ b/src/app/crud/crud_sequence.py @@ -4,8 +4,9 @@ # See LICENSE or go to for full license details. -from typing import Union +from typing import Any, Union, cast +from sqlalchemy import func, update from sqlmodel.ext.asyncio.session import AsyncSession from app.crud.base import BaseCRUD @@ -18,3 +19,13 @@ class SequenceCRUD(BaseCRUD[Sequence, Sequence, Union[SequenceUpdate, SequenceLabel]]): def __init__(self, session: AsyncSession) -> None: super().__init__(session, Sequence) + + async def bump_max_conf(self, sequence_id: int, candidate: float) -> None: + """Atomically raise sequences.max_conf to candidate if higher (or set if NULL).""" + stmt: Any = ( + update(Sequence) + .where(cast(Any, Sequence.id) == sequence_id) + .values(max_conf=func.greatest(func.coalesce(Sequence.max_conf, candidate), candidate)) + ) + await self.session.exec(stmt) + await self.session.commit() From ce6f2d7e89a12cf94f8ea0c876343e8245005be1 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:01:49 +0200 Subject: [PATCH 13/36] refactor: read max_conf from sequence row instead of parsing detections --- src/app/api/api_v1/endpoints/alerts.py | 11 ++-- src/app/api/api_v1/endpoints/sequences.py | 4 +- src/app/services/sequence_confidence.py | 67 ++++------------------- 3 files changed, 16 insertions(+), 66 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 0dd4bd65..064478ee 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -49,7 +49,6 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i async def _apply_risk_filter_to_alerts( - session: AsyncSession, alerts: List[Alert], seq_map: Dict[int, List[Sequence]], target_date: Union[date, None] = None, @@ -62,11 +61,9 @@ async def _apply_risk_filter_to_alerts( """ all_sequences = [seq for seqs in seq_map.values() for seq in seqs] if target_date is None: - kept_seqs = await filter_sequences_by_risk(session, all_sequences) + kept_seqs = filter_sequences_by_risk(all_sequences) else: - kept_seqs = await filter_sequences_by_risk_for_date( - session, all_sequences, target_date, organization_id=organization_id - ) + kept_seqs = await filter_sequences_by_risk_for_date(all_sequences, target_date, organization_id=organization_id) kept_ids = {seq.id for seq in kept_seqs} kept_alerts: List[Alert] = [] for alert in alerts: @@ -161,7 +158,7 @@ async def fetch_latest_unlabeled_alerts( alerts = list(alerts_res.unique().all()) alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - alerts = await _apply_risk_filter_to_alerts(session, alerts, seq_map) + alerts = await _apply_risk_filter_to_alerts(alerts, seq_map) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), @@ -192,7 +189,7 @@ async def fetch_alerts_from_date( alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) alerts = await _apply_risk_filter_to_alerts( - session, alerts, seq_map, target_date=from_date, organization_id=token_payload.organization_id + alerts, seq_map, target_date=from_date, organization_id=token_payload.organization_id ) detection_counts = await get_detection_counts_by_sequence_ids( session, diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index 4a3f5b45..a92fa300 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -163,7 +163,7 @@ async def fetch_latest_unlabeled_sequences( .limit(15) ) ).all() - fetched_sequences = await filter_sequences_by_risk(session, fetched_sequences) + fetched_sequences = filter_sequences_by_risk(fetched_sequences) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] @@ -191,7 +191,7 @@ async def fetch_sequences_from_date( ) ).all() fetched_sequences = await filter_sequences_by_risk_for_date( - session, fetched_sequences, from_date, organization_id=token_payload.organization_id + fetched_sequences, from_date, organization_id=token_payload.organization_id ) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 315bb154..3927dd1f 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -8,13 +8,10 @@ import re from ast import literal_eval from datetime import date -from typing import Any, Dict, Iterable, List, Union, cast +from typing import Dict, List, Union from typing import Sequence as TypingSequence -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession - -from app.models import Detection, Sequence +from app.models import Sequence from app.schemas.detections import BOX_PATTERN from app.services.risk import min_confidence_for_class, risk_service @@ -23,7 +20,6 @@ __all__ = [ "filter_sequences_by_risk", "filter_sequences_by_risk_for_date", - "get_max_conf_by_sequence_ids", "max_conf_from_bboxes", ] @@ -49,82 +45,39 @@ def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: return best -async def get_max_conf_by_sequence_ids( - session: AsyncSession, - sequence_ids: Iterable[int], -) -> Dict[int, float]: - """Return {sequence_id: max_conf} computed from all detections of those sequences. - - Sequences with no detections or unparseable bboxes are omitted from the result — - callers should treat a missing key as "unknown" and fail open. - """ - seq_ids: List[int] = list({int(sid) for sid in sequence_ids}) - if not seq_ids: - return {} - - stmt: Any = select(Detection.sequence_id, Detection.bbox, Detection.others_bboxes).where( - cast(Any, Detection.sequence_id).in_(seq_ids) - ) - res = await session.exec(stmt) - - out: Dict[int, float] = {} - for sequence_id, bbox, others in res.all(): - if sequence_id is None: - continue - conf = max_conf_from_bboxes(bbox, others) - if conf is None: - continue - current = out.get(sequence_id) - if current is None or conf > current: - out[sequence_id] = conf - return out - - -async def _filter_sequences( - session: AsyncSession, +def _filter_sequences( sequences: TypingSequence[Sequence], thresholds: Dict[int, Union[float, None]], ) -> List[Sequence]: if all(threshold is None for threshold in thresholds.values()): return list(sequences) - seq_ids_to_check = [seq.id for seq in sequences if thresholds.get(seq.camera_id) is not None] - confs = await get_max_conf_by_sequence_ids(session, seq_ids_to_check) - kept: List[Sequence] = [] for seq in sequences: threshold = thresholds.get(seq.camera_id) - if threshold is None: - kept.append(seq) - continue - conf = confs.get(seq.id) - if conf is None or conf >= threshold: + if threshold is None or seq.max_conf is None or seq.max_conf >= threshold: kept.append(seq) return kept -async def filter_sequences_by_risk( - session: AsyncSession, - sequences: TypingSequence[Sequence], -) -> List[Sequence]: - """Drop sequences whose max conf is below today's risk-driven threshold for their camera. +def filter_sequences_by_risk(sequences: TypingSequence[Sequence]) -> List[Sequence]: + """Drop sequences whose stored ``max_conf`` is below today's risk-driven threshold for their camera. Fail-open: a sequence is kept if either the camera has no FWI score (moderate+ or unknown) - or the sequence has no parseable confidence. + or ``seq.max_conf`` is NULL. """ if not sequences: return [] thresholds = {seq.camera_id: risk_service.min_confidence(seq.camera_id) for seq in sequences} - return await _filter_sequences(session, sequences, thresholds) + return _filter_sequences(sequences, thresholds) async def filter_sequences_by_risk_for_date( - session: AsyncSession, sequences: TypingSequence[Sequence], target_date: date, organization_id: Union[int, None] = None, ) -> List[Sequence]: - """Like filter_sequences_by_risk, but uses the FWI class persisted for a specific date. + """Like ``filter_sequences_by_risk``, but uses the FWI class persisted for a specific date. When ``organization_id`` is provided, the risk-api call is scoped to that organization. """ @@ -132,4 +85,4 @@ async def filter_sequences_by_risk_for_date( return [] scores = await risk_service.get_scores_for_date(target_date, organization_id=organization_id) thresholds = {seq.camera_id: min_confidence_for_class(scores.get(seq.camera_id)) for seq in sequences} - return await _filter_sequences(session, sequences, thresholds) + return _filter_sequences(sequences, thresholds) From 507ed531a03be41d56711132d6b51b9efd44e4b3 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:01:49 +0200 Subject: [PATCH 14/36] test: seed max_conf directly on test sequences --- src/tests/conftest.py | 2 ++ src/tests/endpoints/test_risk_filter.py | 27 ++++++++----------------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 91851241..aef86314 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -199,6 +199,7 @@ "cone_angle": 54.8, "started_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), "last_seen_at": datetime.strptime("2023-11-07T15:28:19.226673", dt_format), + "max_conf": None, }, { "id": 2, @@ -210,6 +211,7 @@ "cone_angle": 54.8, "started_at": datetime.strptime("2023-11-07T16:08:19.226673", dt_format), "last_seen_at": datetime.strptime("2023-11-07T16:08:19.226673", dt_format), + "max_conf": None, }, ] diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 962889b3..284a4211 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -10,7 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from app.core.time import utcnow -from app.models import Alert, AlertSequence, Detection, Sequence +from app.models import Alert, AlertSequence, Sequence from app.services.risk import risk_service @@ -26,7 +26,7 @@ async def _seed_unlabeled_sequence( session: AsyncSession, camera_id: int, pose_id: int, - bbox: str, + max_conf: float, minutes_ago: int = 30, ) -> Sequence: now = utcnow() @@ -39,22 +39,11 @@ async def _seed_unlabeled_sequence( is_wildfire=None, started_at=now - timedelta(minutes=minutes_ago), last_seen_at=now - timedelta(minutes=minutes_ago - 1), + max_conf=max_conf, ) session.add(seq) await session.commit() await session.refresh(seq) - session.add( - Detection( - camera_id=camera_id, - pose_id=pose_id, - sequence_id=seq.id, - bucket_key=f"risk-test-{seq.id}.jpg", - bbox=bbox, - others_bboxes=None, - created_at=now - timedelta(minutes=minutes_ago - 1), - ) - ) - await session.commit() return seq @@ -64,8 +53,8 @@ async def test_unlabeled_latest_drops_low_conf_when_camera_is_low_risk( ): camera_id = pytest.camera_table[0]["id"] pose_id = pytest.pose_table[0]["id"] - low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.40)]", 30) - high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.55)]", 20) + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.40, minutes_ago=30) + high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.55, minutes_ago=20) risk_service._scores = {camera_id: "low"} @@ -88,7 +77,7 @@ async def test_unlabeled_latest_drops_below_very_low_threshold( camera_id = pytest.camera_table[0]["id"] pose_id = pytest.pose_table[0]["id"] # 0.55 passes the low threshold (0.45) but fails very_low (0.6) - seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.55)]", 25) + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.55, minutes_ago=25) risk_service._scores = {camera_id: "very_low"} @@ -108,7 +97,7 @@ async def test_unlabeled_latest_keeps_all_when_class_is_moderate( ): camera_id = pytest.camera_table[0]["id"] pose_id = pytest.pose_table[0]["id"] - low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.10)]", 30) + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=30) risk_service._scores = {camera_id: "moderate"} @@ -128,7 +117,7 @@ async def test_alerts_unlabeled_latest_drops_alert_when_all_seqs_below_threshold ): camera_id = pytest.camera_table[1]["id"] # belongs to org 2 (user_idx 2) pose_id = pytest.pose_table[2]["id"] - seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, "[(.1,.1,.7,.8,.30)]", 20) + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) now = utcnow() alert = Alert( From 104bdc9ccbf823d4b423e6a6c1bad147059fc6fc Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:05:37 +0200 Subject: [PATCH 15/36] fix: portable max_conf bump and validate fwi thresholds in [0,1] --- src/app/core/config.py | 7 +++++++ src/app/crud/crud_sequence.py | 15 +++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/app/core/config.py b/src/app/core/config.py index 7c0ff1b2..c407e33b 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -85,6 +85,13 @@ def sqlachmey_uri(cls, v: str) -> str: FWI_LOW_MIN_CONF: float = float(os.environ.get("FWI_LOW_MIN_CONF") or 0.45) FWI_VERY_LOW_MIN_CONF: float = float(os.environ.get("FWI_VERY_LOW_MIN_CONF") or 0.6) + @field_validator("FWI_LOW_MIN_CONF", "FWI_VERY_LOW_MIN_CONF") + @classmethod + def fwi_min_conf_in_unit_range(cls, v: float) -> float: + if not 0.0 <= v <= 1.0: + raise ValueError("FWI confidence thresholds must lie in [0, 1]") + return v + # Error monitoring SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN") SERVER_NAME: str = os.environ.get("SERVER_NAME", socket.gethostname()) diff --git a/src/app/crud/crud_sequence.py b/src/app/crud/crud_sequence.py index 16c22efd..a493f741 100644 --- a/src/app/crud/crud_sequence.py +++ b/src/app/crud/crud_sequence.py @@ -6,7 +6,7 @@ from typing import Any, Union, cast -from sqlalchemy import func, update +from sqlalchemy import case, or_, update from sqlmodel.ext.asyncio.session import AsyncSession from app.crud.base import BaseCRUD @@ -21,11 +21,14 @@ def __init__(self, session: AsyncSession) -> None: super().__init__(session, Sequence) async def bump_max_conf(self, sequence_id: int, candidate: float) -> None: - """Atomically raise sequences.max_conf to candidate if higher (or set if NULL).""" - stmt: Any = ( - update(Sequence) - .where(cast(Any, Sequence.id) == sequence_id) - .values(max_conf=func.greatest(func.coalesce(Sequence.max_conf, candidate), candidate)) + """Atomically raise sequences.max_conf to candidate if higher (or set if NULL). + + Uses a portable CASE expression so it runs on SQLite as well as Postgres. + """ + bumped = case( + (or_(Sequence.max_conf.is_(None), Sequence.max_conf < candidate), candidate), # type: ignore[union-attr] + else_=Sequence.max_conf, ) + stmt: Any = update(Sequence).where(cast(Any, Sequence.id) == sequence_id).values(max_conf=bumped) await self.session.exec(stmt) await self.session.commit() From b6a4e6b355a3d3f0d0538104723d999049be1b1f Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:26:15 +0200 Subject: [PATCH 16/36] feat: add risk_score query param to override fwi class on alerts endpoints --- src/app/api/api_v1/endpoints/alerts.py | 48 ++++++++++++-- src/app/services/sequence_confidence.py | 19 ++++++ src/tests/endpoints/test_risk_filter.py | 85 +++++++++++++++++++++---- 3 files changed, 134 insertions(+), 18 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 064478ee..d2e310a5 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -5,6 +5,7 @@ from datetime import date, timedelta +from enum import Enum from typing import Any, Dict, List, Union, cast from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status @@ -20,10 +21,26 @@ from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead -from app.services.sequence_confidence import filter_sequences_by_risk, filter_sequences_by_risk_for_date +from app.services.sequence_confidence import ( + filter_sequences_by_class, + filter_sequences_by_risk, + filter_sequences_by_risk_for_date, +) from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.telemetry import telemetry_client + +class FwiClass(str, Enum): + """FWI classes accepted as a manual ``risk_score`` override.""" + + very_low = "very_low" + low = "low" + moderate = "moderate" + high = "high" + very_high = "very_high" + extreme = "extreme" + + router = APIRouter() @@ -53,14 +70,19 @@ async def _apply_risk_filter_to_alerts( seq_map: Dict[int, List[Sequence]], target_date: Union[date, None] = None, organization_id: Union[int, None] = None, + override_class: Union[str, None] = None, ) -> List[Alert]: """Drop sequences below the risk threshold and alerts that end up empty. - When ``target_date`` is provided, look up the FWI class persisted for that day - (scoped to ``organization_id`` if given); otherwise use today's cached value. + When ``override_class`` is provided, that single FWI class is applied to every + sequence (no risk-api lookup). Otherwise: ``target_date`` triggers a per-date + risk-api call (scoped to ``organization_id`` if given); without it we use today's + cached value. """ all_sequences = [seq for seqs in seq_map.values() for seq in seqs] - if target_date is None: + if override_class is not None: + kept_seqs = filter_sequences_by_class(all_sequences, override_class) + elif target_date is None: kept_seqs = filter_sequences_by_risk(all_sequences) else: kept_seqs = await filter_sequences_by_risk_for_date(all_sequences, target_date, organization_id=organization_id) @@ -140,6 +162,10 @@ async def fetch_alert_sequences( summary="Fetch all the alerts with unlabeled sequences from the last 24 hours", ) async def fetch_latest_unlabeled_alerts( + risk_score: Union[FwiClass, None] = Query( + None, + description="Override FWI class applied to every sequence; bypasses risk-api lookup.", + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[AlertReadWithSequences]: @@ -158,7 +184,9 @@ async def fetch_latest_unlabeled_alerts( alerts = list(alerts_res.unique().all()) alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - alerts = await _apply_risk_filter_to_alerts(alerts, seq_map) + alerts = await _apply_risk_filter_to_alerts( + alerts, seq_map, override_class=risk_score.value if risk_score is not None else None + ) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), @@ -171,6 +199,10 @@ async def fetch_alerts_from_date( from_date: date = Query(), limit: Union[int, None] = Query(15, description="Maximum number of alerts to fetch"), offset: Union[int, None] = Query(0, description="Number of alerts to skip before starting to fetch"), + risk_score: Union[FwiClass, None] = Query( + None, + description="Override FWI class applied to every sequence; bypasses risk-api lookup.", + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[AlertReadWithSequences]: @@ -189,7 +221,11 @@ async def fetch_alerts_from_date( alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) alerts = await _apply_risk_filter_to_alerts( - alerts, seq_map, target_date=from_date, organization_id=token_payload.organization_id + alerts, + seq_map, + target_date=from_date, + organization_id=token_payload.organization_id, + override_class=risk_score.value if risk_score is not None else None, ) detection_counts = await get_detection_counts_by_sequence_ids( session, diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 3927dd1f..518ee35a 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -18,6 +18,7 @@ logger = logging.getLogger("uvicorn.error") __all__ = [ + "filter_sequences_by_class", "filter_sequences_by_risk", "filter_sequences_by_risk_for_date", "max_conf_from_bboxes", @@ -60,6 +61,24 @@ def _filter_sequences( return kept +def filter_sequences_by_class( + sequences: TypingSequence[Sequence], + fwi_class: Union[str, None], +) -> List[Sequence]: + """Apply a single FWI class threshold to every sequence, regardless of camera. + + Used when callers pass an explicit ``risk_score`` override instead of consulting + the risk-api. ``moderate``/``high``/etc. yield no filtering (returns the input). + """ + if not sequences: + return [] + threshold = min_confidence_for_class(fwi_class) + if threshold is None: + return list(sequences) + thresholds: Dict[int, Union[float, None]] = {seq.camera_id: threshold for seq in sequences} + return _filter_sequences(sequences, thresholds) + + def filter_sequences_by_risk(sequences: TypingSequence[Sequence]) -> List[Sequence]: """Drop sequences whose stored ``max_conf`` is below today's risk-driven threshold for their camera. diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 284a4211..564eca18 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -111,6 +111,21 @@ async def test_unlabeled_latest_keeps_all_when_class_is_moderate( assert low_seq.id in {item["id"] for item in response.json()} +async def _seed_alert_with_sequence(session: AsyncSession, organization_id: int, seq: Sequence) -> Alert: + now = utcnow() + alert = Alert( + organization_id=organization_id, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=19), + ) + session.add(alert) + await session.commit() + await session.refresh(alert) + session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id)) + await session.commit() + return alert + + @pytest.mark.asyncio async def test_alerts_unlabeled_latest_drops_alert_when_all_seqs_below_threshold( async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache @@ -118,18 +133,7 @@ async def test_alerts_unlabeled_latest_drops_alert_when_all_seqs_below_threshold camera_id = pytest.camera_table[1]["id"] # belongs to org 2 (user_idx 2) pose_id = pytest.pose_table[2]["id"] seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) - - now = utcnow() - alert = Alert( - organization_id=2, - started_at=now - timedelta(minutes=20), - last_seen_at=now - timedelta(minutes=19), - ) - detection_session.add(alert) - await detection_session.commit() - await detection_session.refresh(alert) - detection_session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id)) - await detection_session.commit() + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) risk_service._scores = {camera_id: "low"} @@ -141,3 +145,60 @@ async def test_alerts_unlabeled_latest_drops_alert_when_all_seqs_below_threshold response = await async_client.get("/alerts/unlabeled/latest", headers=auth) assert response.status_code == 200, print(response.__dict__) assert alert.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_risk_score_override( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + # Risk-api would say "moderate" (no filter), but the override forces "low" -> 0.45 threshold drops it. + risk_service._scores = {camera_id: "moderate"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest?risk_score=low", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_risk_score_moderate_keeps_everything( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + # Cache says very_low (would drop everything), but the override forces moderate. + risk_service._scores = {camera_id: "very_low"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest?risk_score=moderate", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_risk_score_invalid_value_returns_422( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest?risk_score=bogus", headers=auth) + assert response.status_code == 422 From 0bf0b0380d9ccfb2a277a2513c4bf6e25fd40e66 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:33:29 +0200 Subject: [PATCH 17/36] refactor: collapse risk filter helpers into one and use literal type for risk_score --- src/app/api/api_v1/endpoints/alerts.py | 55 ++++++----------- src/app/api/api_v1/endpoints/sequences.py | 12 ++-- src/app/services/risk.py | 55 ++++++++--------- src/app/services/sequence_confidence.py | 74 +++++------------------ 4 files changed, 65 insertions(+), 131 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index d2e310a5..864d9e54 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -5,8 +5,7 @@ from datetime import date, timedelta -from enum import Enum -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, List, Literal, Union, cast from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status from sqlalchemy import asc, desc @@ -21,24 +20,13 @@ from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead -from app.services.sequence_confidence import ( - filter_sequences_by_class, - filter_sequences_by_risk, - filter_sequences_by_risk_for_date, -) +from app.services.risk import risk_service +from app.services.sequence_confidence import filter_by_class_per_camera from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.telemetry import telemetry_client - -class FwiClass(str, Enum): - """FWI classes accepted as a manual ``risk_score`` override.""" - - very_low = "very_low" - low = "low" - moderate = "moderate" - high = "high" - very_high = "very_high" - extreme = "extreme" +# FWI classes accepted as a manual ``risk_score`` override on the listing endpoints. +FwiClass = Literal["very_low", "low", "moderate", "high", "very_high", "extreme"] router = APIRouter() @@ -74,19 +62,20 @@ async def _apply_risk_filter_to_alerts( ) -> List[Alert]: """Drop sequences below the risk threshold and alerts that end up empty. - When ``override_class`` is provided, that single FWI class is applied to every - sequence (no risk-api lookup). Otherwise: ``target_date`` triggers a per-date - risk-api call (scoped to ``organization_id`` if given); without it we use today's - cached value. + Resolution priority for the per-camera FWI class: + ``override_class`` (single value applied to all cameras) → risk-api ``/scores/{date}`` + when ``target_date`` is set → today's RAM cache. """ - all_sequences = [seq for seqs in seq_map.values() for seq in seqs] + all_seqs = [seq for seqs in seq_map.values() for seq in seqs] if override_class is not None: - kept_seqs = filter_sequences_by_class(all_sequences, override_class) - elif target_date is None: - kept_seqs = filter_sequences_by_risk(all_sequences) + class_per_camera: Dict[int, Union[str, None]] = {seq.camera_id: override_class for seq in all_seqs} + elif target_date is not None: + scores = await risk_service.get_scores_for_date(target_date, organization_id=organization_id) + class_per_camera = {seq.camera_id: scores.get(seq.camera_id) for seq in all_seqs} else: - kept_seqs = await filter_sequences_by_risk_for_date(all_sequences, target_date, organization_id=organization_id) - kept_ids = {seq.id for seq in kept_seqs} + class_per_camera = {seq.camera_id: risk_service.class_for_camera(seq.camera_id) for seq in all_seqs} + + kept_ids = {seq.id for seq in filter_by_class_per_camera(all_seqs, class_per_camera)} kept_alerts: List[Alert] = [] for alert in alerts: filtered = [seq for seq in seq_map.get(alert.id, []) if seq.id in kept_ids] @@ -163,8 +152,7 @@ async def fetch_alert_sequences( ) async def fetch_latest_unlabeled_alerts( risk_score: Union[FwiClass, None] = Query( - None, - description="Override FWI class applied to every sequence; bypasses risk-api lookup.", + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), @@ -184,9 +172,7 @@ async def fetch_latest_unlabeled_alerts( alerts = list(alerts_res.unique().all()) alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - alerts = await _apply_risk_filter_to_alerts( - alerts, seq_map, override_class=risk_score.value if risk_score is not None else None - ) + alerts = await _apply_risk_filter_to_alerts(alerts, seq_map, override_class=risk_score) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), @@ -200,8 +186,7 @@ async def fetch_alerts_from_date( limit: Union[int, None] = Query(15, description="Maximum number of alerts to fetch"), offset: Union[int, None] = Query(0, description="Number of alerts to skip before starting to fetch"), risk_score: Union[FwiClass, None] = Query( - None, - description="Override FWI class applied to every sequence; bypasses risk-api lookup.", + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), @@ -225,7 +210,7 @@ async def fetch_alerts_from_date( seq_map, target_date=from_date, organization_id=token_payload.organization_id, - override_class=risk_score.value if risk_score is not None else None, + override_class=risk_score, ) detection_counts = await get_detection_counts_by_sequence_ids( session, diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index a92fa300..90652587 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -22,7 +22,8 @@ from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel, SequenceRead from app.services.overlap import compute_overlap -from app.services.sequence_confidence import filter_sequences_by_risk, filter_sequences_by_risk_for_date +from app.services.risk import risk_service +from app.services.sequence_confidence import filter_by_class_per_camera from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.storage import s3_service from app.services.telemetry import telemetry_client @@ -163,7 +164,8 @@ async def fetch_latest_unlabeled_sequences( .limit(15) ) ).all() - fetched_sequences = filter_sequences_by_risk(fetched_sequences) + classes = {seq.camera_id: risk_service.class_for_camera(seq.camera_id) for seq in fetched_sequences} + fetched_sequences = filter_by_class_per_camera(fetched_sequences, classes) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] @@ -190,9 +192,9 @@ async def fetch_sequences_from_date( .offset(offset) ) ).all() - fetched_sequences = await filter_sequences_by_risk_for_date( - fetched_sequences, from_date, organization_id=token_payload.organization_id - ) + scores = await risk_service.get_scores_for_date(from_date, organization_id=token_payload.organization_id) + classes = {seq.camera_id: scores.get(seq.camera_id) for seq in fetched_sequences} + fetched_sequences = filter_by_class_per_camera(fetched_sequences, classes) counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] diff --git a/src/app/services/risk.py b/src/app/services/risk.py index cebcfd0c..28e33a49 100644 --- a/src/app/services/risk.py +++ b/src/app/services/risk.py @@ -15,21 +15,26 @@ __all__ = ["min_confidence_for_class", "risk_service"] -# EFFIS classes that should trigger filtering. Anything else (moderate+) → no filter. -_LOW = "low" -_VERY_LOW = "very_low" - def min_confidence_for_class(fwi_class: Union[str, None]) -> Union[float, None]: """Return the min confidence required for this FWI class, or None if no filter applies.""" if not fwi_class: return None - normalized = fwi_class.strip().lower().replace(" ", "_") - if normalized == _VERY_LOW: - return settings.FWI_VERY_LOW_MIN_CONF - if normalized == _LOW: - return settings.FWI_LOW_MIN_CONF - return None + table = {"very_low": settings.FWI_VERY_LOW_MIN_CONF, "low": settings.FWI_LOW_MIN_CONF} + return table.get(fwi_class.strip().lower().replace(" ", "_")) + + +def _parse_scores_payload(payload: object) -> dict[int, str]: + """Pull ``{camera_id: fwi_class}`` from a list of risk-api items. Skip malformed rows.""" + if not isinstance(payload, list): + return {} + scores: dict[int, str] = {} + for item in payload: + cid = item.get("id") or item.get("camera_id") + fwi = item.get("fwi_class") + if isinstance(cid, int) and isinstance(fwi, str): + scores[cid] = fwi + return scores class RiskService: @@ -46,11 +51,15 @@ def __init__(self) -> None: def is_configured(self) -> bool: return bool(settings.RISK_API_URL and settings.RISK_API_LOGIN and settings.RISK_API_PWD) + def class_for_camera(self, camera_id: int) -> Union[str, None]: + """Return today's cached FWI class for a camera, or None if unknown.""" + return self._scores.get(camera_id) + def min_confidence(self, camera_id: int) -> Union[float, None]: """Return the min confidence required for this camera (today), or None if no filter.""" return min_confidence_for_class(self._scores.get(camera_id)) - async def _fetch(self, path: str, params: Union[dict, None] = None) -> Union[list, dict, None]: + async def _fetch(self, path: str, params: Union[dict, None] = None) -> object: if not self.is_configured: return None host = settings.RISK_API_URL.rstrip("/") # type: ignore[union-attr] @@ -69,17 +78,10 @@ async def _fetch(self, path: str, params: Union[dict, None] = None) -> Union[lis async def refresh(self) -> None: """Fetch fresh FWI classes from the risk API. On error, keep the previous cache.""" - payload = await self._fetch("cameras") - if not isinstance(payload, list): + scores = _parse_scores_payload(await self._fetch("cameras")) + if not scores: logger.warning("Risk API refresh: keeping previous cache (%d entries)", len(self._scores)) return - - scores: dict[int, str] = {} - for item in payload: - camera_id = item.get("id") - fwi_class = item.get("fwi_class") - if isinstance(camera_id, int) and isinstance(fwi_class, str): - scores[camera_id] = fwi_class self._scores = scores logger.info("Risk API refresh: cached FWI class for %d camera(s)", len(scores)) @@ -88,17 +90,8 @@ async def get_scores_for_date(self, target_date: date, organization_id: Union[in Returns {} on error or when not configured. """ - params: Union[dict, None] = {"organization_id": organization_id} if organization_id is not None else None - payload = await self._fetch(f"scores/{target_date.isoformat()}", params=params) - if not isinstance(payload, list): - return {} - scores: dict[int, str] = {} - for item in payload: - camera_id = item.get("id") or item.get("camera_id") - fwi_class = item.get("fwi_class") - if isinstance(camera_id, int) and isinstance(fwi_class, str): - scores[camera_id] = fwi_class - return scores + params = {"organization_id": organization_id} if organization_id is not None else None + return _parse_scores_payload(await self._fetch(f"scores/{target_date.isoformat()}", params=params)) risk_service = RiskService() diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 518ee35a..639d71bd 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -7,22 +7,16 @@ import logging import re from ast import literal_eval -from datetime import date from typing import Dict, List, Union from typing import Sequence as TypingSequence from app.models import Sequence from app.schemas.detections import BOX_PATTERN -from app.services.risk import min_confidence_for_class, risk_service +from app.services.risk import min_confidence_for_class logger = logging.getLogger("uvicorn.error") -__all__ = [ - "filter_sequences_by_class", - "filter_sequences_by_risk", - "filter_sequences_by_risk_for_date", - "max_conf_from_bboxes", -] +__all__ = ["filter_by_class_per_camera", "max_conf_from_bboxes"] def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: @@ -46,62 +40,22 @@ def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: return best -def _filter_sequences( +def filter_by_class_per_camera( sequences: TypingSequence[Sequence], - thresholds: Dict[int, Union[float, None]], + class_per_camera: Dict[int, Union[str, None]], ) -> List[Sequence]: - if all(threshold is None for threshold in thresholds.values()): - return list(sequences) - - kept: List[Sequence] = [] - for seq in sequences: - threshold = thresholds.get(seq.camera_id) - if threshold is None or seq.max_conf is None or seq.max_conf >= threshold: - kept.append(seq) - return kept - - -def filter_sequences_by_class( - sequences: TypingSequence[Sequence], - fwi_class: Union[str, None], -) -> List[Sequence]: - """Apply a single FWI class threshold to every sequence, regardless of camera. + """Drop sequences whose stored ``max_conf`` falls below the threshold for their camera's FWI class. - Used when callers pass an explicit ``risk_score`` override instead of consulting - the risk-api. ``moderate``/``high``/etc. yield no filtering (returns the input). + Fail-open: a sequence is kept when its camera has no FWI class (moderate+ or unknown) + or when ``seq.max_conf`` is NULL. """ if not sequences: return [] - threshold = min_confidence_for_class(fwi_class) - if threshold is None: + thresholds = {cid: min_confidence_for_class(c) for cid, c in class_per_camera.items()} + if all(t is None for t in thresholds.values()): return list(sequences) - thresholds: Dict[int, Union[float, None]] = {seq.camera_id: threshold for seq in sequences} - return _filter_sequences(sequences, thresholds) - - -def filter_sequences_by_risk(sequences: TypingSequence[Sequence]) -> List[Sequence]: - """Drop sequences whose stored ``max_conf`` is below today's risk-driven threshold for their camera. - - Fail-open: a sequence is kept if either the camera has no FWI score (moderate+ or unknown) - or ``seq.max_conf`` is NULL. - """ - if not sequences: - return [] - thresholds = {seq.camera_id: risk_service.min_confidence(seq.camera_id) for seq in sequences} - return _filter_sequences(sequences, thresholds) - - -async def filter_sequences_by_risk_for_date( - sequences: TypingSequence[Sequence], - target_date: date, - organization_id: Union[int, None] = None, -) -> List[Sequence]: - """Like ``filter_sequences_by_risk``, but uses the FWI class persisted for a specific date. - - When ``organization_id`` is provided, the risk-api call is scoped to that organization. - """ - if not sequences: - return [] - scores = await risk_service.get_scores_for_date(target_date, organization_id=organization_id) - thresholds = {seq.camera_id: min_confidence_for_class(scores.get(seq.camera_id)) for seq in sequences} - return _filter_sequences(sequences, thresholds) + return [ + seq + for seq in sequences + if (t := thresholds.get(seq.camera_id)) is None or seq.max_conf is None or seq.max_conf >= t + ] From 36353bce9400e94254d58a837bfa866a1b41bfd3 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 04:37:49 +0200 Subject: [PATCH 18/36] fix: restore refresh() cache-replace on empty list and harden payload parser --- src/app/services/risk.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/app/services/risk.py b/src/app/services/risk.py index 28e33a49..b823fe19 100644 --- a/src/app/services/risk.py +++ b/src/app/services/risk.py @@ -30,6 +30,8 @@ def _parse_scores_payload(payload: object) -> dict[int, str]: return {} scores: dict[int, str] = {} for item in payload: + if not isinstance(item, dict): + continue cid = item.get("id") or item.get("camera_id") fwi = item.get("fwi_class") if isinstance(cid, int) and isinstance(fwi, str): @@ -77,13 +79,13 @@ async def _fetch(self, path: str, params: Union[dict, None] = None) -> object: return None async def refresh(self) -> None: - """Fetch fresh FWI classes from the risk API. On error, keep the previous cache.""" - scores = _parse_scores_payload(await self._fetch("cameras")) - if not scores: - logger.warning("Risk API refresh: keeping previous cache (%d entries)", len(self._scores)) + """Fetch fresh FWI classes from the risk API. On network/HTTP failure, keep the previous cache.""" + payload = await self._fetch("cameras") + if payload is None: + logger.warning("Risk API refresh failed; keeping previous cache (%d entries)", len(self._scores)) return - self._scores = scores - logger.info("Risk API refresh: cached FWI class for %d camera(s)", len(scores)) + self._scores = _parse_scores_payload(payload) + logger.info("Risk API refresh: cached FWI class for %d camera(s)", len(self._scores)) async def get_scores_for_date(self, target_date: date, organization_id: Union[int, None] = None) -> dict[int, str]: """Fetch persisted FWI classes for a specific date, optionally scoped to an organization. From 386911365d575b538ea72f32cdcc6c0f8e5763f0 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 06:04:35 +0200 Subject: [PATCH 19/36] feat: push risk filter into SQL WHERE for exact pagination --- src/app/api/api_v1/endpoints/alerts.py | 112 +++++++++++----------- src/app/api/api_v1/endpoints/sequences.py | 72 ++++++++------ src/app/services/sequence_confidence.py | 25 ++++- 3 files changed, 122 insertions(+), 87 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 864d9e54..c74a9b7d 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -5,10 +5,11 @@ from datetime import date, timedelta -from typing import Any, Dict, List, Literal, Union, cast +from typing import Any, Dict, List, Union, cast from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status from sqlalchemy import asc, desc +from sqlalchemy.sql import ColumnElement from sqlmodel import delete, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -16,19 +17,15 @@ from app.core.time import utcnow from app.crud import AlertCRUD from app.db import get_session -from app.models import Alert, AlertSequence, Sequence, UserRole +from app.models import Alert, AlertSequence, Camera, Sequence, UserRole from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead -from app.services.risk import risk_service -from app.services.sequence_confidence import filter_by_class_per_camera +from app.services.risk import FwiClass, risk_service +from app.services.sequence_confidence import max_conf_filter_clause from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.telemetry import telemetry_client -# FWI classes accepted as a manual ``risk_score`` override on the listing endpoints. -FwiClass = Literal["very_low", "low", "moderate", "high", "very_high", "extreme"] - - router = APIRouter() @@ -37,7 +34,11 @@ def verify_org_rights(organization_id: int, alert: Alert) -> None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden.") -async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[int]) -> Dict[int, List[Sequence]]: +async def _fetch_sequences_by_alert_ids( + session: AsyncSession, + alert_ids: List[int], + seq_filter: Union[ColumnElement[bool], None] = None, +) -> Dict[int, List[Sequence]]: mapping: Dict[int, List[Sequence]] = {} if not alert_ids: return mapping @@ -45,44 +46,30 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i select(AlertSequence.alert_id, Sequence) .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) .where(AlertSequence.alert_id.in_(alert_ids)) # type: ignore[attr-defined] - .order_by(cast(Any, AlertSequence.alert_id), desc(cast(Any, Sequence.last_seen_at))) ) + if seq_filter is not None: + seq_stmt = seq_stmt.where(seq_filter) + seq_stmt = seq_stmt.order_by(cast(Any, AlertSequence.alert_id), desc(cast(Any, Sequence.last_seen_at))) res = await session.exec(seq_stmt) for alert_id, sequence in res.all(): mapping.setdefault(int(alert_id), []).append(sequence) return mapping -async def _apply_risk_filter_to_alerts( - alerts: List[Alert], - seq_map: Dict[int, List[Sequence]], +async def _resolve_class_per_camera( + session: AsyncSession, + organization_id: int, target_date: Union[date, None] = None, - organization_id: Union[int, None] = None, override_class: Union[str, None] = None, -) -> List[Alert]: - """Drop sequences below the risk threshold and alerts that end up empty. - - Resolution priority for the per-camera FWI class: - ``override_class`` (single value applied to all cameras) → risk-api ``/scores/{date}`` - when ``target_date`` is set → today's RAM cache. - """ - all_seqs = [seq for seqs in seq_map.values() for seq in seqs] +) -> Dict[int, Union[str, None]]: + """Resolve ``{camera_id: fwi_class}`` for the org, picking override → per-date → today's cache.""" if override_class is not None: - class_per_camera: Dict[int, Union[str, None]] = {seq.camera_id: override_class for seq in all_seqs} - elif target_date is not None: + cam_ids = (await session.exec(select(Camera.id).where(Camera.organization_id == organization_id))).all() + return dict.fromkeys(cam_ids, override_class) + if target_date is not None: scores = await risk_service.get_scores_for_date(target_date, organization_id=organization_id) - class_per_camera = {seq.camera_id: scores.get(seq.camera_id) for seq in all_seqs} - else: - class_per_camera = {seq.camera_id: risk_service.class_for_camera(seq.camera_id) for seq in all_seqs} - - kept_ids = {seq.id for seq in filter_by_class_per_camera(all_seqs, class_per_camera)} - kept_alerts: List[Alert] = [] - for alert in alerts: - filtered = [seq for seq in seq_map.get(alert.id, []) if seq.id in kept_ids] - if filtered: - seq_map[alert.id] = filtered - kept_alerts.append(alert) - return kept_alerts + return {cid: cls for cid, cls in scores.items()} + return {cid: cls for cid, cls in risk_service.scores().items()} def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead: @@ -159,20 +146,28 @@ async def fetch_latest_unlabeled_alerts( ) -> List[AlertReadWithSequences]: telemetry_client.capture(token_payload.sub, event="alerts-fetch-latest") - alerts_stmt: Any = select(Alert).join(AlertSequence, cast(Any, AlertSequence.alert_id == Alert.id)) - alerts_stmt = alerts_stmt.join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) - alerts_stmt = ( - alerts_stmt.where(Alert.organization_id == token_payload.organization_id) + classes = await _resolve_class_per_camera(session, token_payload.organization_id, override_class=risk_score) + seq_filter = max_conf_filter_clause(classes) + + seq_match: Any = ( + select(AlertSequence.alert_id) + .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) .where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + ) + if seq_filter is not None: + seq_match = seq_match.where(seq_filter) + + alerts_stmt: Any = ( + select(Alert) + .where(Alert.organization_id == token_payload.organization_id) + .where(cast(Any, Alert.id).in_(seq_match)) .order_by(Alert.started_at.desc()) # type: ignore[attr-defined] .limit(15) ) - alerts_res = await session.exec(alerts_stmt) - alerts = list(alerts_res.unique().all()) + alerts = list((await session.exec(alerts_stmt)).all()) alert_ids = [alert.id for alert in alerts] - seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - alerts = await _apply_risk_filter_to_alerts(alerts, seq_map, override_class=risk_score) + seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids, seq_filter) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), @@ -193,25 +188,28 @@ async def fetch_alerts_from_date( ) -> List[AlertReadWithSequences]: telemetry_client.capture(token_payload.sub, event="alerts-fetch-from-date") + classes = await _resolve_class_per_camera( + session, token_payload.organization_id, target_date=from_date, override_class=risk_score + ) + seq_filter = max_conf_filter_clause(classes) + alerts_stmt: Any = ( select(Alert) .where(Alert.organization_id == token_payload.organization_id) .where(func.date(Alert.started_at) == from_date) - .order_by(Alert.started_at.desc()) # type: ignore[attr-defined] - .limit(limit) - .offset(offset) ) - alerts_res = await session.exec(alerts_stmt) - alerts = list(alerts_res.all()) + if seq_filter is not None: + seq_match: Any = ( + select(AlertSequence.alert_id) + .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) + .where(seq_filter) + ) + alerts_stmt = alerts_stmt.where(cast(Any, Alert.id).in_(seq_match)) + alerts_stmt = alerts_stmt.order_by(Alert.started_at.desc()).limit(limit).offset(offset) # type: ignore[attr-defined] + + alerts = list((await session.exec(alerts_stmt)).all()) alert_ids = [alert.id for alert in alerts] - seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - alerts = await _apply_risk_filter_to_alerts( - alerts, - seq_map, - target_date=from_date, - organization_id=token_payload.organization_id, - override_class=risk_score, - ) + seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids, seq_filter) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index 90652587..8bc8c7e5 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -22,8 +22,8 @@ from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel, SequenceRead from app.services.overlap import compute_overlap -from app.services.risk import risk_service -from app.services.sequence_confidence import filter_by_class_per_camera +from app.services.risk import FwiClass, risk_service +from app.services.sequence_confidence import max_conf_filter_clause from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.storage import s3_service from app.services.telemetry import telemetry_client @@ -148,24 +148,32 @@ async def fetch_sequence_detections( summary="Fetch all the unlabeled sequences from the last 24 hours", ) async def fetch_latest_unlabeled_sequences( + risk_score: Union[FwiClass, None] = Query( + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[SequenceRead]: telemetry_client.capture(token_payload.sub, event="sequence-fetch-latest") - camera_ids = await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) - - fetched_sequences = ( - await session.exec( - select(Sequence) - .where(Sequence.started_at > utcnow() - timedelta(hours=24)) - .where(Sequence.camera_id.in_(camera_ids.all())) # type: ignore[attr-defined] - .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] - .order_by(Sequence.started_at.desc()) # type: ignore[attr-defined] - .limit(15) - ) + camera_ids = ( + await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) ).all() - classes = {seq.camera_id: risk_service.class_for_camera(seq.camera_id) for seq in fetched_sequences} - fetched_sequences = filter_by_class_per_camera(fetched_sequences, classes) + classes: dict[int, Union[str, None]] = ( + dict.fromkeys(camera_ids, risk_score) if risk_score is not None else dict(risk_service.scores()) + ) + + stmt: Any = ( + select(Sequence) + .where(Sequence.started_at > utcnow() - timedelta(hours=24)) + .where(Sequence.camera_id.in_(camera_ids)) # type: ignore[attr-defined] + .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + ) + seq_filter = max_conf_filter_clause(classes) + if seq_filter is not None: + stmt = stmt.where(seq_filter) + stmt = stmt.order_by(Sequence.started_at.desc()).limit(15) # type: ignore[attr-defined] + + fetched_sequences = (await session.exec(stmt)).all() counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] @@ -175,26 +183,32 @@ async def fetch_sequences_from_date( from_date: date = Query(), limit: Union[int, None] = Query(15, description="Maximum number of sequences to fetch"), offset: Union[int, None] = Query(0, description="Number of sequences to skip before starting to fetch"), + risk_score: Union[FwiClass, None] = Query( + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[SequenceRead]: telemetry_client.capture(token_payload.sub, event="sequence-fetch-from-date") # Limit to cameras in the same organization - camera_ids = await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) - # Identify the sequences from that day - fetched_sequences = ( - await session.exec( - select(Sequence) - .where(func.date(Sequence.started_at) == from_date) - .where(Sequence.camera_id.in_(camera_ids.all())) # type: ignore[attr-defined] - .order_by(Sequence.started_at.desc()) # type: ignore[attr-defined] - .limit(limit) - .offset(offset) - ) + camera_ids = ( + await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) ).all() - scores = await risk_service.get_scores_for_date(from_date, organization_id=token_payload.organization_id) - classes = {seq.camera_id: scores.get(seq.camera_id) for seq in fetched_sequences} - fetched_sequences = filter_by_class_per_camera(fetched_sequences, classes) + if risk_score is not None: + classes: dict[int, Union[str, None]] = dict.fromkeys(camera_ids, risk_score) + else: + scores = await risk_service.get_scores_for_date(from_date, organization_id=token_payload.organization_id) + classes = dict(scores) + + stmt: Any = ( + select(Sequence).where(func.date(Sequence.started_at) == from_date).where(Sequence.camera_id.in_(camera_ids)) # type: ignore[attr-defined] + ) + seq_filter = max_conf_filter_clause(classes) + if seq_filter is not None: + stmt = stmt.where(seq_filter) + stmt = stmt.order_by(Sequence.started_at.desc()).limit(limit).offset(offset) # type: ignore[attr-defined] + + fetched_sequences = (await session.exec(stmt)).all() counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 639d71bd..1397fb49 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -10,13 +10,16 @@ from typing import Dict, List, Union from typing import Sequence as TypingSequence +from sqlalchemy import case, or_ +from sqlalchemy.sql import ColumnElement + from app.models import Sequence from app.schemas.detections import BOX_PATTERN from app.services.risk import min_confidence_for_class logger = logging.getLogger("uvicorn.error") -__all__ = ["filter_by_class_per_camera", "max_conf_from_bboxes"] +__all__ = ["filter_by_class_per_camera", "max_conf_filter_clause", "max_conf_from_bboxes"] def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: @@ -40,6 +43,26 @@ def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: return best +def max_conf_filter_clause(class_per_camera: Dict[int, Union[str, None]]) -> Union[ColumnElement[bool], None]: + """SQL ``WHERE`` clause keeping sequences whose ``max_conf`` passes their camera's threshold. + + Returns ``None`` when no camera has an active threshold (caller should skip the filter). + Fail-open: rows with ``max_conf IS NULL`` are kept; cameras without an entry default to 0. + Collapses to a constant comparison when all active thresholds are equal (avoids a per-camera + ``CASE`` with one arm per camera). + """ + thresholds = { + cid: t for cid, klass in class_per_camera.items() if (t := min_confidence_for_class(klass)) is not None + } + if not thresholds: + return None + distinct = set(thresholds.values()) + if len(distinct) == 1: + return or_(Sequence.max_conf.is_(None), Sequence.max_conf >= distinct.pop()) # type: ignore[union-attr] + threshold_expr = case(*[(Sequence.camera_id == cid, t) for cid, t in thresholds.items()], else_=0.0) + return or_(Sequence.max_conf.is_(None), Sequence.max_conf >= threshold_expr) # type: ignore[union-attr] + + def filter_by_class_per_camera( sequences: TypingSequence[Sequence], class_per_camera: Dict[int, Union[str, None]], From be1af2cb1dc7de53716173089ce70e4e254ab521 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 06:04:36 +0200 Subject: [PATCH 20/36] refactor: replace fwi conf settings with FWI_MIN_CONF dict in risk module --- .env.example | 2 -- docker-compose.yml | 2 -- src/app/core/config.py | 9 --------- src/app/services/risk.py | 26 ++++++++++++++++++++++---- src/tests/services/test_risk.py | 16 ++++++++-------- 5 files changed, 30 insertions(+), 25 deletions(-) diff --git a/.env.example b/.env.example index 494b3ab6..84aecc0d 100644 --- a/.env.example +++ b/.env.example @@ -28,8 +28,6 @@ RISK_API_URL= RISK_API_LOGIN= RISK_API_PWD= RISK_REFRESH_HOUR_UTC=4 -FWI_LOW_MIN_CONF=0.45 -FWI_VERY_LOW_MIN_CONF=0.6 # Production-only ACME_EMAIL= diff --git a/docker-compose.yml b/docker-compose.yml index 7405255f..a1b79277 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -69,8 +69,6 @@ services: - RISK_API_LOGIN=${RISK_API_LOGIN} - RISK_API_PWD=${RISK_API_PWD} - RISK_REFRESH_HOUR_UTC=${RISK_REFRESH_HOUR_UTC:-4} - - FWI_LOW_MIN_CONF=${FWI_LOW_MIN_CONF:-0.45} - - FWI_VERY_LOW_MIN_CONF=${FWI_VERY_LOW_MIN_CONF:-0.6} volumes: - ./src/:/app/ command: "sh -c 'alembic upgrade head && python app/db.py && uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers'" diff --git a/src/app/core/config.py b/src/app/core/config.py index c407e33b..08b2e57e 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -82,15 +82,6 @@ def sqlachmey_uri(cls, v: str) -> str: RISK_API_LOGIN: Union[str, None] = os.environ.get("RISK_API_LOGIN") RISK_API_PWD: Union[str, None] = os.environ.get("RISK_API_PWD") RISK_REFRESH_HOUR_UTC: int = int(os.environ.get("RISK_REFRESH_HOUR_UTC") or 4) - FWI_LOW_MIN_CONF: float = float(os.environ.get("FWI_LOW_MIN_CONF") or 0.45) - FWI_VERY_LOW_MIN_CONF: float = float(os.environ.get("FWI_VERY_LOW_MIN_CONF") or 0.6) - - @field_validator("FWI_LOW_MIN_CONF", "FWI_VERY_LOW_MIN_CONF") - @classmethod - def fwi_min_conf_in_unit_range(cls, v: float) -> float: - if not 0.0 <= v <= 1.0: - raise ValueError("FWI confidence thresholds must lie in [0, 1]") - return v # Error monitoring SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN") diff --git a/src/app/services/risk.py b/src/app/services/risk.py index b823fe19..21ccc245 100644 --- a/src/app/services/risk.py +++ b/src/app/services/risk.py @@ -5,7 +5,7 @@ import logging from datetime import date -from typing import Union +from typing import Literal, Union import httpx @@ -13,15 +13,29 @@ logger = logging.getLogger("uvicorn.error") -__all__ = ["min_confidence_for_class", "risk_service"] +__all__ = ["FWI_MIN_CONF", "FwiClass", "min_confidence_for_class", "risk_service"] + +# FWI classes accepted by the risk-api and as a manual ``risk_score`` override. +FwiClass = Literal["very_low", "low", "moderate", "high", "very_high", "extreme"] + +# Minimum sequence ``max_conf`` required per FWI class. Zero or absent → no filter. +# All EFFIS classes are listed even when unused so the table stays explicit and easy to tune. +FWI_MIN_CONF: dict[str, float] = { + "very_low": 0.6, + "low": 0.45, + "moderate": 0.0, + "high": 0.0, + "very_high": 0.0, + "extreme": 0.0, +} def min_confidence_for_class(fwi_class: Union[str, None]) -> Union[float, None]: """Return the min confidence required for this FWI class, or None if no filter applies.""" if not fwi_class: return None - table = {"very_low": settings.FWI_VERY_LOW_MIN_CONF, "low": settings.FWI_LOW_MIN_CONF} - return table.get(fwi_class.strip().lower().replace(" ", "_")) + threshold = FWI_MIN_CONF.get(fwi_class.strip().lower().replace(" ", "_")) + return threshold or None def _parse_scores_payload(payload: object) -> dict[int, str]: @@ -57,6 +71,10 @@ def class_for_camera(self, camera_id: int) -> Union[str, None]: """Return today's cached FWI class for a camera, or None if unknown.""" return self._scores.get(camera_id) + def scores(self) -> dict[int, str]: + """Return a copy of the full ``{camera_id: fwi_class}`` cache.""" + return dict(self._scores) + def min_confidence(self, camera_id: int) -> Union[float, None]: """Return the min confidence required for this camera (today), or None if no filter.""" return min_confidence_for_class(self._scores.get(camera_id)) diff --git a/src/tests/services/test_risk.py b/src/tests/services/test_risk.py index 49af7d45..cefdded0 100644 --- a/src/tests/services/test_risk.py +++ b/src/tests/services/test_risk.py @@ -6,12 +6,12 @@ import pytest from app.core.config import settings -from app.services.risk import RiskService, min_confidence_for_class +from app.services.risk import FWI_MIN_CONF, RiskService, min_confidence_for_class def test_min_confidence_for_class(): - assert min_confidence_for_class("very_low") == settings.FWI_VERY_LOW_MIN_CONF - assert min_confidence_for_class("low") == settings.FWI_LOW_MIN_CONF + assert min_confidence_for_class("very_low") == FWI_MIN_CONF["very_low"] + assert min_confidence_for_class("low") == FWI_MIN_CONF["low"] assert min_confidence_for_class("moderate") is None assert min_confidence_for_class("high") is None assert min_confidence_for_class("very_high") is None @@ -21,16 +21,16 @@ def test_min_confidence_for_class(): def test_min_confidence_for_class_normalizes_casing(): - assert min_confidence_for_class("Very Low") == settings.FWI_VERY_LOW_MIN_CONF - assert min_confidence_for_class("LOW") == settings.FWI_LOW_MIN_CONF - assert min_confidence_for_class(" very_low ") == settings.FWI_VERY_LOW_MIN_CONF + assert min_confidence_for_class("Very Low") == FWI_MIN_CONF["very_low"] + assert min_confidence_for_class("LOW") == FWI_MIN_CONF["low"] + assert min_confidence_for_class(" very_low ") == FWI_MIN_CONF["very_low"] def test_risk_service_min_confidence_uses_cached_class(): service = RiskService() service._scores = {1: "very_low", 2: "low", 3: "moderate"} - assert service.min_confidence(1) == settings.FWI_VERY_LOW_MIN_CONF - assert service.min_confidence(2) == settings.FWI_LOW_MIN_CONF + assert service.min_confidence(1) == FWI_MIN_CONF["very_low"] + assert service.min_confidence(2) == FWI_MIN_CONF["low"] assert service.min_confidence(3) is None assert service.min_confidence(99) is None From 27f7cf75c6654042ba94abc7ed7942e24bc5ce24 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 06:10:34 +0200 Subject: [PATCH 21/36] test: pagination on /sequences/all/fromdate keeps page full when filter applies --- src/tests/endpoints/test_risk_filter.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 564eca18..d21080e3 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -202,3 +202,35 @@ async def test_alerts_unlabeled_latest_risk_score_invalid_value_returns_422( ) response = await async_client.get("/alerts/unlabeled/latest?risk_score=bogus", headers=auth) assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_sequences_fromdate_pagination_filters_before_limit( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """A page must come back full when the filter would otherwise consume the whole limit.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + + # 2 below threshold + 3 above. With LIMIT 3, naive post-filter would return at most 1. + for max_conf, minutes_ago in [(0.10, 50), (0.20, 45), (0.50, 40), (0.60, 35), (0.80, 30)]: + await _seed_unlabeled_sequence( + detection_session, camera_id, pose_id, max_conf=max_conf, minutes_ago=minutes_ago + ) + + risk_service._scores = {camera_id: "low"} # threshold 0.45 + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + + response = await async_client.get( + f"/sequences/all/fromdate?from_date={target_date}&limit=3", headers=auth + ) + assert response.status_code == 200, print(response.__dict__) + page = response.json() + assert len(page) == 3 + assert all(seq["max_conf"] >= 0.45 for seq in page) From aa20bdba3bd0b913706c7a7f432fc69790110988 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 06:14:18 +0200 Subject: [PATCH 22/36] fix: drop max_conf clause collapse and route pagination test through override --- src/app/services/sequence_confidence.py | 8 ++------ src/tests/endpoints/test_risk_filter.py | 5 ++--- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 1397fb49..b8050b21 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -47,18 +47,14 @@ def max_conf_filter_clause(class_per_camera: Dict[int, Union[str, None]]) -> Uni """SQL ``WHERE`` clause keeping sequences whose ``max_conf`` passes their camera's threshold. Returns ``None`` when no camera has an active threshold (caller should skip the filter). - Fail-open: rows with ``max_conf IS NULL`` are kept; cameras without an entry default to 0. - Collapses to a constant comparison when all active thresholds are equal (avoids a per-camera - ``CASE`` with one arm per camera). + Fail-open: rows with ``max_conf IS NULL`` are kept, and cameras absent from ``class_per_camera`` + default to threshold 0 (everything passes) via the ``CASE`` ``else_`` clause. """ thresholds = { cid: t for cid, klass in class_per_camera.items() if (t := min_confidence_for_class(klass)) is not None } if not thresholds: return None - distinct = set(thresholds.values()) - if len(distinct) == 1: - return or_(Sequence.max_conf.is_(None), Sequence.max_conf >= distinct.pop()) # type: ignore[union-attr] threshold_expr = case(*[(Sequence.camera_id == cid, t) for cid, t in thresholds.items()], else_=0.0) return or_(Sequence.max_conf.is_(None), Sequence.max_conf >= threshold_expr) # type: ignore[union-attr] diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index d21080e3..2779c786 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -219,16 +219,15 @@ async def test_sequences_fromdate_pagination_filters_before_limit( detection_session, camera_id, pose_id, max_conf=max_conf, minutes_ago=minutes_ago ) - risk_service._scores = {camera_id: "low"} # threshold 0.45 - auth = pytest.get_token( pytest.user_table[2]["id"], pytest.user_table[2]["role"].split(), pytest.user_table[2]["organization_id"], ) + # The ``risk_score=low`` override drives the threshold without hitting the risk API. response = await async_client.get( - f"/sequences/all/fromdate?from_date={target_date}&limit=3", headers=auth + f"/sequences/all/fromdate?from_date={target_date}&limit=3&risk_score=low", headers=auth ) assert response.status_code == 200, print(response.__dict__) page = response.json() From fb0745a09f77857468e956a462c82cdf523340ef Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 06:20:44 +0200 Subject: [PATCH 23/36] fix: compute sequence max_conf from primary bbox only, ignore sibling detections --- src/app/api/api_v1/endpoints/detections.py | 8 +++----- ...6_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py | 7 ++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/app/api/api_v1/endpoints/detections.py b/src/app/api/api_v1/endpoints/detections.py index 48b3b835..a8364b96 100644 --- a/src/app/api/api_v1/endpoints/detections.py +++ b/src/app/api/api_v1/endpoints/detections.py @@ -432,7 +432,8 @@ async def create_detection( if matched_sequence is not None: await sequences.update(matched_sequence.id, SequenceUpdate(last_seen_at=det.created_at)) det = await detections.update(det.id, DetectionSequence(sequence_id=matched_sequence.id)) - det_max_conf = max_conf_from_bboxes(det.bbox, det.others_bboxes) + # Only the primary bbox tracks the sequence; siblings in others_bboxes are unrelated detections. + det_max_conf = max_conf_from_bboxes(det.bbox) if det_max_conf is not None: await sequences.bump_max_conf(matched_sequence.id, det_max_conf) else: @@ -463,10 +464,7 @@ async def create_detection( if len(overlapping_dets) >= settings.SEQUENCE_MIN_INTERVAL_DETS: first_det = min(overlapping_dets, key=lambda item: item.created_at) cone_azimuth, cone_angle = resolve_cone(pose.azimuth, first_det.bbox, camera.angle_of_view) - seq_max_conf = max_conf_from_bboxes( - *[d.bbox for d in overlapping_dets], - *[d.others_bboxes for d in overlapping_dets], - ) + seq_max_conf = max_conf_from_bboxes(*[d.bbox for d in overlapping_dets]) sequence_ = await sequences.create( Sequence( camera_id=token_payload.sub, diff --git a/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py index 9d77dcb7..d4703d9c 100644 --- a/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py +++ b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py @@ -50,13 +50,14 @@ def upgrade() -> None: op.add_column("sequences", sa.Column("max_conf", sa.Float(), nullable=True)) bind = op.get_bind() + # Only the primary bbox tracks the sequence; siblings in others_bboxes are unrelated detections. rows = bind.execute( - sa.text("SELECT sequence_id, bbox, others_bboxes FROM detections WHERE sequence_id IS NOT NULL") + sa.text("SELECT sequence_id, bbox FROM detections WHERE sequence_id IS NOT NULL") ).fetchall() seq_max: dict[int, float] = {} - for sequence_id, bbox, others in rows: - conf = _max_conf(bbox, others) + for sequence_id, bbox in rows: + conf = _max_conf(bbox) if conf is None: continue current = seq_max.get(sequence_id) From 9c14549b32a6c972bb06cef143dd48af85998f2f Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 07:48:34 +0200 Subject: [PATCH 24/36] chore: apply ruff format --- .../2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py index d4703d9c..3fa62461 100644 --- a/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py +++ b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py @@ -51,9 +51,7 @@ def upgrade() -> None: bind = op.get_bind() # Only the primary bbox tracks the sequence; siblings in others_bboxes are unrelated detections. - rows = bind.execute( - sa.text("SELECT sequence_id, bbox FROM detections WHERE sequence_id IS NOT NULL") - ).fetchall() + rows = bind.execute(sa.text("SELECT sequence_id, bbox FROM detections WHERE sequence_id IS NOT NULL")).fetchall() seq_max: dict[int, float] = {} for sequence_id, bbox in rows: From 3bd419a57adf7c56d896594a8e21b7f93f6cedf8 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 11:01:15 +0200 Subject: [PATCH 25/36] fix: silence mypy on case() and chained where() over join() --- src/app/api/api_v1/endpoints/alerts.py | 16 +++++++--------- src/app/crud/crud_sequence.py | 6 ++---- src/app/services/sequence_confidence.py | 8 +++++--- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index c74a9b7d..54634c05 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -149,12 +149,11 @@ async def fetch_latest_unlabeled_alerts( classes = await _resolve_class_per_camera(session, token_payload.organization_id, override_class=risk_score) seq_filter = max_conf_filter_clause(classes) - seq_match: Any = ( - select(AlertSequence.alert_id) - .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) - .where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) - .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + seq_match: Any = select(AlertSequence.alert_id).join( + Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id) ) + seq_match = seq_match.where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) + seq_match = seq_match.where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] if seq_filter is not None: seq_match = seq_match.where(seq_filter) @@ -199,11 +198,10 @@ async def fetch_alerts_from_date( .where(func.date(Alert.started_at) == from_date) ) if seq_filter is not None: - seq_match: Any = ( - select(AlertSequence.alert_id) - .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) - .where(seq_filter) + seq_match: Any = select(AlertSequence.alert_id).join( + Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id) ) + seq_match = seq_match.where(seq_filter) alerts_stmt = alerts_stmt.where(cast(Any, Alert.id).in_(seq_match)) alerts_stmt = alerts_stmt.order_by(Alert.started_at.desc()).limit(limit).offset(offset) # type: ignore[attr-defined] diff --git a/src/app/crud/crud_sequence.py b/src/app/crud/crud_sequence.py index a493f741..2a89b122 100644 --- a/src/app/crud/crud_sequence.py +++ b/src/app/crud/crud_sequence.py @@ -25,10 +25,8 @@ async def bump_max_conf(self, sequence_id: int, candidate: float) -> None: Uses a portable CASE expression so it runs on SQLite as well as Postgres. """ - bumped = case( - (or_(Sequence.max_conf.is_(None), Sequence.max_conf < candidate), candidate), # type: ignore[union-attr] - else_=Sequence.max_conf, - ) + max_conf_col = cast(Any, Sequence.max_conf) + bumped = case((or_(max_conf_col.is_(None), max_conf_col < candidate), candidate), else_=max_conf_col) stmt: Any = update(Sequence).where(cast(Any, Sequence.id) == sequence_id).values(max_conf=bumped) await self.session.exec(stmt) await self.session.commit() diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index b8050b21..5a21b8f9 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -7,7 +7,7 @@ import logging import re from ast import literal_eval -from typing import Dict, List, Union +from typing import Any, Dict, List, Union, cast from typing import Sequence as TypingSequence from sqlalchemy import case, or_ @@ -55,8 +55,10 @@ def max_conf_filter_clause(class_per_camera: Dict[int, Union[str, None]]) -> Uni } if not thresholds: return None - threshold_expr = case(*[(Sequence.camera_id == cid, t) for cid, t in thresholds.items()], else_=0.0) - return or_(Sequence.max_conf.is_(None), Sequence.max_conf >= threshold_expr) # type: ignore[union-attr] + max_conf_col = cast(Any, Sequence.max_conf) + whens: List[Any] = [(Sequence.camera_id == cid, t) for cid, t in thresholds.items()] + threshold_expr = case(*whens, else_=0.0) + return or_(max_conf_col.is_(None), max_conf_col >= threshold_expr) def filter_by_class_per_camera( From 0607a81f869afaa3b06a7562102e9c136d1de060 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 12:12:12 +0200 Subject: [PATCH 26/36] fix: annotate case() result as Any to satisfy mypy and reformat with ruff --- src/app/api/api_v1/endpoints/alerts.py | 4 +--- src/app/crud/crud_sequence.py | 5 ++++- src/app/services/sequence_confidence.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 54634c05..b0e4923d 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -149,9 +149,7 @@ async def fetch_latest_unlabeled_alerts( classes = await _resolve_class_per_camera(session, token_payload.organization_id, override_class=risk_score) seq_filter = max_conf_filter_clause(classes) - seq_match: Any = select(AlertSequence.alert_id).join( - Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id) - ) + seq_match: Any = select(AlertSequence.alert_id).join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) seq_match = seq_match.where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) seq_match = seq_match.where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] if seq_filter is not None: diff --git a/src/app/crud/crud_sequence.py b/src/app/crud/crud_sequence.py index 2a89b122..58e5271e 100644 --- a/src/app/crud/crud_sequence.py +++ b/src/app/crud/crud_sequence.py @@ -26,7 +26,10 @@ async def bump_max_conf(self, sequence_id: int, candidate: float) -> None: Uses a portable CASE expression so it runs on SQLite as well as Postgres. """ max_conf_col = cast(Any, Sequence.max_conf) - bumped = case((or_(max_conf_col.is_(None), max_conf_col < candidate), candidate), else_=max_conf_col) + bumped: Any = case( + (or_(max_conf_col.is_(None), max_conf_col < candidate), candidate), + else_=max_conf_col, + ) stmt: Any = update(Sequence).where(cast(Any, Sequence.id) == sequence_id).values(max_conf=bumped) await self.session.exec(stmt) await self.session.commit() diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 5a21b8f9..531a1d94 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -57,7 +57,7 @@ def max_conf_filter_clause(class_per_camera: Dict[int, Union[str, None]]) -> Uni return None max_conf_col = cast(Any, Sequence.max_conf) whens: List[Any] = [(Sequence.camera_id == cid, t) for cid, t in thresholds.items()] - threshold_expr = case(*whens, else_=0.0) + threshold_expr: Any = case(*whens, else_=0.0) return or_(max_conf_col.is_(None), max_conf_col >= threshold_expr) From cbdd6422186276550df4b7a596065bd56bb40a9f Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 13:03:01 +0200 Subject: [PATCH 27/36] test: cover _seconds_until_next_utc_hour and risk_score override on /sequences/* --- src/tests/endpoints/test_risk_filter.py | 63 +++++++++++++++++++++++++ src/tests/test_main.py | 59 +++++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 src/tests/test_main.py diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 2779c786..3d49ed62 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -204,6 +204,69 @@ async def test_alerts_unlabeled_latest_risk_score_invalid_value_returns_422( assert response.status_code == 422 +@pytest.mark.asyncio +async def test_sequences_unlabeled_latest_risk_score_override_drops_low_conf( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) + high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.55, minutes_ago=15) + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest?risk_score=low", headers=auth) + assert response.status_code == 200, print(response.__dict__) + returned_ids = {item["id"] for item in response.json()} + assert low_seq.id not in returned_ids + assert high_seq.id in returned_ids + + +@pytest.mark.asyncio +async def test_sequences_unlabeled_latest_risk_score_invalid_value_returns_422( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest?risk_score=bogus", headers=auth) + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_sequences_fromdate_risk_score_moderate_disables_filter( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """``risk_score=moderate`` is the kill switch: every seed seq comes back regardless of max_conf.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + + seeded = [] + for max_conf, minutes_ago in [(0.05, 50), (0.20, 45), (0.55, 30)]: + seq = await _seed_unlabeled_sequence( + detection_session, camera_id, pose_id, max_conf=max_conf, minutes_ago=minutes_ago + ) + seeded.append(seq.id) + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get( + f"/sequences/all/fromdate?from_date={target_date}&limit=200&risk_score=moderate", headers=auth + ) + assert response.status_code == 200, print(response.__dict__) + returned_ids = {item["id"] for item in response.json()} + assert set(seeded).issubset(returned_ids) + + @pytest.mark.asyncio async def test_sequences_fromdate_pagination_filters_before_limit( async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache diff --git a/src/tests/test_main.py b/src/tests/test_main.py new file mode 100644 index 00000000..d4c74082 --- /dev/null +++ b/src/tests/test_main.py @@ -0,0 +1,59 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +from app.main import _seconds_until_next_utc_hour + + +def test_seconds_until_next_utc_hour_future_today(): + fake_now = datetime(2026, 5, 5, 1, 30, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + # datetime.replace + timedelta still need to work — point them at the real type. + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(4) + # 04:00 - 01:30 = 2h30m = 9000s + assert seconds == 2 * 3600 + 30 * 60 + + +def test_seconds_until_next_utc_hour_rolls_to_next_day_when_passed(): + fake_now = datetime(2026, 5, 5, 5, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(4) + # next 04:00 is tomorrow → 23h + assert seconds == 23 * 3600 + + +def test_seconds_until_next_utc_hour_clamps_negative_hour(): + fake_now = datetime(2026, 5, 5, 12, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(-5) # clamped to 0 + # next 00:00 is tomorrow → 12h + assert seconds == 12 * 3600 + + +def test_seconds_until_next_utc_hour_clamps_overflow_hour(): + fake_now = datetime(2026, 5, 5, 12, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(99) # clamped to 23 + # next 23:00 today → 11h + assert seconds == 11 * 3600 + + +def test_seconds_until_next_utc_hour_returns_full_day_when_now_equals_target(): + fake_now = datetime(2026, 5, 5, 4, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(4) + assert seconds == timedelta(days=1).total_seconds() From 712d579db88a6828b5093568a65144755d7e17e2 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 13:14:07 +0200 Subject: [PATCH 28/36] chore: address risk filter review comments --- src/app/api/api_v1/endpoints/alerts.py | 23 ++++++++++++++--------- src/app/models.py | 11 ++++++++--- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index b0e4923d..dff0c20c 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -56,13 +56,13 @@ async def _fetch_sequences_by_alert_ids( return mapping -async def _resolve_class_per_camera( +async def _resolve_fwi_class_per_camera( session: AsyncSession, organization_id: int, target_date: Union[date, None] = None, override_class: Union[str, None] = None, ) -> Dict[int, Union[str, None]]: - """Resolve ``{camera_id: fwi_class}`` for the org, picking override → per-date → today's cache.""" + """Resolve ``{camera_id: fwi_class}`` for the org, picking override -> per-date -> today's cache.""" if override_class is not None: cam_ids = (await session.exec(select(Camera.id).where(Camera.organization_id == organization_id))).all() return dict.fromkeys(cam_ids, override_class) @@ -146,12 +146,17 @@ async def fetch_latest_unlabeled_alerts( ) -> List[AlertReadWithSequences]: telemetry_client.capture(token_payload.sub, event="alerts-fetch-latest") - classes = await _resolve_class_per_camera(session, token_payload.organization_id, override_class=risk_score) - seq_filter = max_conf_filter_clause(classes) + fwi_classes_by_camera = await _resolve_fwi_class_per_camera( + session, token_payload.organization_id, override_class=risk_score + ) + seq_filter = max_conf_filter_clause(fwi_classes_by_camera) - seq_match: Any = select(AlertSequence.alert_id).join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) - seq_match = seq_match.where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) - seq_match = seq_match.where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + seq_match: Any = ( + select(AlertSequence.alert_id) + .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) + .where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) + .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + ) if seq_filter is not None: seq_match = seq_match.where(seq_filter) @@ -185,10 +190,10 @@ async def fetch_alerts_from_date( ) -> List[AlertReadWithSequences]: telemetry_client.capture(token_payload.sub, event="alerts-fetch-from-date") - classes = await _resolve_class_per_camera( + fwi_classes_by_camera = await _resolve_fwi_class_per_camera( session, token_payload.organization_id, target_date=from_date, override_class=risk_score ) - seq_filter = max_conf_filter_clause(classes) + seq_filter = max_conf_filter_clause(fwi_classes_by_camera) alerts_stmt: Any = ( select(Alert) diff --git a/src/app/models.py b/src/app/models.py index 04e231d1..3dea578d 100644 --- a/src/app/models.py +++ b/src/app/models.py @@ -104,9 +104,14 @@ class Sequence(SQLModel, table=True): cone_angle: Union[float, None] = Field(None, nullable=True) started_at: datetime = Field(..., nullable=False) last_seen_at: datetime = Field(..., nullable=False) - # Highest detection confidence ever attached to this sequence. - # Monotonic: never recomputed downward when detections are deleted/reassigned. - max_conf: Union[float, None] = Field(None, nullable=True) + max_conf: Union[float, None] = Field( + None, + nullable=True, + description=( + "Highest detection confidence ever attached to this sequence. " + "Monotonic: not recomputed downward when detections are deleted or reassigned." + ), + ) class Alert(SQLModel, table=True): From 416e2c3e69a84b5f90c59dbf67895ed2d04677f2 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 13:25:44 +0200 Subject: [PATCH 29/36] fix: satisfy mypy for risk filter queries --- src/app/api/api_v1/endpoints/alerts.py | 11 ++++++----- src/app/crud/crud_sequence.py | 2 +- src/app/services/sequence_confidence.py | 2 +- src/app/services/sequence_counts.py | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index dff0c20c..b860ed56 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -151,11 +151,12 @@ async def fetch_latest_unlabeled_alerts( ) seq_filter = max_conf_filter_clause(fwi_classes_by_camera) - seq_match: Any = ( - select(AlertSequence.alert_id) - .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) - .where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) - .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + seq_match: Any = cast( + Any, + select(AlertSequence.alert_id).join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)), + ) + seq_match = ( + seq_match.where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)).where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] ) if seq_filter is not None: seq_match = seq_match.where(seq_filter) diff --git a/src/app/crud/crud_sequence.py b/src/app/crud/crud_sequence.py index 58e5271e..0b1e4e0d 100644 --- a/src/app/crud/crud_sequence.py +++ b/src/app/crud/crud_sequence.py @@ -26,7 +26,7 @@ async def bump_max_conf(self, sequence_id: int, candidate: float) -> None: Uses a portable CASE expression so it runs on SQLite as well as Postgres. """ max_conf_col = cast(Any, Sequence.max_conf) - bumped: Any = case( + bumped: Any = cast(Any, case)( (or_(max_conf_col.is_(None), max_conf_col < candidate), candidate), else_=max_conf_col, ) diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py index 531a1d94..3c3da1ab 100644 --- a/src/app/services/sequence_confidence.py +++ b/src/app/services/sequence_confidence.py @@ -57,7 +57,7 @@ def max_conf_filter_clause(class_per_camera: Dict[int, Union[str, None]]) -> Uni return None max_conf_col = cast(Any, Sequence.max_conf) whens: List[Any] = [(Sequence.camera_id == cid, t) for cid, t in thresholds.items()] - threshold_expr: Any = case(*whens, else_=0.0) + threshold_expr: Any = cast(Any, case)(*whens, else_=0.0) return or_(max_conf_col.is_(None), max_conf_col >= threshold_expr) diff --git a/src/app/services/sequence_counts.py b/src/app/services/sequence_counts.py index 195c8b52..75667c9b 100644 --- a/src/app/services/sequence_counts.py +++ b/src/app/services/sequence_counts.py @@ -17,7 +17,7 @@ async def get_detection_counts_by_sequence_ids(session: AsyncSession, sequence_i return {} stmt: Any = ( - select(cast(Any, Detection.sequence_id), func.count(Detection.id)) + select(cast(Any, Detection.sequence_id), func.count(cast(Any, Detection.id))) .where(cast(Any, Detection.sequence_id).in_(sequence_ids)) .group_by(cast(Any, Detection.sequence_id)) ) From 5da133db96cc0ec94e09236585a4f926c7a134c9 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 13:57:15 +0200 Subject: [PATCH 30/36] test: cover risk refresh lifecycle --- src/tests/test_main.py | 92 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/src/tests/test_main.py b/src/tests/test_main.py index d4c74082..561cf208 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -3,10 +3,29 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +import asyncio +from collections.abc import Generator from datetime import datetime, timedelta, timezone -from unittest.mock import patch +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch -from app.main import _seconds_until_next_utc_hour +import pytest +from fastapi import FastAPI + +from app.main import _risk_refresh_loop, _seconds_until_next_utc_hour, lifespan + + +class _CancelledTask: + def __init__(self) -> None: + self.cancel_called = False + + def cancel(self) -> None: + self.cancel_called = True + + def __await__(self) -> Generator[None, None, None]: + if not self.cancel_called: + yield None + raise asyncio.CancelledError def test_seconds_until_next_utc_hour_future_today(): @@ -57,3 +76,72 @@ def test_seconds_until_next_utc_hour_returns_full_day_when_now_equals_target(): mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) seconds = _seconds_until_next_utc_hour(4) assert seconds == timedelta(days=1).total_seconds() + + +@pytest.mark.asyncio +async def test_risk_refresh_loop_calls_refresh_then_cancels_cleanly(): + sleep_mock = AsyncMock(side_effect=[None, asyncio.CancelledError()]) + refresh_mock = AsyncMock() + + with ( + patch("app.main.asyncio.sleep", sleep_mock), + patch("app.main.risk_service.refresh", new=refresh_mock), + pytest.raises(asyncio.CancelledError), + ): + await _risk_refresh_loop() + + refresh_mock.assert_awaited_once() + assert sleep_mock.await_count == 2 + + +@pytest.mark.asyncio +async def test_risk_refresh_loop_swallows_refresh_errors_and_continues(): + sleep_mock = AsyncMock(side_effect=[None, None, asyncio.CancelledError()]) + refresh_mock = AsyncMock(side_effect=[RuntimeError("boom"), None]) + + with ( + patch("app.main.asyncio.sleep", sleep_mock), + patch("app.main.risk_service.refresh", new=refresh_mock), + patch("app.main.logger.exception") as exception_mock, + pytest.raises(asyncio.CancelledError), + ): + await _risk_refresh_loop() + + assert refresh_mock.await_count == 2 + assert sleep_mock.await_count == 3 + exception_mock.assert_called_once_with("Risk refresh loop iteration failed; continuing") + + +@pytest.mark.asyncio +async def test_lifespan_refreshes_and_cancels_daily_task_when_risk_api_configured(): + fake_service = SimpleNamespace(is_configured=True, refresh=AsyncMock()) + fake_task = _CancelledTask() + + def fake_create_task(coro): + coro.close() + return fake_task + + create_task_mock = MagicMock(side_effect=fake_create_task) + with ( + patch("app.main.risk_service", fake_service), + patch("app.main.asyncio.create_task", create_task_mock), + ): + async with lifespan(FastAPI()): + fake_service.refresh.assert_awaited_once() + create_task_mock.assert_called_once() + assert fake_task.cancel_called is False + + assert fake_task.cancel_called is True + + +@pytest.mark.asyncio +async def test_lifespan_skips_risk_refresh_when_risk_api_not_configured(): + fake_service = SimpleNamespace(is_configured=False, refresh=AsyncMock()) + + with ( + patch("app.main.risk_service", fake_service), + patch("app.main.asyncio.create_task") as create_task_mock, + ): + async with lifespan(FastAPI()): + fake_service.refresh.assert_not_awaited() + create_task_mock.assert_not_called() From 06d540a8d8823e98d5671a9ce33331abdef9f564 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 14:00:21 +0200 Subject: [PATCH 31/36] test: cover sequence risk filter endpoints --- src/tests/endpoints/test_sequences.py | 119 +++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 2 deletions(-) diff --git a/src/tests/endpoints/test_sequences.py b/src/tests/endpoints/test_sequences.py index 1d593382..5860e155 100644 --- a/src/tests/endpoints/test_sequences.py +++ b/src/tests/endpoints/test_sequences.py @@ -1,4 +1,4 @@ -from datetime import timedelta +from datetime import date, timedelta from typing import Any, Dict, List, Union from unittest.mock import AsyncMock, MagicMock, patch @@ -8,13 +8,55 @@ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -from app.api.api_v1.endpoints.sequences import label_sequence +from app.api.api_v1.endpoints.sequences import ( + fetch_latest_unlabeled_sequences, + fetch_sequences_from_date, + label_sequence, +) from app.core.time import utcnow from app.models import Alert, AlertSequence, AnnotationType, Camera, Detection, Pose, Sequence, UserRole from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel +class _ExecResult: + def __init__(self, rows) -> None: + self._rows = rows + + def all(self): + return self._rows + + +class _FakeSequenceSession: + def __init__(self, *results) -> None: + self._results = list(results) + self.statements = [] + + async def exec(self, stmt): + self.statements.append(stmt) + return self._results.pop(0) + + +def _unit_sequence(sequence_id: int = 101, camera_id: int = 1, max_conf: float = 0.75) -> Sequence: + now = utcnow() + return Sequence( + id=sequence_id, + camera_id=camera_id, + pose_id=1, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=5), + max_conf=max_conf, + ) + + +def _unit_token(organization_id: int = 1) -> TokenPayload: + return TokenPayload(sub=1, scopes=[UserRole.USER], organization_id=organization_id) + + @pytest.mark.parametrize( ("sequence_id", "expected_idx", "expected_detections_count"), [ @@ -678,3 +720,76 @@ async def test_unit_label_sequence_forbidden_for_wrong_org(): ) assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_unit_fetch_latest_unlabeled_sequences_uses_risk_score_override(): + sequence = _unit_sequence(sequence_id=201, camera_id=1, max_conf=0.55) + session = _FakeSequenceSession(_ExecResult([1]), _ExecResult([sequence])) + counts_mock = AsyncMock(return_value={sequence.id: 4}) + + with patch("app.api.api_v1.endpoints.sequences.get_detection_counts_by_sequence_ids", new=counts_mock): + result = await fetch_latest_unlabeled_sequences( + risk_score="low", + session=session, + token_payload=_unit_token(), + ) + + assert len(session.statements) == 2 + counts_mock.assert_awaited_once_with(session, [sequence.id]) + assert [item.id for item in result] == [sequence.id] + assert result[0].detections_count == 4 + + +@pytest.mark.asyncio +async def test_unit_fetch_sequences_from_date_gets_risk_scores_for_requested_date(): + target_date = date(2026, 5, 5) + sequence = _unit_sequence(sequence_id=202, camera_id=1, max_conf=0.65) + session = _FakeSequenceSession(_ExecResult([1]), _ExecResult([sequence])) + risk_scores_mock = AsyncMock(return_value={1: "very_low"}) + counts_mock = AsyncMock(return_value={sequence.id: 2}) + + with ( + patch("app.api.api_v1.endpoints.sequences.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.sequences.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_sequences_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score=None, + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_awaited_once_with(target_date, organization_id=1) + counts_mock.assert_awaited_once_with(session, [sequence.id]) + assert [item.id for item in result] == [sequence.id] + assert result[0].detections_count == 2 + + +@pytest.mark.asyncio +async def test_unit_fetch_sequences_from_date_risk_score_override_bypasses_risk_api(): + target_date = date(2026, 5, 5) + sequence = _unit_sequence(sequence_id=203, camera_id=1, max_conf=0.50) + session = _FakeSequenceSession(_ExecResult([1]), _ExecResult([sequence])) + risk_scores_mock = AsyncMock() + counts_mock = AsyncMock(return_value={sequence.id: 1}) + + with ( + patch("app.api.api_v1.endpoints.sequences.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.sequences.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_sequences_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score="low", + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_not_awaited() + counts_mock.assert_awaited_once_with(session, [sequence.id]) + assert [item.id for item in result] == [sequence.id] + assert result[0].detections_count == 1 From e77bd896b337662da2bdac20e3eb9a1d86329486 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 14:02:07 +0200 Subject: [PATCH 32/36] test: cover risk service http paths --- src/tests/services/test_risk.py | 134 ++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/src/tests/services/test_risk.py b/src/tests/services/test_risk.py index cefdded0..5f5085e2 100644 --- a/src/tests/services/test_risk.py +++ b/src/tests/services/test_risk.py @@ -3,12 +3,58 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +from datetime import date +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx import pytest from app.core.config import settings from app.services.risk import FWI_MIN_CONF, RiskService, min_confidence_for_class +def _fake_httpx_client(*, json_data=None, raise_exc=None, side_effect_per_call=None): + """Return a context manager replacing ``app.services.risk.httpx.AsyncClient``. + + Either pass ``json_data`` (single response), ``raise_exc`` (raised on get()), or + ``side_effect_per_call`` (list of (json_or_exc) values returned in order). + """ + response = MagicMock() + response.raise_for_status = MagicMock() + response.json = MagicMock(return_value=json_data) + + inner = MagicMock() + if raise_exc is not None: + inner.get = AsyncMock(side_effect=raise_exc) + elif side_effect_per_call is not None: + responses = [] + for item in side_effect_per_call: + if isinstance(item, Exception): + responses.append(item) + else: + r = MagicMock() + r.raise_for_status = MagicMock() + r.json = MagicMock(return_value=item) + responses.append(r) + inner.get = AsyncMock(side_effect=responses) + else: + inner.get = AsyncMock(return_value=response) + + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=inner) + cm.__aexit__ = AsyncMock(return_value=None) + + factory = MagicMock(return_value=cm) + return factory, inner + + +@pytest.fixture +def configured_risk(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(settings, "RISK_API_URL", "http://risk.test") + monkeypatch.setattr(settings, "RISK_API_LOGIN", "u") + monkeypatch.setattr(settings, "RISK_API_PWD", "p") + + def test_min_confidence_for_class(): assert min_confidence_for_class("very_low") == FWI_MIN_CONF["very_low"] assert min_confidence_for_class("low") == FWI_MIN_CONF["low"] @@ -42,3 +88,91 @@ async def test_refresh_no_op_when_not_configured(monkeypatch: pytest.MonkeyPatch service._scores = {1: "low"} await service.refresh() assert service._scores == {1: "low"} + + +@pytest.mark.asyncio +async def test_refresh_replaces_cache_on_success(configured_risk): + service = RiskService() + service._scores = {99: "very_low"} # stale entry that should disappear + factory, inner = _fake_httpx_client( + json_data=[ + {"id": 1, "fwi_class": "very_low"}, + {"id": 2, "fwi_class": "moderate"}, + ] + ) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {1: "very_low", 2: "moderate"} + inner.get.assert_awaited_once() + args, kwargs = inner.get.await_args + assert args[0].endswith("/cameras") + assert kwargs["auth"] == ("u", "p") + + +@pytest.mark.asyncio +async def test_refresh_keeps_cache_on_http_error(configured_risk): + service = RiskService() + service._scores = {1: "low"} + factory, _ = _fake_httpx_client(raise_exc=httpx.ConnectError("boom")) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {1: "low"} + + +@pytest.mark.asyncio +async def test_refresh_replaces_with_empty_when_payload_is_empty_list(configured_risk): + """Empty list from upstream wipes the cache; only a network/HTTP failure preserves it.""" + service = RiskService() + service._scores = {1: "low"} + factory, _ = _fake_httpx_client(json_data=[]) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {} + + +@pytest.mark.asyncio +async def test_refresh_skips_malformed_rows(configured_risk): + service = RiskService() + factory, _ = _fake_httpx_client( + json_data=[ + {"id": 1, "fwi_class": "low"}, + "not-a-dict", + {"id": "bad", "fwi_class": "low"}, + {"id": 2}, + {"camera_id": 3, "fwi_class": "very_low"}, + ] + ) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {1: "low", 3: "very_low"} + + +@pytest.mark.asyncio +async def test_get_scores_for_date_passes_organization_param(configured_risk): + service = RiskService() + factory, inner = _fake_httpx_client(json_data=[{"camera_id": 7, "fwi_class": "high"}]) + with patch("app.services.risk.httpx.AsyncClient", factory): + scores = await service.get_scores_for_date(date(2026, 5, 5), organization_id=42) + assert scores == {7: "high"} + args, kwargs = inner.get.await_args + assert args[0].endswith("/scores/2026-05-05") + assert kwargs["params"] == {"organization_id": 42} + + +@pytest.mark.asyncio +async def test_get_scores_for_date_returns_empty_on_failure(configured_risk): + service = RiskService() + factory, _ = _fake_httpx_client(raise_exc=httpx.ReadTimeout("slow")) + with patch("app.services.risk.httpx.AsyncClient", factory): + scores = await service.get_scores_for_date(date(2026, 5, 5)) + assert scores == {} + + +@pytest.mark.asyncio +async def test_get_scores_for_date_no_org_param_when_none(configured_risk): + service = RiskService() + factory, inner = _fake_httpx_client(json_data=[]) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.get_scores_for_date(date(2026, 5, 5)) + _, kwargs = inner.get.await_args + assert kwargs["params"] is None From addca3bb658be4bc1758a81eca12b00fdbe09571 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 15:26:52 +0200 Subject: [PATCH 33/36] test: cover alerts/fromdate risk filter and mixed-seq alert --- src/tests/endpoints/test_risk_filter.py | 213 +++++++++++++++++++++++- 1 file changed, 211 insertions(+), 2 deletions(-) diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 3d49ed62..1130123b 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -3,14 +3,17 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from datetime import timedelta +from datetime import date, timedelta +from unittest.mock import AsyncMock, patch import pytest # type: ignore from httpx import AsyncClient from sqlmodel.ext.asyncio.session import AsyncSession +from app.api.api_v1.endpoints.alerts import fetch_alerts_from_date from app.core.time import utcnow -from app.models import Alert, AlertSequence, Sequence +from app.models import Alert, AlertSequence, Sequence, UserRole +from app.schemas.login import TokenPayload from app.services.risk import risk_service @@ -267,6 +270,212 @@ async def test_sequences_fromdate_risk_score_moderate_disables_filter( assert set(seeded).issubset(returned_ids) +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_keeps_alert_with_mixed_seqs( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """An alert mixing one passing and one failing sequence stays, with only the passing seq in payload.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=25) + high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.70, minutes_ago=15) + + now = utcnow() + alert = Alert( + organization_id=2, + started_at=now - timedelta(minutes=30), + last_seen_at=now - timedelta(minutes=10), + ) + detection_session.add(alert) + await detection_session.commit() + await detection_session.refresh(alert) + detection_session.add(AlertSequence(alert_id=alert.id, sequence_id=low_seq.id)) + detection_session.add(AlertSequence(alert_id=alert.id, sequence_id=high_seq.id)) + await detection_session.commit() + + risk_service._scores = {camera_id: "low"} # 0.45 threshold + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + + payload = next((a for a in response.json() if a["id"] == alert.id), None) + assert payload is not None, "alert with at least one passing sequence must remain" + seq_ids = {s["id"] for s in payload["sequences"]} + assert high_seq.id in seq_ids + assert low_seq.id not in seq_ids + + +@pytest.mark.asyncio +async def test_alerts_fromdate_risk_score_override_drops_low_conf_alert( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get(f"/alerts/all/fromdate?from_date={target_date}&risk_score=low", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_fromdate_risk_score_moderate_keeps_alert( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """``risk_score=moderate`` is the kill switch on the from_date endpoint too.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + # Cache says very_low (would drop everything), but the override forces moderate. + risk_service._scores = {camera_id: "very_low"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get(f"/alerts/all/fromdate?from_date={target_date}&risk_score=moderate", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_fromdate_risk_score_invalid_value_returns_422( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + target_date = utcnow().date().isoformat() + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get(f"/alerts/all/fromdate?from_date={target_date}&risk_score=bogus", headers=auth) + assert response.status_code == 422 + + +class _ExecResult: + def __init__(self, rows) -> None: + self._rows = rows + + def all(self): + return self._rows + + +class _FakeAlertsSession: + def __init__(self, *results) -> None: + self._results = list(results) + self.statements = [] + + async def exec(self, stmt): + self.statements.append(stmt) + return self._results.pop(0) + + +def _unit_token(organization_id: int = 1) -> TokenPayload: + return TokenPayload(sub=1, scopes=[UserRole.USER], organization_id=organization_id) + + +@pytest.mark.asyncio +async def test_unit_fetch_alerts_from_date_calls_risk_scores_for_requested_date(): + """The per-date branch must dispatch ``get_scores_for_date`` with the requested date and org.""" + target_date = date(2026, 5, 5) + now = utcnow() + alert = Alert(id=501, organization_id=1, started_at=now, last_seen_at=now) + seq = Sequence( + id=601, + camera_id=1, + pose_id=1, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=5), + max_conf=0.70, + ) + session = _FakeAlertsSession(_ExecResult([alert]), _ExecResult([(alert.id, seq)])) + risk_scores_mock = AsyncMock(return_value={}) # empty -> no SQL filter, simpler stmt order + counts_mock = AsyncMock(return_value={seq.id: 3}) + + with ( + patch("app.api.api_v1.endpoints.alerts.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.alerts.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_alerts_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score=None, + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_awaited_once_with(target_date, organization_id=1) + counts_mock.assert_awaited_once_with(session, [seq.id]) + assert [item.id for item in result] == [alert.id] + assert [s.id for s in result[0].sequences] == [seq.id] + + +@pytest.mark.asyncio +async def test_unit_fetch_alerts_from_date_risk_score_override_bypasses_risk_api(): + """The override branch must NOT hit the risk API and must apply the override class to all org cameras.""" + target_date = date(2026, 5, 5) + now = utcnow() + alert = Alert(id=502, organization_id=1, started_at=now, last_seen_at=now) + seq = Sequence( + id=602, + camera_id=1, + pose_id=1, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=5), + max_conf=0.55, + ) + # Override path -> first exec: select Camera.id, then alerts_stmt, then alert_seq map + session = _FakeAlertsSession( + _ExecResult([1]), + _ExecResult([alert]), + _ExecResult([(alert.id, seq)]), + ) + risk_scores_mock = AsyncMock() + counts_mock = AsyncMock(return_value={seq.id: 1}) + + with ( + patch("app.api.api_v1.endpoints.alerts.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.alerts.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_alerts_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score="low", + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_not_awaited() + assert [item.id for item in result] == [alert.id] + assert [s.id for s in result[0].sequences] == [seq.id] + + @pytest.mark.asyncio async def test_sequences_fromdate_pagination_filters_before_limit( async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache From d12e74ce7a7399964d9c1c062f6da441178bccbf Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 15:29:50 +0200 Subject: [PATCH 34/36] test: parametrize keep-all assertion across moderate/high/very_high/extreme --- src/tests/endpoints/test_risk_filter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 1130123b..68353997 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -95,14 +95,16 @@ async def test_unlabeled_latest_drops_below_very_low_threshold( @pytest.mark.asyncio -async def test_unlabeled_latest_keeps_all_when_class_is_moderate( - async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +@pytest.mark.parametrize("fwi_class", ["moderate", "high", "very_high", "extreme"]) +async def test_unlabeled_latest_keeps_all_when_class_is_moderate_or_above( + fwi_class: str, async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache ): + """No filter is applied for ``moderate`` and above; pin every class so future tweaks stay covered.""" camera_id = pytest.camera_table[0]["id"] pose_id = pytest.pose_table[0]["id"] low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=30) - risk_service._scores = {camera_id: "moderate"} + risk_service._scores = {camera_id: fwi_class} auth = pytest.get_token( pytest.user_table[0]["id"], From 2b5e9f2fdbea22df5117105c2fd18027e8b0d2e7 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 15:34:44 +0200 Subject: [PATCH 35/36] test: pin fail-open on null max_conf and unknown cameras --- src/tests/endpoints/test_risk_filter.py | 59 +++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 68353997..6d146f60 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -116,6 +116,65 @@ async def test_unlabeled_latest_keeps_all_when_class_is_moderate_or_above( assert low_seq.id in {item["id"] for item in response.json()} +@pytest.mark.asyncio +async def test_unlabeled_latest_keeps_seq_with_null_max_conf_under_filter( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """Sequences with NULL max_conf (legacy / unparseable bbox) must pass even under an active filter.""" + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + null_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=None, minutes_ago=20) # type: ignore[arg-type] + + risk_service._scores = {camera_id: "low"} # 0.45 threshold, would normally drop + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert null_seq.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_unlabeled_latest_keeps_seq_for_camera_unknown_to_risk_api( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """A camera absent from the risk-api cache stays unfiltered even when sibling cameras are filtered.""" + known_cam = pytest.camera_table[0]["id"] + unknown_cam = pytest.camera_table[1]["id"] + known_pose = pytest.pose_table[0]["id"] + unknown_pose = pytest.pose_table[2]["id"] + + # Cache only knows about ``known_cam`` and flags it ``low`` (0.45 threshold). + # ``unknown_cam`` has no entry -> CASE else_=0.0 -> any max_conf passes. + risk_service._scores = {known_cam: "low"} + + known_dropped = await _seed_unlabeled_sequence(detection_session, known_cam, known_pose, max_conf=0.20) + unknown_kept = await _seed_unlabeled_sequence(detection_session, unknown_cam, unknown_pose, max_conf=0.10) + + # known_cam belongs to org 1; unknown_cam belongs to org 2 -> query both orgs. + auth_org1 = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + auth_org2 = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + + resp1 = await async_client.get("/sequences/unlabeled/latest", headers=auth_org1) + assert resp1.status_code == 200, print(resp1.__dict__) + assert known_dropped.id not in {item["id"] for item in resp1.json()} + + resp2 = await async_client.get("/sequences/unlabeled/latest", headers=auth_org2) + assert resp2.status_code == 200, print(resp2.__dict__) + assert unknown_kept.id in {item["id"] for item in resp2.json()} + + async def _seed_alert_with_sequence(session: AsyncSession, organization_id: int, seq: Sequence) -> Alert: now = utcnow() alert = Alert( From ce064b859803c635531a8ede36fec6701d2d4a03 Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 5 May 2026 15:39:44 +0200 Subject: [PATCH 36/36] test: parametrize alerts no-filter override across moderate/high/very_high/extreme --- src/tests/endpoints/test_risk_filter.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py index 6d146f60..b2260e10 100644 --- a/src/tests/endpoints/test_risk_filter.py +++ b/src/tests/endpoints/test_risk_filter.py @@ -234,15 +234,17 @@ async def test_alerts_unlabeled_latest_risk_score_override( @pytest.mark.asyncio -async def test_alerts_unlabeled_latest_risk_score_moderate_keeps_everything( - async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +@pytest.mark.parametrize("fwi_class", ["moderate", "high", "very_high", "extreme"]) +async def test_alerts_unlabeled_latest_risk_score_moderate_or_above_keeps_everything( + fwi_class: str, async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache ): + """Override at any moderate-or-above class disables the filter even when the cache would drop the alert.""" camera_id = pytest.camera_table[1]["id"] pose_id = pytest.pose_table[2]["id"] seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=20) alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) - # Cache says very_low (would drop everything), but the override forces moderate. + # Cache says very_low (would drop everything), but the override forces a no-filter class. risk_service._scores = {camera_id: "very_low"} auth = pytest.get_token( @@ -250,7 +252,7 @@ async def test_alerts_unlabeled_latest_risk_score_moderate_keeps_everything( pytest.user_table[2]["role"].split(), pytest.user_table[2]["organization_id"], ) - response = await async_client.get("/alerts/unlabeled/latest?risk_score=moderate", headers=auth) + response = await async_client.get(f"/alerts/unlabeled/latest?risk_score={fwi_class}", headers=auth) assert response.status_code == 200, print(response.__dict__) assert alert.id in {item["id"] for item in response.json()} @@ -392,17 +394,18 @@ async def test_alerts_fromdate_risk_score_override_drops_low_conf_alert( @pytest.mark.asyncio -async def test_alerts_fromdate_risk_score_moderate_keeps_alert( - async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +@pytest.mark.parametrize("fwi_class", ["moderate", "high", "very_high", "extreme"]) +async def test_alerts_fromdate_risk_score_moderate_or_above_keeps_alert( + fwi_class: str, async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache ): - """``risk_score=moderate`` is the kill switch on the from_date endpoint too.""" + """No-filter override (``moderate`` and above) keeps the alert on ``/alerts/all/fromdate`` too.""" camera_id = pytest.camera_table[1]["id"] pose_id = pytest.pose_table[2]["id"] target_date = utcnow().date().isoformat() seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=20) alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) - # Cache says very_low (would drop everything), but the override forces moderate. + # Cache says very_low (would drop everything), but the override forces a no-filter class. risk_service._scores = {camera_id: "very_low"} auth = pytest.get_token( @@ -410,7 +413,9 @@ async def test_alerts_fromdate_risk_score_moderate_keeps_alert( pytest.user_table[2]["role"].split(), pytest.user_table[2]["organization_id"], ) - response = await async_client.get(f"/alerts/all/fromdate?from_date={target_date}&risk_score=moderate", headers=auth) + response = await async_client.get( + f"/alerts/all/fromdate?from_date={target_date}&risk_score={fwi_class}", headers=auth + ) assert response.status_code == 200, print(response.__dict__) assert alert.id in {item["id"] for item in response.json()}