diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 4175c907..130639cd 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -208,6 +208,8 @@ async def fetch_alert_sequences( summary="Fetch all the alerts with unlabeled sequences from the last 24 hours", ) async def fetch_latest_unlabeled_alerts( + limit: Union[int, None] = Query(15, description="Maximum number of alerts to fetch"), + offset: Union[int, None] = Query(0, description="Number of alerts to skip before starting to fetch"), risk_score: Union[FwiClass, None] = Query( None, description="Override FWI class applied to every sequence; bypasses risk-api lookup." ), @@ -236,7 +238,8 @@ async def fetch_latest_unlabeled_alerts( .where(Alert.organization_id == token_payload.organization_id) .where(cast(Any, Alert.id).in_(seq_match)) .order_by(Alert.started_at.desc()) # type: ignore[attr-defined] - .limit(15) + .limit(limit) + .offset(offset) ) alerts = list((await session.exec(alerts_stmt)).all()) alert_ids = [alert.id for alert in alerts] diff --git a/src/tests/endpoints/test_alerts.py b/src/tests/endpoints/test_alerts.py index f0f6d12f..feb8005c 100644 --- a/src/tests/endpoints/test_alerts.py +++ b/src/tests/endpoints/test_alerts.py @@ -154,6 +154,34 @@ async def test_alerts_unlabeled_latest(async_client: AsyncClient, detection_sess assert any(seq["detections_count"] == 0 for seq in returned["sequences"]) +@pytest.mark.asyncio +async def test_alerts_unlabeled_latest_pagination(async_client: AsyncClient, detection_session: AsyncSession): + alert_a, _, _ = await _create_alert_with_sequences(detection_session, org_id=1, camera_id=1, lat=48.0, lon=2.0) + alert_b, _, _ = await _create_alert_with_sequences(detection_session, org_id=1, camera_id=1, lat=48.1, lon=2.1) + + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + + resp = await async_client.get("/alerts/unlabeled/latest?limit=10&offset=0", headers=auth) + assert resp.status_code == 200, resp.text + full = resp.json() + full_ids = [item["id"] for item in full] + assert {alert_a.id, alert_b.id}.issubset(full_ids) + + resp = await async_client.get("/alerts/unlabeled/latest?limit=1&offset=0", headers=auth) + assert resp.status_code == 200, resp.text + page_one = resp.json() + assert len(page_one) == 1 + assert page_one[0]["id"] == full_ids[0] + + resp = await async_client.get("/alerts/unlabeled/latest?limit=1&offset=1", headers=auth) + assert resp.status_code == 200, resp.text + page_two = resp.json() + assert len(page_two) == 1 + assert page_two[0]["id"] == full_ids[1] + + @pytest.mark.asyncio async def test_alerts_from_date(async_client: AsyncClient, detection_session: AsyncSession): alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences(