diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 8af09c05..7cf64074 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -20,6 +20,7 @@ from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead +from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.telemetry import telemetry_client router = APIRouter() @@ -46,10 +47,16 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i return mapping -def _serialize_alert(alert: Alert, sequences: List[Sequence]) -> AlertReadWithSequences: +def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead: + return SequenceRead(**sequence.model_dump(), detections_count=detections_count) + + +def _serialize_alert( + alert: Alert, sequences: List[Sequence], detection_counts: Dict[int, int] +) -> AlertReadWithSequences: return AlertReadWithSequences( **alert.model_dump(), - sequences=[SequenceRead(**seq.model_dump()) for seq in sequences], + sequences=[_serialize_sequence(sequence, detection_counts.get(sequence.id, 0)) for sequence in sequences], ) @@ -66,9 +73,11 @@ async def get_alert( if UserRole.ADMIN not in token_payload.scopes: verify_org_rights(token_payload.organization_id, alert) - alert_id_int = int(alert.id) - seq_map = await _fetch_sequences_by_alert_ids(session, [alert_id_int]) - return _serialize_alert(alert, seq_map.get(alert_id_int, [])) + seq_map = await _fetch_sequences_by_alert_ids(session, [alert.id]) + detection_counts = await get_detection_counts_by_sequence_ids( + session, [sequence.id for sequence in seq_map.get(alert.id, [])] + ) + return _serialize_alert(alert, seq_map.get(alert.id, []), detection_counts) @router.get( @@ -81,7 +90,7 @@ async def fetch_alert_sequences( alerts: AlertCRUD = Depends(get_alert_crud), session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), -) -> List[Sequence]: +) -> List[SequenceRead]: telemetry_client.capture(token_payload.sub, event="alerts-sequences-get", properties={"alert_id": alert_id}) alert = cast(Alert, await alerts.get(alert_id, strict=True)) if UserRole.ADMIN not in token_payload.scopes: @@ -93,7 +102,9 @@ async def fetch_alert_sequences( seq_stmt = seq_stmt.where(AlertSequence.alert_id == alert_id).order_by(order_clause).limit(limit) res = await session.exec(seq_stmt) - return list(res.all()) + sequences = list(res.all()) + detection_counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in sequences]) + return [_serialize_sequence(sequence, detection_counts.get(sequence.id, 0)) for sequence in sequences] @router.get( @@ -118,9 +129,13 @@ async def fetch_latest_unlabeled_alerts( ) alerts_res = await session.exec(alerts_stmt) alerts = alerts_res.unique().all() - alert_ids = [int(alert.id) for alert in alerts] + alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - return [_serialize_alert(alert, seq_map.get(int(alert.id), [])) for alert in alerts] + detection_counts = await get_detection_counts_by_sequence_ids( + session, + list({sequence.id for sequences in seq_map.values() for sequence in sequences}), + ) + return [_serialize_alert(alert, seq_map.get(alert.id, []), detection_counts) for alert in alerts] @router.get("/all/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the alerts for a specific date") @@ -143,9 +158,13 @@ async def fetch_alerts_from_date( ) alerts_res = await session.exec(alerts_stmt) alerts = alerts_res.all() - alert_ids = [int(alert.id) for alert in alerts] + alert_ids = [alert.id for alert in alerts] seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids) - return [_serialize_alert(alert, seq_map.get(int(alert.id), [])) for alert in alerts] + detection_counts = await get_detection_counts_by_sequence_ids( + session, + list({sequence.id for sequences in seq_map.values() for sequence in sequences}), + ) + return [_serialize_alert(alert, seq_map.get(alert.id, []), detection_counts) for alert in alerts] @router.delete("/{alert_id}", status_code=status.HTTP_200_OK, summary="Delete an alert") diff --git a/src/app/api/api_v1/endpoints/detections.py b/src/app/api/api_v1/endpoints/detections.py index 477aa57e..60ae87ed 100644 --- a/src/app/api/api_v1/endpoints/detections.py +++ b/src/app/api/api_v1/endpoints/detections.py @@ -554,7 +554,7 @@ async def fetch_detections( ) -> List[DetectionRead]: telemetry_client.capture(token_payload.sub, event="detections-fetch") if UserRole.ADMIN in token_payload.scopes: - return [DetectionRead(**elt.model_dump()) for elt in await detections.fetch_all()] + return [DetectionRead(**elt.model_dump()) for elt in await detections.fetch_all(order_by="id")] cameras_list = await cameras.fetch_all(filters=("organization_id", token_payload.organization_id)) camera_ids = [camera.id for camera in cameras_list] diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index 1f3e8841..edacaa46 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -22,6 +22,7 @@ from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel, SequenceRead from app.services.overlap import compute_overlap +from app.services.sequence_counts import get_detection_counts_by_sequence_ids from app.services.storage import s3_service from app.services.telemetry import telemetry_client @@ -83,20 +84,26 @@ async def _refresh_alert_state(alert_id: int, session: AsyncSession, alerts: Ale ) +def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead: + return SequenceRead(**sequence.model_dump(), detections_count=detections_count) + + @router.get("/{sequence_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific sequence") async def get_sequence( sequence_id: int = Path(..., gt=0), cameras: CameraCRUD = Depends(get_camera_crud), sequences: SequenceCRUD = Depends(get_sequence_crud), + session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), -) -> Sequence: +) -> SequenceRead: telemetry_client.capture(token_payload.sub, event="sequences-get", properties={"sequence_id": sequence_id}) sequence = cast(Sequence, await sequences.get(sequence_id, strict=True)) if UserRole.ADMIN not in token_payload.scopes: await verify_org_rights(token_payload.organization_id, sequence.camera_id, cameras) - return SequenceRead(**sequence.model_dump()) + counts = await get_detection_counts_by_sequence_ids(session, [sequence.id]) + return _serialize_sequence(sequence, counts.get(sequence.id, 0)) @router.get( @@ -155,7 +162,8 @@ async def fetch_latest_unlabeled_sequences( .limit(15) ) ).all() - return [SequenceRead(**elt.model_dump()) for elt in fetched_sequences] + counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) + return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] @router.get("/all/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the sequences for a specific date") @@ -180,7 +188,8 @@ async def fetch_sequences_from_date( .offset(offset) ) ).all() - return [SequenceRead(**elt.model_dump()) for elt in fetched_sequences] + counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences]) + return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences] @router.delete("/{sequence_id}", status_code=status.HTTP_200_OK, summary="Delete a sequence") diff --git a/src/app/schemas/sequences.py b/src/app/schemas/sequences.py index 2555dd57..3b0db886 100644 --- a/src/app/schemas/sequences.py +++ b/src/app/schemas/sequences.py @@ -22,4 +22,4 @@ class SequenceLabel(BaseModel): class SequenceRead(Sequence): - pass + detections_count: int = 0 diff --git a/src/app/services/sequence_counts.py b/src/app/services/sequence_counts.py new file mode 100644 index 00000000..195c8b52 --- /dev/null +++ b/src/app/services/sequence_counts.py @@ -0,0 +1,25 @@ +# Copyright (C) 2025-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any, Dict, List, cast + +from sqlmodel import func, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models import Detection + + +async def get_detection_counts_by_sequence_ids(session: AsyncSession, sequence_ids: List[int]) -> Dict[int, int]: + if not sequence_ids: + return {} + + stmt: Any = ( + select(cast(Any, Detection.sequence_id), func.count(Detection.id)) + .where(cast(Any, Detection.sequence_id).in_(sequence_ids)) + .group_by(cast(Any, Detection.sequence_id)) + ) + res = await session.exec(stmt) + return {sequence_id: detections_count for sequence_id, detections_count in res.all() if sequence_id is not None} diff --git a/src/tests/endpoints/test_alerts.py b/src/tests/endpoints/test_alerts.py index 77bb9724..4c681630 100644 --- a/src/tests/endpoints/test_alerts.py +++ b/src/tests/endpoints/test_alerts.py @@ -14,14 +14,18 @@ from app.core.config import settings from app.core.time import utcnow -from app.models import Alert, AlertSequence, AnnotationType, Camera, Organization, Pose, Sequence +from app.models import Alert, AlertSequence, AnnotationType, Camera, Detection, Organization, Pose, Sequence from app.services.overlap import compute_overlap async def _create_alert_with_sequences( session: AsyncSession, org_id: int, camera_id: int, lat: float, lon: float -) -> Tuple[Alert, List[int]]: +) -> Tuple[Alert, List[int], List[int]]: now = utcnow() + pose = ( + await session.exec(select(Pose).where(Pose.camera_id == camera_id).order_by(Pose.id)) # type: ignore[attr-defined] + ).first() + assert pose is not None seq_payloads = [ { "camera_id": camera_id, @@ -48,6 +52,7 @@ async def _create_alert_with_sequences( "cone_angle": 3.0, }, ] + detections_count_by_sequence = [2, 1, 0] sequences: List[Sequence] = [] for idx, payload in enumerate(seq_payloads): seq = Sequence( @@ -60,6 +65,20 @@ async def _create_alert_with_sequences( await session.commit() for seq in sequences: await session.refresh(seq) + for sequence, detections_count in zip(sequences, detections_count_by_sequence, strict=True): + for det_idx in range(detections_count): + session.add( + Detection( + camera_id=sequence.camera_id, + pose_id=pose.id, + sequence_id=sequence.id, + bucket_key=f"alert-seq-{sequence.id}-{det_idx}.jpg", + bbox="[(.1,.1,.7,.8,.9)]", + others_bboxes=None, + created_at=now - timedelta(seconds=det_idx), + ) + ) + await session.commit() alert = Alert( organization_id=org_id, @@ -75,14 +94,15 @@ async def _create_alert_with_sequences( for seq in sequences: session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id)) await session.commit() - return alert, [seq.id for seq in sequences] + return alert, [seq.id for seq in sequences], detections_count_by_sequence @pytest.mark.asyncio async def test_get_alert_and_sequences(async_client: AsyncClient, detection_session: AsyncSession): - alert, seq_ids = await _create_alert_with_sequences( + alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences( detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256 ) + expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False)) auth = pytest.get_token( pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] @@ -97,19 +117,23 @@ async def test_get_alert_and_sequences(async_client: AsyncClient, detection_sess assert payload["started_at"] == alert.started_at.isoformat() assert payload["last_seen_at"] == alert.last_seen_at.isoformat() assert {seq["id"] for seq in payload["sequences"]} == set(seq_ids) + assert {seq["id"]: seq["detections_count"] for seq in payload["sequences"]} == expected_counts resp = await async_client.get(f"/alerts/{alert.id}/sequences?limit=5&desc=true", headers=auth) assert resp.status_code == 200, resp.text returned = resp.json() last_seen_times = [item["last_seen_at"] for item in returned] assert last_seen_times == sorted(last_seen_times, reverse=True) + assert {sequence["id"]: sequence["detections_count"] for sequence in returned} == expected_counts + assert any(sequence["detections_count"] == 0 for sequence in returned) @pytest.mark.asyncio async def test_alerts_unlabeled_latest(async_client: AsyncClient, detection_session: AsyncSession): - alert, seq_ids = await _create_alert_with_sequences( + alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences( detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256 ) + expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False)) auth = pytest.get_token( pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] @@ -124,13 +148,16 @@ async def test_alerts_unlabeled_latest(async_client: AsyncClient, detection_sess assert returned["started_at"] == alert.started_at.isoformat() assert returned["last_seen_at"] == alert.last_seen_at.isoformat() assert {seq["id"] for seq in returned["sequences"]} == set(seq_ids) + assert {seq["id"]: seq["detections_count"] for seq in returned["sequences"]} == expected_counts + assert any(seq["detections_count"] == 0 for seq in returned["sequences"]) @pytest.mark.asyncio async def test_alerts_from_date(async_client: AsyncClient, detection_session: AsyncSession): - alert, seq_ids = await _create_alert_with_sequences( + alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences( detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256 ) + expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False)) date_str = alert.started_at.date().isoformat() auth = pytest.get_token( @@ -146,6 +173,8 @@ async def test_alerts_from_date(async_client: AsyncClient, detection_session: As assert started_times == sorted(started_times, reverse=True) alert_payload = next(item for item in returned if item["id"] == alert.id) assert {seq["id"] for seq in alert_payload["sequences"]} == set(seq_ids) + assert {seq["id"]: seq["detections_count"] for seq in alert_payload["sequences"]} == expected_counts + assert any(seq["detections_count"] == 0 for seq in alert_payload["sequences"]) @pytest.mark.asyncio diff --git a/src/tests/endpoints/test_sequences.py b/src/tests/endpoints/test_sequences.py index 6e9f0d17..1d593382 100644 --- a/src/tests/endpoints/test_sequences.py +++ b/src/tests/endpoints/test_sequences.py @@ -15,6 +15,36 @@ from app.schemas.sequences import SequenceLabel +@pytest.mark.parametrize( + ("sequence_id", "expected_idx", "expected_detections_count"), + [ + (1, 0, 3), + (2, 1, 1), + ], +) +@pytest.mark.asyncio +async def test_get_sequence( + async_client: AsyncClient, + detection_session: AsyncSession, + sequence_id: int, + expected_idx: int, + expected_detections_count: int, +): + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + + response = await async_client.get(f"/sequences/{sequence_id}", headers=auth) + + assert response.status_code == 200, print(response.__dict__) + assert response.json() == { + **pytest.sequence_table[expected_idx], + "detections_count": expected_detections_count, + } + + @pytest.mark.parametrize( ("user_idx", "sequence_id", "status_code", "status_detail", "expected_result"), [ @@ -158,9 +188,9 @@ async def test_label_sequence( # datetime != date, weird, but works (0, "2018-06-06T00:00:00", 200, None, []), (0, "2018-06-06", 200, None, []), - (0, "2023-11-07", 200, None, pytest.sequence_table[:1]), - (1, "2023-11-07", 200, None, pytest.sequence_table[:1]), - (2, "2023-11-07", 200, None, pytest.sequence_table[1:2]), + (0, "2023-11-07", 200, None, [{**pytest.sequence_table[0], "detections_count": 3}]), + (1, "2023-11-07", 200, None, [{**pytest.sequence_table[0], "detections_count": 3}]), + (2, "2023-11-07", 200, None, [{**pytest.sequence_table[1], "detections_count": 1}]), ], ) @pytest.mark.asyncio @@ -190,6 +220,7 @@ async def test_fetch_sequences_from_date( assert response.json() == expected_result assert all(isinstance(elt["sequence_azimuth"], float) for elt in response.json()) assert all(isinstance(elt["cone_angle"], float) for elt in response.json()) + assert all(isinstance(elt["detections_count"], int) for elt in response.json()) @pytest.mark.parametrize( @@ -229,6 +260,63 @@ async def test_latest_sequences( assert all(isinstance(elt["cone_angle"], float) for elt in response.json()) +@pytest.mark.asyncio +async def test_latest_sequences_include_detections_count(async_client: AsyncClient, detection_session: AsyncSession): + now = utcnow() + sequence_with_detections = Sequence( + camera_id=pytest.camera_table[0]["id"], + pose_id=pytest.pose_table[0]["id"], + camera_azimuth=180.0, + sequence_azimuth=175.0, + cone_angle=5.0, + is_wildfire=None, + started_at=now - timedelta(minutes=15), + last_seen_at=now - timedelta(minutes=5), + ) + sequence_without_detections = Sequence( + camera_id=pytest.camera_table[0]["id"], + pose_id=pytest.pose_table[0]["id"], + camera_azimuth=182.0, + sequence_azimuth=176.0, + cone_angle=6.0, + is_wildfire=None, + started_at=now - timedelta(minutes=10), + last_seen_at=now - timedelta(minutes=2), + ) + detection_session.add(sequence_with_detections) + detection_session.add(sequence_without_detections) + await detection_session.commit() + await detection_session.refresh(sequence_with_detections) + await detection_session.refresh(sequence_without_detections) + + for idx in range(2): + detection_session.add( + Detection( + camera_id=sequence_with_detections.camera_id, + pose_id=pytest.pose_table[0]["id"], + sequence_id=sequence_with_detections.id, + bucket_key=f"sequence-latest-{sequence_with_detections.id}-{idx}.jpg", + bbox="[(.1,.1,.7,.8,.9)]", + others_bboxes=None, + created_at=now - timedelta(minutes=4 - idx), + ) + ) + await detection_session.commit() + + auth = pytest.get_token( + pytest.user_table[0]["id"], + pytest.user_table[0]["role"].split(), + pytest.user_table[0]["organization_id"], + ) + response = await async_client.get("/sequences/unlabeled/latest", headers=auth) + + assert response.status_code == 200, print(response.__dict__) + returned = response.json() + counts_by_sequence_id = {item["id"]: item["detections_count"] for item in returned} + assert counts_by_sequence_id[sequence_with_detections.id] == 2 + assert counts_by_sequence_id[sequence_without_detections.id] == 0 + + @pytest.mark.asyncio async def test_sequence_label_updates_alerts(async_client: AsyncClient, detection_session: AsyncSession): # Create a sequence linked to a camera and an alert