Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions src/app/api/api_v1/endpoints/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from app.schemas.alerts import AlertReadWithSequences
from app.schemas.login import TokenPayload
from app.schemas.sequences import SequenceRead
from app.services.sequence_counts import get_detection_counts_by_sequence_ids
from app.services.telemetry import telemetry_client

router = APIRouter()
Expand All @@ -46,10 +47,16 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i
return mapping


def _serialize_alert(alert: Alert, sequences: List[Sequence]) -> AlertReadWithSequences:
def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead:
return SequenceRead(**sequence.model_dump(), detections_count=detections_count)


def _serialize_alert(
alert: Alert, sequences: List[Sequence], detection_counts: Dict[int, int]
) -> AlertReadWithSequences:
return AlertReadWithSequences(
**alert.model_dump(),
sequences=[SequenceRead(**seq.model_dump()) for seq in sequences],
sequences=[_serialize_sequence(sequence, detection_counts.get(sequence.id, 0)) for sequence in sequences],
)


Expand All @@ -66,9 +73,11 @@ async def get_alert(
if UserRole.ADMIN not in token_payload.scopes:
verify_org_rights(token_payload.organization_id, alert)

alert_id_int = int(alert.id)
seq_map = await _fetch_sequences_by_alert_ids(session, [alert_id_int])
return _serialize_alert(alert, seq_map.get(alert_id_int, []))
seq_map = await _fetch_sequences_by_alert_ids(session, [alert.id])
detection_counts = await get_detection_counts_by_sequence_ids(
session, [sequence.id for sequence in seq_map.get(alert.id, [])]
)
return _serialize_alert(alert, seq_map.get(alert.id, []), detection_counts)


@router.get(
Expand All @@ -81,7 +90,7 @@ async def fetch_alert_sequences(
alerts: AlertCRUD = Depends(get_alert_crud),
session: AsyncSession = Depends(get_session),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> List[Sequence]:
) -> List[SequenceRead]:
telemetry_client.capture(token_payload.sub, event="alerts-sequences-get", properties={"alert_id": alert_id})
alert = cast(Alert, await alerts.get(alert_id, strict=True))
if UserRole.ADMIN not in token_payload.scopes:
Expand All @@ -93,7 +102,9 @@ async def fetch_alert_sequences(
seq_stmt = seq_stmt.where(AlertSequence.alert_id == alert_id).order_by(order_clause).limit(limit)

res = await session.exec(seq_stmt)
return list(res.all())
sequences = list(res.all())
detection_counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in sequences])
return [_serialize_sequence(sequence, detection_counts.get(sequence.id, 0)) for sequence in sequences]


@router.get(
Expand All @@ -118,9 +129,13 @@ async def fetch_latest_unlabeled_alerts(
)
alerts_res = await session.exec(alerts_stmt)
alerts = alerts_res.unique().all()
alert_ids = [int(alert.id) for alert in alerts]
alert_ids = [alert.id for alert in alerts]
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids)
return [_serialize_alert(alert, seq_map.get(int(alert.id), [])) for alert in alerts]
detection_counts = await get_detection_counts_by_sequence_ids(
session,
list({sequence.id for sequences in seq_map.values() for sequence in sequences}),
)
return [_serialize_alert(alert, seq_map.get(alert.id, []), detection_counts) for alert in alerts]


@router.get("/all/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the alerts for a specific date")
Expand All @@ -143,9 +158,13 @@ async def fetch_alerts_from_date(
)
alerts_res = await session.exec(alerts_stmt)
alerts = alerts_res.all()
alert_ids = [int(alert.id) for alert in alerts]
alert_ids = [alert.id for alert in alerts]
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids)
return [_serialize_alert(alert, seq_map.get(int(alert.id), [])) for alert in alerts]
detection_counts = await get_detection_counts_by_sequence_ids(
session,
list({sequence.id for sequences in seq_map.values() for sequence in sequences}),
)
return [_serialize_alert(alert, seq_map.get(alert.id, []), detection_counts) for alert in alerts]


@router.delete("/{alert_id}", status_code=status.HTTP_200_OK, summary="Delete an alert")
Expand Down
2 changes: 1 addition & 1 deletion src/app/api/api_v1/endpoints/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ async def fetch_detections(
) -> List[DetectionRead]:
telemetry_client.capture(token_payload.sub, event="detections-fetch")
if UserRole.ADMIN in token_payload.scopes:
return [DetectionRead(**elt.model_dump()) for elt in await detections.fetch_all()]
return [DetectionRead(**elt.model_dump()) for elt in await detections.fetch_all(order_by="id")]

cameras_list = await cameras.fetch_all(filters=("organization_id", token_payload.organization_id))
camera_ids = [camera.id for camera in cameras_list]
Expand Down
17 changes: 13 additions & 4 deletions src/app/api/api_v1/endpoints/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from app.schemas.login import TokenPayload
from app.schemas.sequences import SequenceLabel, SequenceRead
from app.services.overlap import compute_overlap
from app.services.sequence_counts import get_detection_counts_by_sequence_ids
from app.services.storage import s3_service
from app.services.telemetry import telemetry_client

Expand Down Expand Up @@ -83,20 +84,26 @@ async def _refresh_alert_state(alert_id: int, session: AsyncSession, alerts: Ale
)


def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead:
return SequenceRead(**sequence.model_dump(), detections_count=detections_count)


@router.get("/{sequence_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific sequence")
async def get_sequence(
sequence_id: int = Path(..., gt=0),
cameras: CameraCRUD = Depends(get_camera_crud),
sequences: SequenceCRUD = Depends(get_sequence_crud),
session: AsyncSession = Depends(get_session),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> Sequence:
) -> SequenceRead:
telemetry_client.capture(token_payload.sub, event="sequences-get", properties={"sequence_id": sequence_id})
sequence = cast(Sequence, await sequences.get(sequence_id, strict=True))

if UserRole.ADMIN not in token_payload.scopes:
await verify_org_rights(token_payload.organization_id, sequence.camera_id, cameras)

return SequenceRead(**sequence.model_dump())
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id])
return _serialize_sequence(sequence, counts.get(sequence.id, 0))


@router.get(
Expand Down Expand Up @@ -155,7 +162,8 @@ async def fetch_latest_unlabeled_sequences(
.limit(15)
)
).all()
return [SequenceRead(**elt.model_dump()) for elt in fetched_sequences]
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences])
return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences]


@router.get("/all/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the sequences for a specific date")
Expand All @@ -180,7 +188,8 @@ async def fetch_sequences_from_date(
.offset(offset)
)
).all()
return [SequenceRead(**elt.model_dump()) for elt in fetched_sequences]
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences])
return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences]


@router.delete("/{sequence_id}", status_code=status.HTTP_200_OK, summary="Delete a sequence")
Expand Down
2 changes: 1 addition & 1 deletion src/app/schemas/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ class SequenceLabel(BaseModel):


class SequenceRead(Sequence):
pass
detections_count: int = 0
25 changes: 25 additions & 0 deletions src/app/services/sequence_counts.py
Comment thread
fe51 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (C) 2025-2026, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.


from typing import Any, Dict, List, cast

from sqlmodel import func, select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.models import Detection


async def get_detection_counts_by_sequence_ids(session: AsyncSession, sequence_ids: List[int]) -> Dict[int, int]:
if not sequence_ids:
return {}

stmt: Any = (
select(cast(Any, Detection.sequence_id), func.count(Detection.id))
.where(cast(Any, Detection.sequence_id).in_(sequence_ids))
.group_by(cast(Any, Detection.sequence_id))
)
res = await session.exec(stmt)
return {sequence_id: detections_count for sequence_id, detections_count in res.all() if sequence_id is not None}
41 changes: 35 additions & 6 deletions src/tests/endpoints/test_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@

from app.core.config import settings
from app.core.time import utcnow
from app.models import Alert, AlertSequence, AnnotationType, Camera, Organization, Pose, Sequence
from app.models import Alert, AlertSequence, AnnotationType, Camera, Detection, Organization, Pose, Sequence
from app.services.overlap import compute_overlap


async def _create_alert_with_sequences(
session: AsyncSession, org_id: int, camera_id: int, lat: float, lon: float
) -> Tuple[Alert, List[int]]:
) -> Tuple[Alert, List[int], List[int]]:
now = utcnow()
pose = (
await session.exec(select(Pose).where(Pose.camera_id == camera_id).order_by(Pose.id)) # type: ignore[attr-defined]
).first()
assert pose is not None
seq_payloads = [
{
"camera_id": camera_id,
Expand All @@ -48,6 +52,7 @@ async def _create_alert_with_sequences(
"cone_angle": 3.0,
},
]
detections_count_by_sequence = [2, 1, 0]
sequences: List[Sequence] = []
for idx, payload in enumerate(seq_payloads):
seq = Sequence(
Expand All @@ -60,6 +65,20 @@ async def _create_alert_with_sequences(
await session.commit()
for seq in sequences:
await session.refresh(seq)
for sequence, detections_count in zip(sequences, detections_count_by_sequence, strict=True):
for det_idx in range(detections_count):
session.add(
Detection(
camera_id=sequence.camera_id,
pose_id=pose.id,
sequence_id=sequence.id,
bucket_key=f"alert-seq-{sequence.id}-{det_idx}.jpg",
bbox="[(.1,.1,.7,.8,.9)]",
others_bboxes=None,
created_at=now - timedelta(seconds=det_idx),
)
)
await session.commit()

alert = Alert(
organization_id=org_id,
Expand All @@ -75,14 +94,15 @@ async def _create_alert_with_sequences(
for seq in sequences:
session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id))
await session.commit()
return alert, [seq.id for seq in sequences]
return alert, [seq.id for seq in sequences], detections_count_by_sequence


@pytest.mark.asyncio
async def test_get_alert_and_sequences(async_client: AsyncClient, detection_session: AsyncSession):
alert, seq_ids = await _create_alert_with_sequences(
alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences(
detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256
)
expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False))

auth = pytest.get_token(
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
Expand All @@ -97,19 +117,23 @@ async def test_get_alert_and_sequences(async_client: AsyncClient, detection_sess
assert payload["started_at"] == alert.started_at.isoformat()
assert payload["last_seen_at"] == alert.last_seen_at.isoformat()
assert {seq["id"] for seq in payload["sequences"]} == set(seq_ids)
assert {seq["id"]: seq["detections_count"] for seq in payload["sequences"]} == expected_counts

resp = await async_client.get(f"/alerts/{alert.id}/sequences?limit=5&desc=true", headers=auth)
assert resp.status_code == 200, resp.text
returned = resp.json()
last_seen_times = [item["last_seen_at"] for item in returned]
assert last_seen_times == sorted(last_seen_times, reverse=True)
assert {sequence["id"]: sequence["detections_count"] for sequence in returned} == expected_counts
assert any(sequence["detections_count"] == 0 for sequence in returned)


@pytest.mark.asyncio
async def test_alerts_unlabeled_latest(async_client: AsyncClient, detection_session: AsyncSession):
alert, seq_ids = await _create_alert_with_sequences(
alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences(
detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256
)
expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False))

auth = pytest.get_token(
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
Expand All @@ -124,13 +148,16 @@ async def test_alerts_unlabeled_latest(async_client: AsyncClient, detection_sess
assert returned["started_at"] == alert.started_at.isoformat()
assert returned["last_seen_at"] == alert.last_seen_at.isoformat()
assert {seq["id"] for seq in returned["sequences"]} == set(seq_ids)
assert {seq["id"]: seq["detections_count"] for seq in returned["sequences"]} == expected_counts
assert any(seq["detections_count"] == 0 for seq in returned["sequences"])


@pytest.mark.asyncio
async def test_alerts_from_date(async_client: AsyncClient, detection_session: AsyncSession):
alert, seq_ids = await _create_alert_with_sequences(
alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences(
detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256
)
expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False))
date_str = alert.started_at.date().isoformat()

auth = pytest.get_token(
Expand All @@ -146,6 +173,8 @@ async def test_alerts_from_date(async_client: AsyncClient, detection_session: As
assert started_times == sorted(started_times, reverse=True)
alert_payload = next(item for item in returned if item["id"] == alert.id)
assert {seq["id"] for seq in alert_payload["sequences"]} == set(seq_ids)
assert {seq["id"]: seq["detections_count"] for seq in alert_payload["sequences"]} == expected_counts
assert any(seq["detections_count"] == 0 for seq in alert_payload["sequences"])


@pytest.mark.asyncio
Expand Down
Loading
Loading