diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index 8af09c05..a463c8a7 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -4,10 +4,13 @@ # See LICENSE or go to for full license details. -from datetime import date, timedelta -from typing import Any, Dict, List, Union, cast +import csv +import io +from datetime import date, datetime, time, timedelta +from typing import Any, Dict, Iterable, Iterator, List, Union, cast from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status +from fastapi.responses import StreamingResponse from sqlalchemy import asc, desc from sqlmodel import delete, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -53,6 +56,73 @@ def _serialize_alert(alert: Alert, sequences: List[Sequence]) -> AlertReadWithSe ) +_ALERT_EXPORT_COLUMNS = ["id", "lat", "lon", "started_at", "last_seen_at"] + + +def _iter_alerts_csv(alerts: Iterable[Alert]) -> Iterator[str]: + buf = io.StringIO() + writer = csv.writer(buf) + writer.writerow(_ALERT_EXPORT_COLUMNS) + yield buf.getvalue() + buf.seek(0) + buf.truncate(0) + for a in alerts: + writer.writerow([ + a.id, + "" if a.lat is None else a.lat, + "" if a.lon is None else a.lon, + a.started_at.isoformat(), + a.last_seen_at.isoformat(), + ]) + yield buf.getvalue() + buf.seek(0) + buf.truncate(0) + + +def _build_alerts_csv_response(alerts: List[Alert], from_date: date, to_date: date) -> StreamingResponse: + filename = f"alerts_{from_date.isoformat()}_{to_date.isoformat()}.csv" + headers = {"Content-Disposition": f'attachment; filename="{filename}"'} + return StreamingResponse(_iter_alerts_csv(alerts), media_type="text/csv", headers=headers) + + +@router.get( + "/export", + status_code=status.HTTP_200_OK, + summary="Export alerts in a date range as CSV", + response_class=StreamingResponse, +) +async def export_alerts_csv( + from_date: date = Query(..., description="Inclusive lower bound on started_at (UTC date)"), + to_date: date = Query(..., description="Inclusive upper bound on started_at (UTC date)"), + session: AsyncSession = Depends(get_session), + token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), +) -> StreamingResponse: + telemetry_client.capture( + token_payload.sub, + event="alerts-export", + properties={"from_date": from_date.isoformat(), "to_date": to_date.isoformat()}, + ) + + if to_date < from_date: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="to_date must be on or after from_date", + ) + + # DB columns store naive UTC datetimes (see app.core.time.utcnow), so we drop tzinfo here. + start_dt = datetime.combine(from_date, time.min) + end_dt = datetime.combine(to_date, time.max) + + stmt: Any = ( + select(Alert) + .where(Alert.organization_id == token_payload.organization_id) + .where(Alert.started_at >= start_dt) + .where(Alert.started_at <= end_dt) + .order_by(Alert.started_at.asc()) # type: ignore[attr-defined] + ) + return _build_alerts_csv_response(list((await session.exec(stmt)).all()), from_date, to_date) + + @router.get("/{alert_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific alert") async def get_alert( alert_id: int = Path(..., gt=0), diff --git a/src/tests/endpoints/test_alerts.py b/src/tests/endpoints/test_alerts.py index 77bb9724..add6098a 100644 --- a/src/tests/endpoints/test_alerts.py +++ b/src/tests/endpoints/test_alerts.py @@ -3,7 +3,9 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from datetime import timedelta +import csv +import io +from datetime import datetime, timedelta from typing import Any, List, Tuple, cast import pandas as pd @@ -317,3 +319,163 @@ async def test_triangulation_creates_single_alert( remaining_ids = {seq.id for seq in sequences if seq.id != sequences[1].id} updated_mappings = {(aid, sid) for aid, sid in mappings_after_other if aid == initial_alert_id} assert updated_mappings == {(initial_alert_id, sid) for sid in remaining_ids} + + +async def _create_alert( + session: AsyncSession, + org_id: int, + started_at: datetime, + last_seen_at: datetime, + lat: float | None = 48.0, + lon: float | None = 2.0, +) -> Alert: + alert = Alert( + organization_id=org_id, + lat=lat, + lon=lon, + started_at=started_at, + last_seen_at=last_seen_at, + ) + session.add(alert) + await session.commit() + await session.refresh(alert) + return alert + + +def _parse_csv_body(body: str) -> Tuple[List[str], List[List[str]]]: + reader = csv.reader(io.StringIO(body)) + rows = list(reader) + return rows[0], rows[1:] + + +@pytest.mark.asyncio +async def test_alerts_export_happy_path(async_client: AsyncClient, detection_session: AsyncSession): + base = datetime(2026, 4, 10, 12, 0, 0) + alerts = [ + await _create_alert(detection_session, 1, base, base + timedelta(minutes=5), 48.1, 2.1), + await _create_alert( + detection_session, 1, base + timedelta(days=1), base + timedelta(days=1, minutes=5), 48.2, 2.2 + ), + await _create_alert( + detection_session, 1, base + timedelta(days=2), base + timedelta(days=2, minutes=5), 48.3, 2.3 + ), + ] + + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + resp = await async_client.get( + "/alerts/export?from_date=2026-04-10&to_date=2026-04-12", + headers=auth, + ) + assert resp.status_code == 200, resp.text + assert resp.headers["content-type"].startswith("text/csv") + assert "attachment" in resp.headers["content-disposition"] + assert "alerts_2026-04-10_2026-04-12.csv" in resp.headers["content-disposition"] + + header, data_rows = _parse_csv_body(resp.text) + assert header == ["id", "lat", "lon", "started_at", "last_seen_at"] + assert [int(r[0]) for r in data_rows] == [a.id for a in alerts] + # ordering is ascending by started_at + started_values = [r[3] for r in data_rows] + assert started_values == sorted(started_values) + # spot-check values for the first row + assert float(data_rows[0][1]) == pytest.approx(48.1) + assert float(data_rows[0][2]) == pytest.approx(2.1) + assert data_rows[0][3] == alerts[0].started_at.isoformat() + assert data_rows[0][4] == alerts[0].last_seen_at.isoformat() + + +@pytest.mark.asyncio +async def test_alerts_export_window_narrows(async_client: AsyncClient, detection_session: AsyncSession): + base = datetime(2026, 4, 10, 12, 0, 0) + await _create_alert(detection_session, 1, base, base + timedelta(minutes=5)) + a_in = await _create_alert(detection_session, 1, base + timedelta(days=1), base + timedelta(days=1, minutes=5)) + await _create_alert(detection_session, 1, base + timedelta(days=2), base + timedelta(days=2, minutes=5)) + + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + resp = await async_client.get( + "/alerts/export?from_date=2026-04-11&to_date=2026-04-11", + headers=auth, + ) + assert resp.status_code == 200, resp.text + _, data_rows = _parse_csv_body(resp.text) + returned_ids = {int(r[0]) for r in data_rows} + assert returned_ids == {a_in.id} + + +@pytest.mark.asyncio +async def test_alerts_export_org_isolation(async_client: AsyncClient, detection_session: AsyncSession): + base = datetime(2026, 4, 10, 12, 0, 0) + org1_alert = await _create_alert(detection_session, 1, base, base + timedelta(minutes=5)) + org2_alert = await _create_alert(detection_session, 2, base, base + timedelta(minutes=5)) + + # Call as a non-admin user from org 1 + auth = pytest.get_token( + pytest.user_table[1]["id"], pytest.user_table[1]["role"].split(), pytest.user_table[1]["organization_id"] + ) + resp = await async_client.get( + "/alerts/export?from_date=2026-04-10&to_date=2026-04-10", + headers=auth, + ) + assert resp.status_code == 200, resp.text + _, data_rows = _parse_csv_body(resp.text) + returned_ids = {int(r[0]) for r in data_rows} + assert org1_alert.id in returned_ids + assert org2_alert.id not in returned_ids + + +@pytest.mark.asyncio +async def test_alerts_export_empty_range(async_client: AsyncClient, detection_session: AsyncSession): + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + resp = await async_client.get( + "/alerts/export?from_date=2099-01-01&to_date=2099-01-31", + headers=auth, + ) + assert resp.status_code == 200, resp.text + header, data_rows = _parse_csv_body(resp.text) + assert header == ["id", "lat", "lon", "started_at", "last_seen_at"] + assert data_rows == [] + + +@pytest.mark.asyncio +async def test_alerts_export_renders_null_coordinates_as_empty( + async_client: AsyncClient, detection_session: AsyncSession +): + base = datetime(2026, 4, 10, 12, 0, 0) + alert = await _create_alert(detection_session, 1, base, base + timedelta(minutes=5), lat=None, lon=None) + + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + resp = await async_client.get( + "/alerts/export?from_date=2026-04-10&to_date=2026-04-10", + headers=auth, + ) + assert resp.status_code == 200, resp.text + _, data_rows = _parse_csv_body(resp.text) + row = next(r for r in data_rows if int(r[0]) == alert.id) + assert row[1] == "" + assert row[2] == "" + + +@pytest.mark.asyncio +async def test_alerts_export_invalid_range(async_client: AsyncClient, detection_session: AsyncSession): + auth = pytest.get_token( + pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + ) + resp = await async_client.get( + "/alerts/export?from_date=2026-04-12&to_date=2026-04-10", + headers=auth, + ) + assert resp.status_code == 422, resp.text + + +@pytest.mark.asyncio +async def test_alerts_export_unauthenticated(async_client: AsyncClient, detection_session: AsyncSession): + resp = await async_client.get("/alerts/export?from_date=2026-04-10&to_date=2026-04-12") + assert resp.status_code == 401, resp.text