diff --git a/.env.example b/.env.example index 7cbc4733..84aecc0d 100644 --- a/.env.example +++ b/.env.example @@ -23,6 +23,12 @@ POSTHOG_KEY= SUPPORT_EMAIL= TELEGRAM_TOKEN= +# Risk API (daily fire-weather index per camera) +RISK_API_URL= +RISK_API_LOGIN= +RISK_API_PWD= +RISK_REFRESH_HOUR_UTC=4 + # Production-only ACME_EMAIL= BACKEND_HOST= diff --git a/docker-compose.yml b/docker-compose.yml index 6088b9d1..a1b79277 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -65,6 +65,10 @@ services: - S3_PROXY_URL=${S3_PROXY_URL} - SERVER_NAME=${SERVER_NAME} - PLATFORM_URL=${PLATFORM_URL:-https://platform.pyronear.org} + - RISK_API_URL=${RISK_API_URL} + - RISK_API_LOGIN=${RISK_API_LOGIN} + - RISK_API_PWD=${RISK_API_PWD} + - RISK_REFRESH_HOUR_UTC=${RISK_REFRESH_HOUR_UTC:-4} volumes: - ./src/:/app/ command: "sh -c 'alembic upgrade head && python app/db.py && uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers'" diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 7cf64074..b860ed56 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status from sqlalchemy import asc, desc +from sqlalchemy.sql import ColumnElement from sqlmodel import delete, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -16,10 +17,12 @@ from app.core.time import utcnow from app.crud import AlertCRUD from app.db import get_session -from app.models import Alert, AlertSequence, Sequence, UserRole +from app.models import Alert, AlertSequence, Camera, Sequence, UserRole from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead +from app.services.risk import FwiClass, risk_service +from app.services.sequence_confidence import max_conf_filter_clause from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.telemetry import telemetry_client @@ -31,7 +34,11 @@ def verify_org_rights(organization_id: int, alert: Alert) -> None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden.") -async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[int]) -> Dict[int, List[Sequence]]: +async def _fetch_sequences_by_alert_ids( + session: AsyncSession, + alert_ids: List[int], + seq_filter: Union[ColumnElement[bool], None] = None, +) -> Dict[int, List[Sequence]]: mapping: Dict[int, List[Sequence]] = {} if not alert_ids: return mapping @@ -39,14 +46,32 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i select(AlertSequence.alert_id, Sequence) .join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) .where(AlertSequence.alert_id.in_(alert_ids)) # type: ignore[attr-defined] - .order_by(cast(Any, AlertSequence.alert_id), desc(cast(Any, Sequence.last_seen_at))) ) + if seq_filter is not None: + seq_stmt = seq_stmt.where(seq_filter) + seq_stmt = seq_stmt.order_by(cast(Any, AlertSequence.alert_id), desc(cast(Any, Sequence.last_seen_at))) res = await session.exec(seq_stmt) for alert_id, sequence in res.all(): mapping.setdefault(int(alert_id), []).append(sequence) return mapping +async def _resolve_fwi_class_per_camera( + session: AsyncSession, + organization_id: int, + target_date: Union[date, None] = None, + override_class: Union[str, None] = None, +) -> Dict[int, Union[str, None]]: + """Resolve ``{camera_id: fwi_class}`` for the org, picking override -> per-date -> today's cache.""" + if override_class is not None: + cam_ids = (await session.exec(select(Camera.id).where(Camera.organization_id == organization_id))).all() + return dict.fromkeys(cam_ids, override_class) + if target_date is not None: + scores = await risk_service.get_scores_for_date(target_date, organization_id=organization_id) + return {cid: cls for cid, cls in scores.items()} + return {cid: cls for cid, cls in risk_service.scores().items()} + + def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead: return SequenceRead(**sequence.model_dump(), detections_count=detections_count) @@ -113,24 +138,39 @@ async def fetch_alert_sequences( summary="Fetch all the alerts with unlabeled sequences from the last 24 hours", ) async def fetch_latest_unlabeled_alerts( + risk_score: Union[FwiClass, None] = Query( + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[AlertReadWithSequences]: telemetry_client.capture(token_payload.sub, event="alerts-fetch-latest") - alerts_stmt: Any = select(Alert).join(AlertSequence, cast(Any, AlertSequence.alert_id == Alert.id)) - alerts_stmt = alerts_stmt.join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)) - alerts_stmt = ( - alerts_stmt.where(Alert.organization_id == token_payload.organization_id) - .where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)) - .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + fwi_classes_by_camera = await _resolve_fwi_class_per_camera( + session, token_payload.organization_id, override_class=risk_score + ) + seq_filter = max_conf_filter_clause(fwi_classes_by_camera) + + seq_match: Any = cast( + Any, + select(AlertSequence.alert_id).join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)), + ) + seq_match = ( + seq_match.where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)).where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + ) + if seq_filter is not None: + seq_match = seq_match.where(seq_filter) + + alerts_stmt: Any = ( + select(Alert) + .where(Alert.organization_id == token_payload.organization_id) + .where(cast(Any, Alert.id).in_(seq_match)) .order_by(Alert.started_at.desc()) # type: ignore[attr-defined] .limit(15) ) - alerts_res = await session.exec(alerts_stmt) - alerts = alerts_res.unique().all() + alerts = list((await session.exec(alerts_stmt)).all()) alert_ids = [alert.id for alert in alerts] - seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) + seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids, seq_filter) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), @@ -143,23 +183,35 @@ async def fetch_alerts_from_date( from_date: date = Query(), limit: Union[int, None] = Query(15, description="Maximum number of alerts to fetch"), offset: Union[int, None] = Query(0, description="Number of alerts to skip before starting to fetch"), + risk_score: Union[FwiClass, None] = Query( + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[AlertReadWithSequences]: telemetry_client.capture(token_payload.sub, event="alerts-fetch-from-date") + fwi_classes_by_camera = await _resolve_fwi_class_per_camera( + session, token_payload.organization_id, target_date=from_date, override_class=risk_score + ) + seq_filter = max_conf_filter_clause(fwi_classes_by_camera) + alerts_stmt: Any = ( select(Alert) .where(Alert.organization_id == token_payload.organization_id) .where(func.date(Alert.started_at) == from_date) - .order_by(Alert.started_at.desc()) # type: ignore[attr-defined] - .limit(limit) - .offset(offset) ) - alerts_res = await session.exec(alerts_stmt) - alerts = alerts_res.all() + if seq_filter is not None: + seq_match: Any = select(AlertSequence.alert_id).join( + Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id) + ) + seq_match = seq_match.where(seq_filter) + alerts_stmt = alerts_stmt.where(cast(Any, Alert.id).in_(seq_match)) + alerts_stmt = alerts_stmt.order_by(Alert.started_at.desc()).limit(limit).offset(offset) # type: ignore[attr-defined] + + alerts = list((await session.exec(alerts_stmt)).all()) alert_ids = [alert.id for alert in alerts] - seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) + seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids, seq_filter) detection_counts = await get_detection_counts_by_sequence_ids( session, list({sequence.id for sequences in seq_map.values() for sequence in sequences}), diff --git a/src/app/api/api_v1/endpoints/detections.py b/src/app/api/api_v1/endpoints/detections.py index 60ae87ed..a8364b96 100644 --- a/src/app/api/api_v1/endpoints/detections.py +++ b/src/app/api/api_v1/endpoints/detections.py @@ -5,6 +5,7 @@ import json +import logging import re from ast import literal_eval from datetime import datetime, timedelta @@ -56,11 +57,15 @@ from app.schemas.sequences import SequenceUpdate from app.services.cones import resolve_cone from app.services.overlap import compute_overlap, haversine_km +from app.services.risk import risk_service +from app.services.sequence_confidence import max_conf_from_bboxes from app.services.slack import slack_client from app.services.storage import s3_service, upload_file from app.services.telegram import telegram_client from app.services.telemetry import telemetry_client +logger = logging.getLogger("uvicorn.error") + router = APIRouter() @@ -427,6 +432,10 @@ async def create_detection( if matched_sequence is not None: await sequences.update(matched_sequence.id, SequenceUpdate(last_seen_at=det.created_at)) det = await detections.update(det.id, DetectionSequence(sequence_id=matched_sequence.id)) + # Only the primary bbox tracks the sequence; siblings in others_bboxes are unrelated detections. + det_max_conf = max_conf_from_bboxes(det.bbox) + if det_max_conf is not None: + await sequences.bump_max_conf(matched_sequence.id, det_max_conf) else: det_filters: List[tuple[str, Any]] = [ ("camera_id", token_payload.sub), @@ -455,6 +464,7 @@ async def create_detection( if len(overlapping_dets) >= settings.SEQUENCE_MIN_INTERVAL_DETS: first_det = min(overlapping_dets, key=lambda item: item.created_at) cone_azimuth, cone_angle = resolve_cone(pose.azimuth, first_det.bbox, camera.angle_of_view) + seq_max_conf = max_conf_from_bboxes(*[d.bbox for d in overlapping_dets]) sequence_ = await sequences.create( Sequence( camera_id=token_payload.sub, @@ -464,6 +474,7 @@ async def create_detection( cone_angle=cone_angle, started_at=first_det.created_at, last_seen_at=det.created_at, + max_conf=seq_max_conf, ) ) for det_ in overlapping_dets: @@ -490,11 +501,20 @@ async def create_detection( if org is None: org = cast(Organization, await organizations.get(token_payload.organization_id, strict=True)) if org.slack_hook: - slack_payload = jsonable_encoder(det) - slack_payload["sequence_azimuth"] = sequence_.sequence_azimuth - background_tasks.add_task( - slack_client.notify, org.slack_hook, json.dumps(slack_payload), camera.name, alert_id - ) + min_conf = risk_service.min_confidence(camera.id) + if min_conf is None or sequence_.max_conf is None or sequence_.max_conf >= min_conf: + slack_payload = jsonable_encoder(det) + slack_payload["sequence_azimuth"] = sequence_.sequence_azimuth + background_tasks.add_task( + slack_client.notify, org.slack_hook, json.dumps(slack_payload), camera.name, alert_id + ) + else: + logger.info( + "Skipping Slack notification for camera %s: max conf %.3f < threshold %.3f", + camera.name, + sequence_.max_conf, + min_conf, + ) created.append(det) diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index edacaa46..8bc8c7e5 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -22,6 +22,8 @@ from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel, SequenceRead from app.services.overlap import compute_overlap +from app.services.risk import FwiClass, risk_service +from app.services.sequence_confidence import max_conf_filter_clause from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.storage import s3_service from app.services.telemetry import telemetry_client @@ -146,22 +148,32 @@ async def fetch_sequence_detections( summary="Fetch all the unlabeled sequences from the last 24 hours", ) async def fetch_latest_unlabeled_sequences( + risk_score: Union[FwiClass, None] = Query( + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[SequenceRead]: telemetry_client.capture(token_payload.sub, event="sequence-fetch-latest") - camera_ids = await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) - - fetched_sequences = ( - await session.exec( - select(Sequence) - .where(Sequence.started_at > utcnow() - timedelta(hours=24)) - .where(Sequence.camera_id.in_(camera_ids.all())) # type: ignore[attr-defined] - .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] - .order_by(Sequence.started_at.desc()) # type: ignore[attr-defined] - .limit(15) - ) + camera_ids = ( + await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) ).all() + classes: dict[int, Union[str, None]] = ( + dict.fromkeys(camera_ids, risk_score) if risk_score is not None else dict(risk_service.scores()) + ) + + stmt: Any = ( + select(Sequence) + .where(Sequence.started_at > utcnow() - timedelta(hours=24)) + .where(Sequence.camera_id.in_(camera_ids)) # type: ignore[attr-defined] + .where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr] + ) + seq_filter = max_conf_filter_clause(classes) + if seq_filter is not None: + stmt = stmt.where(seq_filter) + stmt = stmt.order_by(Sequence.started_at.desc()).limit(15) # type: ignore[attr-defined] + + fetched_sequences = (await session.exec(stmt)).all() counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] @@ -171,23 +183,32 @@ async def fetch_sequences_from_date( from_date: date = Query(), limit: Union[int, None] = Query(15, description="Maximum number of sequences to fetch"), offset: Union[int, None] = Query(0, description="Number of sequences to skip before starting to fetch"), + risk_score: Union[FwiClass, None] = Query( + None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." + ), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[SequenceRead]: telemetry_client.capture(token_payload.sub, event="sequence-fetch-from-date") # Limit to cameras in the same organization - camera_ids = await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) - # Identify the sequences from that day - fetched_sequences = ( - await session.exec( - select(Sequence) - .where(func.date(Sequence.started_at) == from_date) - .where(Sequence.camera_id.in_(camera_ids.all())) # type: ignore[attr-defined] - .order_by(Sequence.started_at.desc()) # type: ignore[attr-defined] - .limit(limit) - .offset(offset) - ) + camera_ids = ( + await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id)) ).all() + if risk_score is not None: + classes: dict[int, Union[str, None]] = dict.fromkeys(camera_ids, risk_score) + else: + scores = await risk_service.get_scores_for_date(from_date, organization_id=token_payload.organization_id) + classes = dict(scores) + + stmt: Any = ( + select(Sequence).where(func.date(Sequence.started_at) == from_date).where(Sequence.camera_id.in_(camera_ids)) # type: ignore[attr-defined] + ) + seq_filter = max_conf_filter_clause(classes) + if seq_filter is not None: + stmt = stmt.where(seq_filter) + stmt = stmt.order_by(Sequence.started_at.desc()).limit(limit).offset(offset) # type: ignore[attr-defined] + + fetched_sequences = (await session.exec(stmt)).all() counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] diff --git a/src/app/core/config.py b/src/app/core/config.py index f4c0f3e5..08b2e57e 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -77,6 +77,12 @@ def sqlachmey_uri(cls, v: str) -> str: TELEGRAM_TOKEN: Union[str, None] = os.environ.get("TELEGRAM_TOKEN") PLATFORM_URL: str = os.environ.get("PLATFORM_URL", "") + # Risk API (daily fire-weather index per camera) + RISK_API_URL: Union[str, None] = os.environ.get("RISK_API_URL") + RISK_API_LOGIN: Union[str, None] = os.environ.get("RISK_API_LOGIN") + RISK_API_PWD: Union[str, None] = os.environ.get("RISK_API_PWD") + RISK_REFRESH_HOUR_UTC: int = int(os.environ.get("RISK_REFRESH_HOUR_UTC") or 4) + # Error monitoring SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN") SERVER_NAME: str = os.environ.get("SERVER_NAME", socket.gethostname()) diff --git a/src/app/crud/crud_sequence.py b/src/app/crud/crud_sequence.py index ed622e06..0b1e4e0d 100644 --- a/src/app/crud/crud_sequence.py +++ b/src/app/crud/crud_sequence.py @@ -4,8 +4,9 @@ # See LICENSE or go to for full license details. -from typing import Union +from typing import Any, Union, cast +from sqlalchemy import case, or_, update from sqlmodel.ext.asyncio.session import AsyncSession from app.crud.base import BaseCRUD @@ -18,3 +19,17 @@ class SequenceCRUD(BaseCRUD[Sequence, Sequence, Union[SequenceUpdate, SequenceLabel]]): def __init__(self, session: AsyncSession) -> None: super().__init__(session, Sequence) + + async def bump_max_conf(self, sequence_id: int, candidate: float) -> None: + """Atomically raise sequences.max_conf to candidate if higher (or set if NULL). + + Uses a portable CASE expression so it runs on SQLite as well as Postgres. + """ + max_conf_col = cast(Any, Sequence.max_conf) + bumped: Any = cast(Any, case)( + (or_(max_conf_col.is_(None), max_conf_col < candidate), candidate), + else_=max_conf_col, + ) + stmt: Any = update(Sequence).where(cast(Any, Sequence.id) == sequence_id).values(max_conf=bumped) + await self.session.exec(stmt) + await self.session.commit() diff --git a/src/app/main.py b/src/app/main.py index 63065c1f..b20bd155 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -3,8 +3,12 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +import asyncio +import contextlib import logging import time +from contextlib import asynccontextmanager +from datetime import datetime, timedelta, timezone import sentry_sdk from fastapi import FastAPI, Request, status @@ -19,6 +23,7 @@ from app.api.api_v1.router import api_router from app.core.config import settings from app.schemas.base import Status +from app.services.risk import risk_service logger = logging.getLogger("uvicorn.error") @@ -40,6 +45,43 @@ logger.info(f"Sentry middleware enabled on server {settings.SERVER_NAME}") +def _seconds_until_next_utc_hour(target_hour: int) -> float: + hour = max(0, min(23, target_hour)) + now = datetime.now(tz=timezone.utc) + target = now.replace(hour=hour, minute=0, second=0, microsecond=0) + if target <= now: + target += timedelta(days=1) + return (target - now).total_seconds() + + +async def _risk_refresh_loop() -> None: + while True: + try: + await asyncio.sleep(_seconds_until_next_utc_hour(settings.RISK_REFRESH_HOUR_UTC)) + await risk_service.refresh() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Risk refresh loop iteration failed; continuing") + + +@asynccontextmanager +async def lifespan(_: FastAPI): + if risk_service.is_configured: + await risk_service.refresh() + task = asyncio.create_task(_risk_refresh_loop()) + else: + task = None + logger.info("Risk API not configured; skipping daily refresh") + try: + yield + finally: + if task is not None: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + app = FastAPI( title=settings.PROJECT_NAME, description=settings.PROJECT_DESCRIPTION, @@ -47,6 +89,7 @@ version=settings.VERSION, openapi_url=f"{settings.API_V1_STR}/openapi.json", docs_url=None, + lifespan=lifespan, ) diff --git a/src/app/models.py b/src/app/models.py index c9eed1b1..3dea578d 100644 --- a/src/app/models.py +++ b/src/app/models.py @@ -104,6 +104,14 @@ class Sequence(SQLModel, table=True): cone_angle: Union[float, None] = Field(None, nullable=True) started_at: datetime = Field(..., nullable=False) last_seen_at: datetime = Field(..., nullable=False) + max_conf: Union[float, None] = Field( + None, + nullable=True, + description=( + "Highest detection confidence ever attached to this sequence. " + "Monotonic: not recomputed downward when detections are deleted or reassigned." + ), + ) class Alert(SQLModel, table=True): diff --git a/src/app/services/risk.py b/src/app/services/risk.py new file mode 100644 index 00000000..21ccc245 --- /dev/null +++ b/src/app/services/risk.py @@ -0,0 +1,117 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import logging +from datetime import date +from typing import Literal, Union + +import httpx + +from app.core.config import settings + +logger = logging.getLogger("uvicorn.error") + +__all__ = ["FWI_MIN_CONF", "FwiClass", "min_confidence_for_class", "risk_service"] + +# FWI classes accepted by the risk-api and as a manual ``risk_score`` override. +FwiClass = Literal["very_low", "low", "moderate", "high", "very_high", "extreme"] + +# Minimum sequence ``max_conf`` required per FWI class. Zero or absent → no filter. +# All EFFIS classes are listed even when unused so the table stays explicit and easy to tune. +FWI_MIN_CONF: dict[str, float] = { + "very_low": 0.6, + "low": 0.45, + "moderate": 0.0, + "high": 0.0, + "very_high": 0.0, + "extreme": 0.0, +} + + +def min_confidence_for_class(fwi_class: Union[str, None]) -> Union[float, None]: + """Return the min confidence required for this FWI class, or None if no filter applies.""" + if not fwi_class: + return None + threshold = FWI_MIN_CONF.get(fwi_class.strip().lower().replace(" ", "_")) + return threshold or None + + +def _parse_scores_payload(payload: object) -> dict[int, str]: + """Pull ``{camera_id: fwi_class}`` from a list of risk-api items. Skip malformed rows.""" + if not isinstance(payload, list): + return {} + scores: dict[int, str] = {} + for item in payload: + if not isinstance(item, dict): + continue + cid = item.get("id") or item.get("camera_id") + fwi = item.get("fwi_class") + if isinstance(cid, int) and isinstance(fwi, str): + scores[cid] = fwi + return scores + + +class RiskService: + """In-memory cache of the daily FWI class per camera. + + Refreshed once at startup and once a day from the pyro-risk-api service. + Fail-open: if the API is unreachable or a camera is unknown, no filter is applied. + """ + + def __init__(self) -> None: + self._scores: dict[int, str] = {} + + @property + def is_configured(self) -> bool: + return bool(settings.RISK_API_URL and settings.RISK_API_LOGIN and settings.RISK_API_PWD) + + def class_for_camera(self, camera_id: int) -> Union[str, None]: + """Return today's cached FWI class for a camera, or None if unknown.""" + return self._scores.get(camera_id) + + def scores(self) -> dict[int, str]: + """Return a copy of the full ``{camera_id: fwi_class}`` cache.""" + return dict(self._scores) + + def min_confidence(self, camera_id: int) -> Union[float, None]: + """Return the min confidence required for this camera (today), or None if no filter.""" + return min_confidence_for_class(self._scores.get(camera_id)) + + async def _fetch(self, path: str, params: Union[dict, None] = None) -> object: + if not self.is_configured: + return None + host = settings.RISK_API_URL.rstrip("/") # type: ignore[union-attr] + try: + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.get( + f"{host}/{path.lstrip('/')}", + params=params, + auth=(settings.RISK_API_LOGIN, settings.RISK_API_PWD), # type: ignore[arg-type] + ) + response.raise_for_status() + return response.json() + except (httpx.HTTPError, ValueError) as exc: + logger.warning("Risk API call %s failed (%s)", path, exc) + return None + + async def refresh(self) -> None: + """Fetch fresh FWI classes from the risk API. On network/HTTP failure, keep the previous cache.""" + payload = await self._fetch("cameras") + if payload is None: + logger.warning("Risk API refresh failed; keeping previous cache (%d entries)", len(self._scores)) + return + self._scores = _parse_scores_payload(payload) + logger.info("Risk API refresh: cached FWI class for %d camera(s)", len(self._scores)) + + async def get_scores_for_date(self, target_date: date, organization_id: Union[int, None] = None) -> dict[int, str]: + """Fetch persisted FWI classes for a specific date, optionally scoped to an organization. + + Returns {} on error or when not configured. + """ + params = {"organization_id": organization_id} if organization_id is not None else None + return _parse_scores_payload(await self._fetch(f"scores/{target_date.isoformat()}", params=params)) + + +risk_service = RiskService() diff --git a/src/app/services/sequence_confidence.py b/src/app/services/sequence_confidence.py new file mode 100644 index 00000000..3c3da1ab --- /dev/null +++ b/src/app/services/sequence_confidence.py @@ -0,0 +1,82 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +import logging +import re +from ast import literal_eval +from typing import Any, Dict, List, Union, cast +from typing import Sequence as TypingSequence + +from sqlalchemy import case, or_ +from sqlalchemy.sql import ColumnElement + +from app.models import Sequence +from app.schemas.detections import BOX_PATTERN +from app.services.risk import min_confidence_for_class + +logger = logging.getLogger("uvicorn.error") + +__all__ = ["filter_by_class_per_camera", "max_conf_filter_clause", "max_conf_from_bboxes"] + + +def max_conf_from_bboxes(*bbox_strings: Union[str, None]) -> Union[float, None]: + """Return the highest confidence found across the given bbox strings, or None if none parse.""" + best: Union[float, None] = None + for raw in bbox_strings: + if not raw: + continue + for match in re.finditer(BOX_PATTERN, raw): + try: + bbox = literal_eval(match.group(0)) + except (SyntaxError, ValueError): + continue + if not isinstance(bbox, tuple) or len(bbox) != 5: + continue + conf = bbox[4] + if not isinstance(conf, (int, float)): + continue + if best is None or conf > best: + best = float(conf) + return best + + +def max_conf_filter_clause(class_per_camera: Dict[int, Union[str, None]]) -> Union[ColumnElement[bool], None]: + """SQL ``WHERE`` clause keeping sequences whose ``max_conf`` passes their camera's threshold. + + Returns ``None`` when no camera has an active threshold (caller should skip the filter). + Fail-open: rows with ``max_conf IS NULL`` are kept, and cameras absent from ``class_per_camera`` + default to threshold 0 (everything passes) via the ``CASE`` ``else_`` clause. + """ + thresholds = { + cid: t for cid, klass in class_per_camera.items() if (t := min_confidence_for_class(klass)) is not None + } + if not thresholds: + return None + max_conf_col = cast(Any, Sequence.max_conf) + whens: List[Any] = [(Sequence.camera_id == cid, t) for cid, t in thresholds.items()] + threshold_expr: Any = cast(Any, case)(*whens, else_=0.0) + return or_(max_conf_col.is_(None), max_conf_col >= threshold_expr) + + +def filter_by_class_per_camera( + sequences: TypingSequence[Sequence], + class_per_camera: Dict[int, Union[str, None]], +) -> List[Sequence]: + """Drop sequences whose stored ``max_conf`` falls below the threshold for their camera's FWI class. + + Fail-open: a sequence is kept when its camera has no FWI class (moderate+ or unknown) + or when ``seq.max_conf`` is NULL. + """ + if not sequences: + return [] + thresholds = {cid: min_confidence_for_class(c) for cid, c in class_per_camera.items()} + if all(t is None for t in thresholds.values()): + return list(sequences) + return [ + seq + for seq in sequences + if (t := thresholds.get(seq.camera_id)) is None or seq.max_conf is None or seq.max_conf >= t + ] diff --git a/src/app/services/sequence_counts.py b/src/app/services/sequence_counts.py index 195c8b52..75667c9b 100644 --- a/src/app/services/sequence_counts.py +++ b/src/app/services/sequence_counts.py @@ -17,7 +17,7 @@ async def get_detection_counts_by_sequence_ids(session: AsyncSession, sequence_i return {} stmt: Any = ( - select(cast(Any, Detection.sequence_id), func.count(Detection.id)) + select(cast(Any, Detection.sequence_id), func.count(cast(Any, Detection.id))) .where(cast(Any, Detection.sequence_id).in_(sequence_ids)) .group_by(cast(Any, Detection.sequence_id)) ) diff --git a/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py new file mode 100644 index 00000000..3fa62461 --- /dev/null +++ b/src/migrations/versions/2026_05_05_0930-b3d8a9c1e2f4_add_max_conf_to_sequences.py @@ -0,0 +1,74 @@ +"""add max_conf column to sequences and backfill from detections + +Revision ID: b3d8a9c1e2f4 +Revises: a1b2c3d4e5f6 +Create Date: 2026-05-05 09:30:00.000000 + +""" + +import logging +import re +from ast import literal_eval +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b3d8a9c1e2f4" +down_revision: Union[str, None] = "a1b2c3d4e5f6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger("alembic.runtime.migration") + +_FLOAT = r"-?\d+(?:\.\d+)?|-?\.\d+" +_BOX_PATTERN = rf"\({_FLOAT},{_FLOAT},{_FLOAT},{_FLOAT},{_FLOAT}\)" + + +def _max_conf(*bbox_strings: Union[str, None]) -> Union[float, None]: + best: Union[float, None] = None + for raw in bbox_strings: + if not raw: + continue + for match in re.finditer(_BOX_PATTERN, raw): + try: + bbox = literal_eval(match.group(0)) + except (SyntaxError, ValueError): + continue + if not isinstance(bbox, tuple) or len(bbox) != 5: + continue + conf = bbox[4] + if not isinstance(conf, (int, float)): + continue + if best is None or conf > best: + best = float(conf) + return best + + +def upgrade() -> None: + op.add_column("sequences", sa.Column("max_conf", sa.Float(), nullable=True)) + + bind = op.get_bind() + # Only the primary bbox tracks the sequence; siblings in others_bboxes are unrelated detections. + rows = bind.execute(sa.text("SELECT sequence_id, bbox FROM detections WHERE sequence_id IS NOT NULL")).fetchall() + + seq_max: dict[int, float] = {} + for sequence_id, bbox in rows: + conf = _max_conf(bbox) + if conf is None: + continue + current = seq_max.get(sequence_id) + if current is None or conf > current: + seq_max[sequence_id] = conf + + if seq_max: + bind.execute( + sa.text("UPDATE sequences SET max_conf = :conf WHERE id = :sid"), + [{"sid": sid, "conf": conf} for sid, conf in seq_max.items()], + ) + logger.info("Backfilled max_conf for %d sequence(s)", len(seq_max)) + + +def downgrade() -> None: + op.drop_column("sequences", "max_conf") diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 91851241..aef86314 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -199,6 +199,7 @@ "cone_angle": 54.8, "started_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), "last_seen_at": datetime.strptime("2023-11-07T15:28:19.226673", dt_format), + "max_conf": None, }, { "id": 2, @@ -210,6 +211,7 @@ "cone_angle": 54.8, "started_at": datetime.strptime("2023-11-07T16:08:19.226673", dt_format), "last_seen_at": datetime.strptime("2023-11-07T16:08:19.226673", dt_format), + "max_conf": None, }, ] diff --git a/src/tests/endpoints/test_risk_filter.py b/src/tests/endpoints/test_risk_filter.py new file mode 100644 index 00000000..b2260e10 --- /dev/null +++ b/src/tests/endpoints/test_risk_filter.py @@ -0,0 +1,573 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from datetime import date, timedelta +from unittest.mock import AsyncMock, patch + +import pytest # type: ignore +from httpx import AsyncClient +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.api.api_v1.endpoints.alerts import fetch_alerts_from_date +from app.core.time import utcnow +from app.models import Alert, AlertSequence, Sequence, UserRole +from app.schemas.login import TokenPayload +from app.services.risk import risk_service + + +@pytest.fixture +def reset_risk_cache(): + previous = dict(risk_service._scores) + risk_service._scores = {} + yield + risk_service._scores = previous + + +async def _seed_unlabeled_sequence( + session: AsyncSession, + camera_id: int, + pose_id: int, + max_conf: float, + minutes_ago: int = 30, +) -> Sequence: + now = utcnow() + seq = Sequence( + camera_id=camera_id, + pose_id=pose_id, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=minutes_ago), + last_seen_at=now - timedelta(minutes=minutes_ago - 1), + max_conf=max_conf, + ) + session.add(seq) + await session.commit() + await session.refresh(seq) + return seq + + +@pytest.mark.asyncio +async def test_unlabeled_latest_drops_low_conf_when_camera_is_low_risk( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.40, minutes_ago=30) + high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.55, minutes_ago=20) + + risk_service._scores = {camera_id: "low"} + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + returned_ids = {item["id"] for item in response.json()} + assert low_seq.id not in returned_ids + assert high_seq.id in returned_ids + + +@pytest.mark.asyncio +async def test_unlabeled_latest_drops_below_very_low_threshold( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + # 0.55 passes the low threshold (0.45) but fails very_low (0.6) + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.55, minutes_ago=25) + + risk_service._scores = {camera_id: "very_low"} + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert seq.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("fwi_class", ["moderate", "high", "very_high", "extreme"]) +async def test_unlabeled_latest_keeps_all_when_class_is_moderate_or_above( + fwi_class: str, async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """No filter is applied for ``moderate`` and above; pin every class so future tweaks stay covered.""" + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=30) + + risk_service._scores = {camera_id: fwi_class} + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert low_seq.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_unlabeled_latest_keeps_seq_with_null_max_conf_under_filter( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """Sequences with NULL max_conf (legacy / unparseable bbox) must pass even under an active filter.""" + camera_id = pytest.camera_table[0]["id"] + pose_id = pytest.pose_table[0]["id"] + null_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=None, minutes_ago=20) # type: ignore[arg-type] + + risk_service._scores = {camera_id: "low"} # 0.45 threshold, would normally drop + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert null_seq.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_unlabeled_latest_keeps_seq_for_camera_unknown_to_risk_api( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """A camera absent from the risk-api cache stays unfiltered even when sibling cameras are filtered.""" + known_cam = pytest.camera_table[0]["id"] + unknown_cam = pytest.camera_table[1]["id"] + known_pose = pytest.pose_table[0]["id"] + unknown_pose = pytest.pose_table[2]["id"] + + # Cache only knows about ``known_cam`` and flags it ``low`` (0.45 threshold). + # ``unknown_cam`` has no entry -> CASE else_=0.0 -> any max_conf passes. + risk_service._scores = {known_cam: "low"} + + known_dropped = await _seed_unlabeled_sequence(detection_session, known_cam, known_pose, max_conf=0.20) + unknown_kept = await _seed_unlabeled_sequence(detection_session, unknown_cam, unknown_pose, max_conf=0.10) + + # known_cam belongs to org 1; unknown_cam belongs to org 2 -> query both orgs. + auth_org1 = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + auth_org2 = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + + resp1 = await async_client.get("/sequences/unlabeled/latest", headers=auth_org1) + assert resp1.status_code == 200, print(resp1.__dict__) + assert known_dropped.id not in {item["id"] for item in resp1.json()} + + resp2 = await async_client.get("/sequences/unlabeled/latest", headers=auth_org2) + assert resp2.status_code == 200, print(resp2.__dict__) + assert unknown_kept.id in {item["id"] for item in resp2.json()} + + +async def _seed_alert_with_sequence(session: AsyncSession, organization_id: int, seq: Sequence) -> Alert: + now = utcnow() + alert = Alert( + organization_id=organization_id, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=19), + ) + session.add(alert) + await session.commit() + await session.refresh(alert) + session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id)) + await session.commit() + return alert + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_drops_alert_when_all_seqs_below_threshold( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] # belongs to org 2 (user_idx 2) + pose_id = pytest.pose_table[2]["id"] + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + risk_service._scores = {camera_id: "low"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_risk_score_override( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + # Risk-api would say "moderate" (no filter), but the override forces "low" -> 0.45 threshold drops it. + risk_service._scores = {camera_id: "moderate"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest?risk_score=low", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("fwi_class", ["moderate", "high", "very_high", "extreme"]) +async def test_alerts_unlabeled_latest_risk_score_moderate_or_above_keeps_everything( + fwi_class: str, async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """Override at any moderate-or-above class disables the filter even when the cache would drop the alert.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + # Cache says very_low (would drop everything), but the override forces a no-filter class. + risk_service._scores = {camera_id: "very_low"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get(f"/alerts/unlabeled/latest?risk_score={fwi_class}", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_risk_score_invalid_value_returns_422( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest?risk_score=bogus", headers=auth) + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_sequences_unlabeled_latest_risk_score_override_drops_low_conf( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) + high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.55, minutes_ago=15) + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest?risk_score=low", headers=auth) + assert response.status_code == 200, print(response.__dict__) + returned_ids = {item["id"] for item in response.json()} + assert low_seq.id not in returned_ids + assert high_seq.id in returned_ids + + +@pytest.mark.asyncio +async def test_sequences_unlabeled_latest_risk_score_invalid_value_returns_422( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest?risk_score=bogus", headers=auth) + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_sequences_fromdate_risk_score_moderate_disables_filter( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """``risk_score=moderate`` is the kill switch: every seed seq comes back regardless of max_conf.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + + seeded = [] + for max_conf, minutes_ago in [(0.05, 50), (0.20, 45), (0.55, 30)]: + seq = await _seed_unlabeled_sequence( + detection_session, camera_id, pose_id, max_conf=max_conf, minutes_ago=minutes_ago + ) + seeded.append(seq.id) + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get( + f"/sequences/all/fromdate?from_date={target_date}&limit=200&risk_score=moderate", headers=auth + ) + assert response.status_code == 200, print(response.__dict__) + returned_ids = {item["id"] for item in response.json()} + assert set(seeded).issubset(returned_ids) + + +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_keeps_alert_with_mixed_seqs( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """An alert mixing one passing and one failing sequence stays, with only the passing seq in payload.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + low_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=25) + high_seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.70, minutes_ago=15) + + now = utcnow() + alert = Alert( + organization_id=2, + started_at=now - timedelta(minutes=30), + last_seen_at=now - timedelta(minutes=10), + ) + detection_session.add(alert) + await detection_session.commit() + await detection_session.refresh(alert) + detection_session.add(AlertSequence(alert_id=alert.id, sequence_id=low_seq.id)) + detection_session.add(AlertSequence(alert_id=alert.id, sequence_id=high_seq.id)) + await detection_session.commit() + + risk_service._scores = {camera_id: "low"} # 0.45 threshold + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get("/alerts/unlabeled/latest", headers=auth) + assert response.status_code == 200, print(response.__dict__) + + payload = next((a for a in response.json() if a["id"] == alert.id), None) + assert payload is not None, "alert with at least one passing sequence must remain" + seq_ids = {s["id"] for s in payload["sequences"]} + assert high_seq.id in seq_ids + assert low_seq.id not in seq_ids + + +@pytest.mark.asyncio +async def test_alerts_fromdate_risk_score_override_drops_low_conf_alert( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.30, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get(f"/alerts/all/fromdate?from_date={target_date}&risk_score=low", headers=auth) + assert response.status_code == 200, print(response.__dict__) + assert alert.id not in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("fwi_class", ["moderate", "high", "very_high", "extreme"]) +async def test_alerts_fromdate_risk_score_moderate_or_above_keeps_alert( + fwi_class: str, async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """No-filter override (``moderate`` and above) keeps the alert on ``/alerts/all/fromdate`` too.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + seq = await _seed_unlabeled_sequence(detection_session, camera_id, pose_id, max_conf=0.10, minutes_ago=20) + alert = await _seed_alert_with_sequence(detection_session, organization_id=2, seq=seq) + + # Cache says very_low (would drop everything), but the override forces a no-filter class. + risk_service._scores = {camera_id: "very_low"} + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get( + f"/alerts/all/fromdate?from_date={target_date}&risk_score={fwi_class}", headers=auth + ) + assert response.status_code == 200, print(response.__dict__) + assert alert.id in {item["id"] for item in response.json()} + + +@pytest.mark.asyncio +async def test_alerts_fromdate_risk_score_invalid_value_returns_422( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + target_date = utcnow().date().isoformat() + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + response = await async_client.get(f"/alerts/all/fromdate?from_date={target_date}&risk_score=bogus", headers=auth) + assert response.status_code == 422 + + +class _ExecResult: + def __init__(self, rows) -> None: + self._rows = rows + + def all(self): + return self._rows + + +class _FakeAlertsSession: + def __init__(self, *results) -> None: + self._results = list(results) + self.statements = [] + + async def exec(self, stmt): + self.statements.append(stmt) + return self._results.pop(0) + + +def _unit_token(organization_id: int = 1) -> TokenPayload: + return TokenPayload(sub=1, scopes=[UserRole.USER], organization_id=organization_id) + + +@pytest.mark.asyncio +async def test_unit_fetch_alerts_from_date_calls_risk_scores_for_requested_date(): + """The per-date branch must dispatch ``get_scores_for_date`` with the requested date and org.""" + target_date = date(2026, 5, 5) + now = utcnow() + alert = Alert(id=501, organization_id=1, started_at=now, last_seen_at=now) + seq = Sequence( + id=601, + camera_id=1, + pose_id=1, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=5), + max_conf=0.70, + ) + session = _FakeAlertsSession(_ExecResult([alert]), _ExecResult([(alert.id, seq)])) + risk_scores_mock = AsyncMock(return_value={}) # empty -> no SQL filter, simpler stmt order + counts_mock = AsyncMock(return_value={seq.id: 3}) + + with ( + patch("app.api.api_v1.endpoints.alerts.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.alerts.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_alerts_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score=None, + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_awaited_once_with(target_date, organization_id=1) + counts_mock.assert_awaited_once_with(session, [seq.id]) + assert [item.id for item in result] == [alert.id] + assert [s.id for s in result[0].sequences] == [seq.id] + + +@pytest.mark.asyncio +async def test_unit_fetch_alerts_from_date_risk_score_override_bypasses_risk_api(): + """The override branch must NOT hit the risk API and must apply the override class to all org cameras.""" + target_date = date(2026, 5, 5) + now = utcnow() + alert = Alert(id=502, organization_id=1, started_at=now, last_seen_at=now) + seq = Sequence( + id=602, + camera_id=1, + pose_id=1, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=5), + max_conf=0.55, + ) + # Override path -> first exec: select Camera.id, then alerts_stmt, then alert_seq map + session = _FakeAlertsSession( + _ExecResult([1]), + _ExecResult([alert]), + _ExecResult([(alert.id, seq)]), + ) + risk_scores_mock = AsyncMock() + counts_mock = AsyncMock(return_value={seq.id: 1}) + + with ( + patch("app.api.api_v1.endpoints.alerts.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.alerts.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_alerts_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score="low", + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_not_awaited() + assert [item.id for item in result] == [alert.id] + assert [s.id for s in result[0].sequences] == [seq.id] + + +@pytest.mark.asyncio +async def test_sequences_fromdate_pagination_filters_before_limit( + async_client: AsyncClient, detection_session: AsyncSession, reset_risk_cache +): + """A page must come back full when the filter would otherwise consume the whole limit.""" + camera_id = pytest.camera_table[1]["id"] + pose_id = pytest.pose_table[2]["id"] + target_date = utcnow().date().isoformat() + + # 2 below threshold + 3 above. With LIMIT 3, naive post-filter would return at most 1. + for max_conf, minutes_ago in [(0.10, 50), (0.20, 45), (0.50, 40), (0.60, 35), (0.80, 30)]: + await _seed_unlabeled_sequence( + detection_session, camera_id, pose_id, max_conf=max_conf, minutes_ago=minutes_ago + ) + + auth = pytest.get_token( + pytest.user_table[2]["id"], + pytest.user_table[2]["role"].split(), + pytest.user_table[2]["organization_id"], + ) + + # The ``risk_score=low`` override drives the threshold without hitting the risk API. + response = await async_client.get( + f"/sequences/all/fromdate?from_date={target_date}&limit=3&risk_score=low", headers=auth + ) + assert response.status_code == 200, print(response.__dict__) + page = response.json() + assert len(page) == 3 + assert all(seq["max_conf"] >= 0.45 for seq in page) diff --git a/src/tests/endpoints/test_sequences.py b/src/tests/endpoints/test_sequences.py index 1d593382..5860e155 100644 --- a/src/tests/endpoints/test_sequences.py +++ b/src/tests/endpoints/test_sequences.py @@ -1,4 +1,4 @@ -from datetime import timedelta +from datetime import date, timedelta from typing import Any, Dict, List, Union from unittest.mock import AsyncMock, MagicMock, patch @@ -8,13 +8,55 @@ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -from app.api.api_v1.endpoints.sequences import label_sequence +from app.api.api_v1.endpoints.sequences import ( + fetch_latest_unlabeled_sequences, + fetch_sequences_from_date, + label_sequence, +) from app.core.time import utcnow from app.models import Alert, AlertSequence, AnnotationType, Camera, Detection, Pose, Sequence, UserRole from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel +class _ExecResult: + def __init__(self, rows) -> None: + self._rows = rows + + def all(self): + return self._rows + + +class _FakeSequenceSession: + def __init__(self, *results) -> None: + self._results = list(results) + self.statements = [] + + async def exec(self, stmt): + self.statements.append(stmt) + return self._results.pop(0) + + +def _unit_sequence(sequence_id: int = 101, camera_id: int = 1, max_conf: float = 0.75) -> Sequence: + now = utcnow() + return Sequence( + id=sequence_id, + camera_id=camera_id, + pose_id=1, + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=20), + last_seen_at=now - timedelta(minutes=5), + max_conf=max_conf, + ) + + +def _unit_token(organization_id: int = 1) -> TokenPayload: + return TokenPayload(sub=1, scopes=[UserRole.USER], organization_id=organization_id) + + @pytest.mark.parametrize( ("sequence_id", "expected_idx", "expected_detections_count"), [ @@ -678,3 +720,76 @@ async def test_unit_label_sequence_forbidden_for_wrong_org(): ) assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_unit_fetch_latest_unlabeled_sequences_uses_risk_score_override(): + sequence = _unit_sequence(sequence_id=201, camera_id=1, max_conf=0.55) + session = _FakeSequenceSession(_ExecResult([1]), _ExecResult([sequence])) + counts_mock = AsyncMock(return_value={sequence.id: 4}) + + with patch("app.api.api_v1.endpoints.sequences.get_detection_counts_by_sequence_ids", new=counts_mock): + result = await fetch_latest_unlabeled_sequences( + risk_score="low", + session=session, + token_payload=_unit_token(), + ) + + assert len(session.statements) == 2 + counts_mock.assert_awaited_once_with(session, [sequence.id]) + assert [item.id for item in result] == [sequence.id] + assert result[0].detections_count == 4 + + +@pytest.mark.asyncio +async def test_unit_fetch_sequences_from_date_gets_risk_scores_for_requested_date(): + target_date = date(2026, 5, 5) + sequence = _unit_sequence(sequence_id=202, camera_id=1, max_conf=0.65) + session = _FakeSequenceSession(_ExecResult([1]), _ExecResult([sequence])) + risk_scores_mock = AsyncMock(return_value={1: "very_low"}) + counts_mock = AsyncMock(return_value={sequence.id: 2}) + + with ( + patch("app.api.api_v1.endpoints.sequences.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.sequences.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_sequences_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score=None, + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_awaited_once_with(target_date, organization_id=1) + counts_mock.assert_awaited_once_with(session, [sequence.id]) + assert [item.id for item in result] == [sequence.id] + assert result[0].detections_count == 2 + + +@pytest.mark.asyncio +async def test_unit_fetch_sequences_from_date_risk_score_override_bypasses_risk_api(): + target_date = date(2026, 5, 5) + sequence = _unit_sequence(sequence_id=203, camera_id=1, max_conf=0.50) + session = _FakeSequenceSession(_ExecResult([1]), _ExecResult([sequence])) + risk_scores_mock = AsyncMock() + counts_mock = AsyncMock(return_value={sequence.id: 1}) + + with ( + patch("app.api.api_v1.endpoints.sequences.risk_service.get_scores_for_date", new=risk_scores_mock), + patch("app.api.api_v1.endpoints.sequences.get_detection_counts_by_sequence_ids", new=counts_mock), + ): + result = await fetch_sequences_from_date( + from_date=target_date, + limit=10, + offset=0, + risk_score="low", + session=session, + token_payload=_unit_token(), + ) + + risk_scores_mock.assert_not_awaited() + counts_mock.assert_awaited_once_with(session, [sequence.id]) + assert [item.id for item in result] == [sequence.id] + assert result[0].detections_count == 1 diff --git a/src/tests/services/test_risk.py b/src/tests/services/test_risk.py new file mode 100644 index 00000000..5f5085e2 --- /dev/null +++ b/src/tests/services/test_risk.py @@ -0,0 +1,178 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from datetime import date +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.core.config import settings +from app.services.risk import FWI_MIN_CONF, RiskService, min_confidence_for_class + + +def _fake_httpx_client(*, json_data=None, raise_exc=None, side_effect_per_call=None): + """Return a context manager replacing ``app.services.risk.httpx.AsyncClient``. + + Either pass ``json_data`` (single response), ``raise_exc`` (raised on get()), or + ``side_effect_per_call`` (list of (json_or_exc) values returned in order). + """ + response = MagicMock() + response.raise_for_status = MagicMock() + response.json = MagicMock(return_value=json_data) + + inner = MagicMock() + if raise_exc is not None: + inner.get = AsyncMock(side_effect=raise_exc) + elif side_effect_per_call is not None: + responses = [] + for item in side_effect_per_call: + if isinstance(item, Exception): + responses.append(item) + else: + r = MagicMock() + r.raise_for_status = MagicMock() + r.json = MagicMock(return_value=item) + responses.append(r) + inner.get = AsyncMock(side_effect=responses) + else: + inner.get = AsyncMock(return_value=response) + + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=inner) + cm.__aexit__ = AsyncMock(return_value=None) + + factory = MagicMock(return_value=cm) + return factory, inner + + +@pytest.fixture +def configured_risk(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(settings, "RISK_API_URL", "http://risk.test") + monkeypatch.setattr(settings, "RISK_API_LOGIN", "u") + monkeypatch.setattr(settings, "RISK_API_PWD", "p") + + +def test_min_confidence_for_class(): + assert min_confidence_for_class("very_low") == FWI_MIN_CONF["very_low"] + assert min_confidence_for_class("low") == FWI_MIN_CONF["low"] + assert min_confidence_for_class("moderate") is None + assert min_confidence_for_class("high") is None + assert min_confidence_for_class("very_high") is None + assert min_confidence_for_class("extreme") is None + assert min_confidence_for_class(None) is None + assert min_confidence_for_class("unexpected") is None + + +def test_min_confidence_for_class_normalizes_casing(): + assert min_confidence_for_class("Very Low") == FWI_MIN_CONF["very_low"] + assert min_confidence_for_class("LOW") == FWI_MIN_CONF["low"] + assert min_confidence_for_class(" very_low ") == FWI_MIN_CONF["very_low"] + + +def test_risk_service_min_confidence_uses_cached_class(): + service = RiskService() + service._scores = {1: "very_low", 2: "low", 3: "moderate"} + assert service.min_confidence(1) == FWI_MIN_CONF["very_low"] + assert service.min_confidence(2) == FWI_MIN_CONF["low"] + assert service.min_confidence(3) is None + assert service.min_confidence(99) is None + + +@pytest.mark.asyncio +async def test_refresh_no_op_when_not_configured(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(settings, "RISK_API_URL", None) + service = RiskService() + service._scores = {1: "low"} + await service.refresh() + assert service._scores == {1: "low"} + + +@pytest.mark.asyncio +async def test_refresh_replaces_cache_on_success(configured_risk): + service = RiskService() + service._scores = {99: "very_low"} # stale entry that should disappear + factory, inner = _fake_httpx_client( + json_data=[ + {"id": 1, "fwi_class": "very_low"}, + {"id": 2, "fwi_class": "moderate"}, + ] + ) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {1: "very_low", 2: "moderate"} + inner.get.assert_awaited_once() + args, kwargs = inner.get.await_args + assert args[0].endswith("/cameras") + assert kwargs["auth"] == ("u", "p") + + +@pytest.mark.asyncio +async def test_refresh_keeps_cache_on_http_error(configured_risk): + service = RiskService() + service._scores = {1: "low"} + factory, _ = _fake_httpx_client(raise_exc=httpx.ConnectError("boom")) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {1: "low"} + + +@pytest.mark.asyncio +async def test_refresh_replaces_with_empty_when_payload_is_empty_list(configured_risk): + """Empty list from upstream wipes the cache; only a network/HTTP failure preserves it.""" + service = RiskService() + service._scores = {1: "low"} + factory, _ = _fake_httpx_client(json_data=[]) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {} + + +@pytest.mark.asyncio +async def test_refresh_skips_malformed_rows(configured_risk): + service = RiskService() + factory, _ = _fake_httpx_client( + json_data=[ + {"id": 1, "fwi_class": "low"}, + "not-a-dict", + {"id": "bad", "fwi_class": "low"}, + {"id": 2}, + {"camera_id": 3, "fwi_class": "very_low"}, + ] + ) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.refresh() + assert service._scores == {1: "low", 3: "very_low"} + + +@pytest.mark.asyncio +async def test_get_scores_for_date_passes_organization_param(configured_risk): + service = RiskService() + factory, inner = _fake_httpx_client(json_data=[{"camera_id": 7, "fwi_class": "high"}]) + with patch("app.services.risk.httpx.AsyncClient", factory): + scores = await service.get_scores_for_date(date(2026, 5, 5), organization_id=42) + assert scores == {7: "high"} + args, kwargs = inner.get.await_args + assert args[0].endswith("/scores/2026-05-05") + assert kwargs["params"] == {"organization_id": 42} + + +@pytest.mark.asyncio +async def test_get_scores_for_date_returns_empty_on_failure(configured_risk): + service = RiskService() + factory, _ = _fake_httpx_client(raise_exc=httpx.ReadTimeout("slow")) + with patch("app.services.risk.httpx.AsyncClient", factory): + scores = await service.get_scores_for_date(date(2026, 5, 5)) + assert scores == {} + + +@pytest.mark.asyncio +async def test_get_scores_for_date_no_org_param_when_none(configured_risk): + service = RiskService() + factory, inner = _fake_httpx_client(json_data=[]) + with patch("app.services.risk.httpx.AsyncClient", factory): + await service.get_scores_for_date(date(2026, 5, 5)) + _, kwargs = inner.get.await_args + assert kwargs["params"] is None diff --git a/src/tests/services/test_sequence_confidence.py b/src/tests/services/test_sequence_confidence.py new file mode 100644 index 00000000..0bbd3e9c --- /dev/null +++ b/src/tests/services/test_sequence_confidence.py @@ -0,0 +1,27 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from app.services.sequence_confidence import max_conf_from_bboxes + + +def test_max_conf_from_single_bbox(): + assert max_conf_from_bboxes("[(.1,.1,.7,.8,.42)]") == 0.42 + + +def test_max_conf_picks_max_across_bboxes_and_strings(): + bbox = "[(.1,.1,.7,.8,.3)]" + others = "[(.1,.1,.5,.5,.7),(.2,.2,.6,.6,.55)]" + assert max_conf_from_bboxes(bbox, others) == 0.7 + + +def test_max_conf_handles_none_and_empty(): + assert max_conf_from_bboxes(None) is None + assert max_conf_from_bboxes("") is None + assert max_conf_from_bboxes(None, None) is None + + +def test_max_conf_skips_unparseable(): + assert max_conf_from_bboxes("[invalid]") is None + assert max_conf_from_bboxes("[(.1,.1,.7,.8,.5)]", "[garbage]") == 0.5 diff --git a/src/tests/test_main.py b/src/tests/test_main.py new file mode 100644 index 00000000..561cf208 --- /dev/null +++ b/src/tests/test_main.py @@ -0,0 +1,147 @@ +# Copyright (C) 2024-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import asyncio +from collections.abc import Generator +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI + +from app.main import _risk_refresh_loop, _seconds_until_next_utc_hour, lifespan + + +class _CancelledTask: + def __init__(self) -> None: + self.cancel_called = False + + def cancel(self) -> None: + self.cancel_called = True + + def __await__(self) -> Generator[None, None, None]: + if not self.cancel_called: + yield None + raise asyncio.CancelledError + + +def test_seconds_until_next_utc_hour_future_today(): + fake_now = datetime(2026, 5, 5, 1, 30, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + # datetime.replace + timedelta still need to work — point them at the real type. + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(4) + # 04:00 - 01:30 = 2h30m = 9000s + assert seconds == 2 * 3600 + 30 * 60 + + +def test_seconds_until_next_utc_hour_rolls_to_next_day_when_passed(): + fake_now = datetime(2026, 5, 5, 5, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(4) + # next 04:00 is tomorrow → 23h + assert seconds == 23 * 3600 + + +def test_seconds_until_next_utc_hour_clamps_negative_hour(): + fake_now = datetime(2026, 5, 5, 12, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(-5) # clamped to 0 + # next 00:00 is tomorrow → 12h + assert seconds == 12 * 3600 + + +def test_seconds_until_next_utc_hour_clamps_overflow_hour(): + fake_now = datetime(2026, 5, 5, 12, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(99) # clamped to 23 + # next 23:00 today → 11h + assert seconds == 11 * 3600 + + +def test_seconds_until_next_utc_hour_returns_full_day_when_now_equals_target(): + fake_now = datetime(2026, 5, 5, 4, 0, 0, tzinfo=timezone.utc) + with patch("app.main.datetime") as mock_dt: + mock_dt.now.return_value = fake_now + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + seconds = _seconds_until_next_utc_hour(4) + assert seconds == timedelta(days=1).total_seconds() + + +@pytest.mark.asyncio +async def test_risk_refresh_loop_calls_refresh_then_cancels_cleanly(): + sleep_mock = AsyncMock(side_effect=[None, asyncio.CancelledError()]) + refresh_mock = AsyncMock() + + with ( + patch("app.main.asyncio.sleep", sleep_mock), + patch("app.main.risk_service.refresh", new=refresh_mock), + pytest.raises(asyncio.CancelledError), + ): + await _risk_refresh_loop() + + refresh_mock.assert_awaited_once() + assert sleep_mock.await_count == 2 + + +@pytest.mark.asyncio +async def test_risk_refresh_loop_swallows_refresh_errors_and_continues(): + sleep_mock = AsyncMock(side_effect=[None, None, asyncio.CancelledError()]) + refresh_mock = AsyncMock(side_effect=[RuntimeError("boom"), None]) + + with ( + patch("app.main.asyncio.sleep", sleep_mock), + patch("app.main.risk_service.refresh", new=refresh_mock), + patch("app.main.logger.exception") as exception_mock, + pytest.raises(asyncio.CancelledError), + ): + await _risk_refresh_loop() + + assert refresh_mock.await_count == 2 + assert sleep_mock.await_count == 3 + exception_mock.assert_called_once_with("Risk refresh loop iteration failed; continuing") + + +@pytest.mark.asyncio +async def test_lifespan_refreshes_and_cancels_daily_task_when_risk_api_configured(): + fake_service = SimpleNamespace(is_configured=True, refresh=AsyncMock()) + fake_task = _CancelledTask() + + def fake_create_task(coro): + coro.close() + return fake_task + + create_task_mock = MagicMock(side_effect=fake_create_task) + with ( + patch("app.main.risk_service", fake_service), + patch("app.main.asyncio.create_task", create_task_mock), + ): + async with lifespan(FastAPI()): + fake_service.refresh.assert_awaited_once() + create_task_mock.assert_called_once() + assert fake_task.cancel_called is False + + assert fake_task.cancel_called is True + + +@pytest.mark.asyncio +async def test_lifespan_skips_risk_refresh_when_risk_api_not_configured(): + fake_service = SimpleNamespace(is_configured=False, refresh=AsyncMock()) + + with ( + patch("app.main.risk_service", fake_service), + patch("app.main.asyncio.create_task") as create_task_mock, + ): + async with lifespan(FastAPI()): + fake_service.refresh.assert_not_awaited() + create_task_mock.assert_not_called()