diff --git a/client/pyroclient/client.py b/client/pyroclient/client.py index ea4f695b..c1f459d9 100644 --- a/client/pyroclient/client.py +++ b/client/pyroclient/client.py @@ -40,6 +40,8 @@ class ClientRoute(str, Enum): SEQUENCES_FETCH_DETECTIONS = "sequences/{seq_id}/detections" SEQUENCES_FETCH_LATEST = "sequences/unlabeled/latest" SEQUENCES_FETCH_FROMDATE = "sequences/all/fromdate" + # ALERTS + ALERTS_UNMATCH_SEQUENCE = "alerts/{alert_id}/sequences/{seq_id}/unmatch" # ORGS ORGS_FETCH = "organizations" @@ -496,6 +498,32 @@ def fetch_sequences_detections(self, sequence_id: int, limit: int = 10, desc: bo timeout=self.timeout, ) + # ALERTS + + def unmatch_alert_sequence(self, alert_id: int, sequence_id: int) -> Response: + """Detach a sequence from an alert. If the sequence is no longer linked to any alert, + a new alert is created for it. + + >>> from pyroclient import client + >>> api_client = Client("MY_USER_TOKEN") + >>> response = api_client.unmatch_alert_sequence(1, 2) + + Args: + alert_id: ID of the alert the sequence should be detached from + sequence_id: ID of the sequence to detach + + Returns: + HTTP response + """ + return requests.post( + urljoin( + self._route_prefix, + ClientRoute.ALERTS_UNMATCH_SEQUENCE.format(alert_id=alert_id, seq_id=sequence_id), + ), + headers=self.headers, + timeout=self.timeout, + ) + # ORGANIZATIONS def fetch_organizations(self) -> Response: diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 8af09c05..e7c36aac 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -12,14 +12,15 @@ from sqlmodel import delete, func, select from sqlmodel.ext.asyncio.session import AsyncSession -from app.api.dependencies import get_alert_crud, get_jwt +from app.api.dependencies import get_alert_crud, get_camera_crud, get_jwt, get_sequence_crud from app.core.time import utcnow -from app.crud import AlertCRUD +from app.crud import AlertCRUD, CameraCRUD, SequenceCRUD from app.db import get_session -from app.models import Alert, AlertSequence, Sequence, UserRole -from app.schemas.alerts import AlertReadWithSequences +from app.models import Alert, AlertSequence, Camera, Sequence, UserRole +from app.schemas.alerts import AlertCreate, AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead +from app.services.alerts import refresh_alert_state from app.services.telemetry import telemetry_client router = APIRouter() @@ -148,6 +149,82 @@ async def fetch_alerts_from_date( return [_serialize_alert(alert, seq_map.get(int(alert.id), [])) for alert in alerts] +@router.post( + "/{alert_id}/sequences/{sequence_id}/unmatch", + status_code=status.HTTP_200_OK, + summary="Detach a sequence from an alert; create a fresh alert if the sequence becomes orphaned", +) +async def unmatch_alert_sequence( + alert_id: int = Path(..., gt=0), + sequence_id: int = Path(..., gt=0), + alerts: AlertCRUD = Depends(get_alert_crud), + sequences: SequenceCRUD = Depends(get_sequence_crud), + cameras: CameraCRUD = Depends(get_camera_crud), + session: AsyncSession = Depends(get_session), + token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT]), +) -> Union[AlertReadWithSequences, None]: + telemetry_client.capture( + token_payload.sub, + event="alerts-sequence-unmatch", + properties={"alert_id": alert_id, "sequence_id": sequence_id}, + ) + alert = cast(Alert, await alerts.get(alert_id, strict=True)) + if UserRole.ADMIN not in token_payload.scopes: + verify_org_rights(token_payload.organization_id, alert) + + link_stmt: Any = select(AlertSequence).where( + AlertSequence.alert_id == alert_id, AlertSequence.sequence_id == sequence_id + ) + link = (await session.exec(link_stmt)).first() + if link is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Sequence is not attached to this alert.") + + count_stmt: Any = select(func.count()).select_from(AlertSequence).where(AlertSequence.alert_id == alert_id) + sequence_count = int((await session.exec(count_stmt)).one()) + if sequence_count <= 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot unmatch the only sequence of an alert.", + ) + + delete_stmt: Any = delete(AlertSequence).where( + cast(Any, AlertSequence.alert_id) == alert_id, + cast(Any, AlertSequence.sequence_id) == sequence_id, + ) + await session.exec(delete_stmt) + await session.commit() + + await refresh_alert_state(alert_id, session, alerts) + + other_links_stmt: Any = ( + select(func.count()).select_from(AlertSequence).where(AlertSequence.sequence_id == sequence_id) + ) + other_links = int((await session.exec(other_links_stmt)).one()) + if other_links > 0: + return None + + sequence = cast(Sequence, await sequences.get(sequence_id, strict=True)) + camera = cast(Camera, await cameras.get(sequence.camera_id, strict=True)) + if camera.organization_id != alert.organization_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Sequence camera does not belong to the same organization as the alert.", + ) + new_alert = await alerts.create( + AlertCreate( + organization_id=alert.organization_id, + started_at=sequence.started_at, + last_seen_at=sequence.last_seen_at, + lat=None, + lon=None, + ) + ) + session.add(AlertSequence(alert_id=new_alert.id, sequence_id=sequence_id)) + await session.commit() + await session.refresh(new_alert) + return _serialize_alert(new_alert, [sequence]) + + @router.delete("/{alert_id}", status_code=status.HTTP_200_OK, summary="Delete an alert") async def delete_alert( alert_id: int = Path(..., gt=0), diff --git a/src/app/api/api_v1/endpoints/sequences.py b/src/app/api/api_v1/endpoints/sequences.py index 1f3e8841..397af63c 100644 --- a/src/app/api/api_v1/endpoints/sequences.py +++ b/src/app/api/api_v1/endpoints/sequences.py @@ -7,7 +7,6 @@ from datetime import date, timedelta from typing import Any, List, Union, cast -import pandas as pd from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status from sqlmodel import delete, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -17,11 +16,11 @@ from app.crud import AlertCRUD, CameraCRUD, DetectionCRUD, SequenceCRUD from app.db import get_session from app.models import AlertSequence, AnnotationType, Camera, Detection, Sequence, UserRole -from app.schemas.alerts import AlertCreate, AlertUpdate +from app.schemas.alerts import AlertCreate from app.schemas.detections import DetectionRead, DetectionSequence, DetectionWithUrl from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceLabel, SequenceRead -from app.services.overlap import compute_overlap +from app.services.alerts import refresh_alert_state from app.services.storage import s3_service from app.services.telemetry import telemetry_client @@ -36,53 +35,6 @@ async def verify_org_rights( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden.") -async def _refresh_alert_state(alert_id: int, session: AsyncSession, alerts: AlertCRUD) -> None: - remaining_stmt: Any = ( - select(Sequence, Camera) - .join(AlertSequence, cast(Any, AlertSequence.sequence_id) == Sequence.id) - .join(Camera, cast(Any, Camera.id) == Sequence.camera_id) - ) - remaining_stmt = remaining_stmt.where(AlertSequence.alert_id == alert_id) - remaining_res = await session.exec(remaining_stmt) - rows = remaining_res.all() - if not rows: - await alerts.delete(alert_id) - return - - seqs = [row[0] for row in rows] - cams = [row[1] for row in rows] - new_start = min(seq.started_at for seq in seqs) - new_last = max(seq.last_seen_at for seq in seqs) - - loc: Union[tuple[float, float], None] = None - if len(rows) >= 2: - records = [] - for seq, cam in zip(seqs, cams, strict=False): - records.append({ - "id": seq.id, - "pose_id": seq.pose_id, - "lat": cam.lat, - "lon": cam.lon, - "sequence_azimuth": seq.sequence_azimuth, - "cone_angle": seq.cone_angle, - "is_wildfire": seq.is_wildfire, - "started_at": seq.started_at, - "last_seen_at": seq.last_seen_at, - }) - df = compute_overlap(pd.DataFrame.from_records(records)) - loc = next((loc for locs in df["event_smoke_locations"].tolist() for loc in locs if loc is not None), None) - - await alerts.update( - alert_id, - AlertUpdate( - started_at=new_start, - last_seen_at=new_last, - lat=loc[0] if loc else None, - lon=loc[1] if loc else None, - ), - ) - - @router.get("/{sequence_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific sequence") async def get_sequence( sequence_id: int = Path(..., gt=0), @@ -207,7 +159,7 @@ async def delete_sequence( await sequences.delete(sequence_id) # Refresh affected alerts for aid in alert_ids: - await _refresh_alert_state(aid, session, alerts) + await refresh_alert_state(aid, session, alerts) @router.patch("/{sequence_id}/label", status_code=status.HTTP_200_OK, summary="Label the nature of the sequence") @@ -239,7 +191,7 @@ async def label_sequence( await session.exec(delete_links) await session.commit() for aid in alert_ids: - await _refresh_alert_state(aid, session, alerts) + await refresh_alert_state(aid, session, alerts) # Create a fresh alert for this sequence alone camera = cast(Camera, await cameras.get(sequence.camera_id, strict=True)) new_alert = await alerts.create( diff --git a/src/app/services/alerts.py b/src/app/services/alerts.py new file mode 100644 index 00000000..85b2ebcb --- /dev/null +++ b/src/app/services/alerts.py @@ -0,0 +1,66 @@ +# Copyright (C) 2025-2026, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any, Union, cast + +import pandas as pd +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.crud import AlertCRUD +from app.models import AlertSequence, Camera, Sequence +from app.schemas.alerts import AlertUpdate +from app.services.overlap import compute_overlap + +__all__ = ["refresh_alert_state"] + + +async def refresh_alert_state(alert_id: int, session: AsyncSession, alerts: AlertCRUD) -> None: + """Recompute an alert's bounds and location from its remaining sequences, or delete it if empty.""" + remaining_stmt: Any = ( + select(Sequence, Camera) + .join(AlertSequence, cast(Any, AlertSequence.sequence_id) == Sequence.id) + .join(Camera, cast(Any, Camera.id) == Sequence.camera_id) + ) + remaining_stmt = remaining_stmt.where(AlertSequence.alert_id == alert_id) + remaining_res = await session.exec(remaining_stmt) + rows = remaining_res.all() + if not rows: + await alerts.delete(alert_id) + return + + seqs = [row[0] for row in rows] + cams = [row[1] for row in rows] + new_start = min(seq.started_at for seq in seqs) + new_last = max(seq.last_seen_at for seq in seqs) + + loc: Union[tuple[float, float], None] = None + if len(rows) >= 2: + records = [] + for seq, cam in zip(seqs, cams, strict=False): + records.append({ + "id": seq.id, + "pose_id": seq.pose_id, + "lat": cam.lat, + "lon": cam.lon, + "sequence_azimuth": seq.sequence_azimuth, + "cone_angle": seq.cone_angle, + "is_wildfire": seq.is_wildfire, + "started_at": seq.started_at, + "last_seen_at": seq.last_seen_at, + }) + df = compute_overlap(pd.DataFrame.from_records(records)) + loc = next((loc for locs in df["event_smoke_locations"].tolist() for loc in locs if loc is not None), None) + + await alerts.update( + alert_id, + AlertUpdate( + started_at=new_start, + last_seen_at=new_last, + lat=loc[0] if loc else None, + lon=loc[1] if loc else None, + ), + ) diff --git a/src/tests/endpoints/test_alerts.py b/src/tests/endpoints/test_alerts.py index 77bb9724..d31c27b5 100644 --- a/src/tests/endpoints/test_alerts.py +++ b/src/tests/endpoints/test_alerts.py @@ -317,3 +317,158 @@ async def test_triangulation_creates_single_alert( remaining_ids = {seq.id for seq in sequences if seq.id != sequences[1].id} updated_mappings = {(aid, sid) for aid, sid in mappings_after_other if aid == initial_alert_id} assert updated_mappings == {(initial_alert_id, sid) for sid in remaining_ids} + + +@pytest.mark.asyncio +async def test_unmatch_creates_new_alert(async_client: AsyncClient, detection_session: AsyncSession): + alert, seq_ids = await _create_alert_with_sequences( + detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256 + ) + target_seq = seq_ids[0] + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + + resp = await async_client.post(f"/alerts/{alert.id}/sequences/{target_seq}/unmatch", headers=auth) + assert resp.status_code == 200, resp.text + body = resp.json() + assert body is not None + assert body["id"] != alert.id + assert body["organization_id"] == alert.organization_id + assert body["lat"] is None + assert body["lon"] is None + assert {seq["id"] for seq in body["sequences"]} == {target_seq} + + mappings_res = await detection_session.exec( + select(AlertSequence.alert_id, AlertSequence.sequence_id).where( + cast(Any, AlertSequence.sequence_id).in_(seq_ids) + ) + ) + mappings = set(mappings_res.all()) + assert (alert.id, target_seq) not in mappings + assert (body["id"], target_seq) in mappings + remaining_for_original = {sid for aid, sid in mappings if aid == alert.id} + assert remaining_for_original == set(seq_ids[1:]) + + +@pytest.mark.asyncio +async def test_unmatch_keeps_sequence_when_already_linked_elsewhere( + async_client: AsyncClient, detection_session: AsyncSession +): + alert, seq_ids = await _create_alert_with_sequences( + detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256 + ) + target_seq = seq_ids[0] + + other_alert = Alert( + organization_id=alert.organization_id, + lat=None, + lon=None, + started_at=alert.started_at, + last_seen_at=alert.last_seen_at, + ) + detection_session.add(other_alert) + await detection_session.commit() + await detection_session.refresh(other_alert) + detection_session.add(AlertSequence(alert_id=other_alert.id, sequence_id=target_seq)) + await detection_session.commit() + + auth = pytest.get_token( + pytest.user_table[1]["id"], pytest.user_table[1]["role"].split(), pytest.user_table[1]["organization_id"] + ) + resp = await async_client.post(f"/alerts/{alert.id}/sequences/{target_seq}/unmatch", headers=auth) + assert resp.status_code == 200, resp.text + assert resp.json() is None + + mappings_res = await detection_session.exec( + select(AlertSequence.alert_id, AlertSequence.sequence_id).where(AlertSequence.sequence_id == target_seq) + ) + alert_ids_for_seq = {aid for aid, _ in mappings_res.all()} + assert alert_ids_for_seq == {other_alert.id} + + +@pytest.mark.asyncio +async def test_unmatch_rejects_single_sequence_alert(async_client: AsyncClient, detection_session: AsyncSession): + now = utcnow() + seq = Sequence( + camera_id=1, + pose_id=None, + camera_azimuth=10.0, + is_wildfire=None, + sequence_azimuth=10.0, + cone_angle=1.0, + started_at=now - timedelta(seconds=30), + last_seen_at=now, + ) + detection_session.add(seq) + await detection_session.commit() + await detection_session.refresh(seq) + + alert = Alert( + organization_id=1, + lat=None, + lon=None, + started_at=seq.started_at, + last_seen_at=seq.last_seen_at, + ) + detection_session.add(alert) + await detection_session.commit() + await detection_session.refresh(alert) + detection_session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id)) + await detection_session.commit() + + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + resp = await async_client.post(f"/alerts/{alert.id}/sequences/{seq.id}/unmatch", headers=auth) + assert resp.status_code == 400, resp.text + + +@pytest.mark.asyncio +async def test_unmatch_returns_404_when_sequence_not_linked(async_client: AsyncClient, detection_session: AsyncSession): + alert, _ = await _create_alert_with_sequences( + detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256 + ) + other_seq = Sequence( + camera_id=1, + pose_id=None, + camera_azimuth=10.0, + is_wildfire=None, + sequence_azimuth=10.0, + cone_angle=1.0, + started_at=utcnow() - timedelta(seconds=30), + last_seen_at=utcnow(), + ) + detection_session.add(other_seq) + await detection_session.commit() + await detection_session.refresh(other_seq) + + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + resp = await async_client.post(f"/alerts/{alert.id}/sequences/{other_seq.id}/unmatch", headers=auth) + assert resp.status_code == 404, resp.text + + +@pytest.mark.asyncio +async def test_unmatch_forbidden_for_user_role(async_client: AsyncClient, detection_session: AsyncSession): + alert, seq_ids = await _create_alert_with_sequences( + detection_session, org_id=2, camera_id=2, lat=48.3856355, lon=2.7323256 + ) + auth = pytest.get_token( + pytest.user_table[2]["id"], pytest.user_table[2]["role"].split(), pytest.user_table[2]["organization_id"] + ) + resp = await async_client.post(f"/alerts/{alert.id}/sequences/{seq_ids[0]}/unmatch", headers=auth) + assert resp.status_code == 403, resp.text + + +@pytest.mark.asyncio +async def test_unmatch_forbidden_cross_org(async_client: AsyncClient, detection_session: AsyncSession): + other_alert, other_seq_ids = await _create_alert_with_sequences( + detection_session, org_id=2, camera_id=2, lat=48.3856355, lon=2.7323256 + ) + auth = pytest.get_token( + pytest.user_table[1]["id"], pytest.user_table[1]["role"].split(), pytest.user_table[1]["organization_id"] + ) + resp = await async_client.post(f"/alerts/{other_alert.id}/sequences/{other_seq_ids[0]}/unmatch", headers=auth) + assert resp.status_code == 403, resp.text diff --git a/src/tests/endpoints/test_sequences.py b/src/tests/endpoints/test_sequences.py index 6e9f0d17..a0f115df 100644 --- a/src/tests/endpoints/test_sequences.py +++ b/src/tests/endpoints/test_sequences.py @@ -456,7 +456,7 @@ async def test_delete_sequence_refreshes_alert(async_client: AsyncClient, detect @pytest.mark.asyncio -@patch("app.api.api_v1.endpoints.sequences._refresh_alert_state", new_callable=AsyncMock) +@patch("app.api.api_v1.endpoints.sequences.refresh_alert_state", new_callable=AsyncMock) async def test_unit_label_sequence_as_other_smoke_refreshes_alert( mock_refresh_alert_state: AsyncMock, ):