Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions src/app/api/api_v1/endpoints/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.


from datetime import date, timedelta
from typing import Any, Dict, List, Union, cast
import csv
import io
from datetime import date, datetime, time, timedelta
from typing import Any, Dict, Iterable, Iterator, List, Union, cast

from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status
from fastapi.responses import StreamingResponse
from sqlalchemy import asc, desc
from sqlmodel import delete, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
Expand Down Expand Up @@ -53,6 +56,73 @@ def _serialize_alert(alert: Alert, sequences: List[Sequence]) -> AlertReadWithSe
)


_ALERT_EXPORT_COLUMNS = ["id", "lat", "lon", "started_at", "last_seen_at"]


def _iter_alerts_csv(alerts: Iterable[Alert]) -> Iterator[str]:
buf = io.StringIO()
writer = csv.writer(buf)
writer.writerow(_ALERT_EXPORT_COLUMNS)
yield buf.getvalue()
buf.seek(0)
buf.truncate(0)
for a in alerts:
writer.writerow([
a.id,
"" if a.lat is None else a.lat,
"" if a.lon is None else a.lon,
a.started_at.isoformat(),
a.last_seen_at.isoformat(),
])
yield buf.getvalue()
buf.seek(0)
buf.truncate(0)


def _build_alerts_csv_response(alerts: List[Alert], from_date: date, to_date: date) -> StreamingResponse:
filename = f"alerts_{from_date.isoformat()}_{to_date.isoformat()}.csv"
headers = {"Content-Disposition": f'attachment; filename="{filename}"'}
return StreamingResponse(_iter_alerts_csv(alerts), media_type="text/csv", headers=headers)


@router.get(
"/export",
status_code=status.HTTP_200_OK,
summary="Export alerts in a date range as CSV",
response_class=StreamingResponse,
)
async def export_alerts_csv(
from_date: date = Query(..., description="Inclusive lower bound on started_at (UTC date)"),
to_date: date = Query(..., description="Inclusive upper bound on started_at (UTC date)"),
session: AsyncSession = Depends(get_session),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> StreamingResponse:
telemetry_client.capture(
token_payload.sub,
event="alerts-export",
properties={"from_date": from_date.isoformat(), "to_date": to_date.isoformat()},
)

if to_date < from_date:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="to_date must be on or after from_date",
)

# DB columns store naive UTC datetimes (see app.core.time.utcnow), so we drop tzinfo here.
start_dt = datetime.combine(from_date, time.min)
end_dt = datetime.combine(to_date, time.max)

stmt: Any = (
select(Alert)
.where(Alert.organization_id == token_payload.organization_id)
.where(Alert.started_at >= start_dt)
.where(Alert.started_at <= end_dt)
.order_by(Alert.started_at.asc()) # type: ignore[attr-defined]
)
return _build_alerts_csv_response(list((await session.exec(stmt)).all()), from_date, to_date)


@router.get("/{alert_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific alert")
async def get_alert(
alert_id: int = Path(..., gt=0),
Expand Down
164 changes: 163 additions & 1 deletion src/tests/endpoints/test_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from datetime import timedelta
import csv
import io
from datetime import datetime, timedelta
from typing import Any, List, Tuple, cast

import pandas as pd
Expand Down Expand Up @@ -317,3 +319,163 @@ async def test_triangulation_creates_single_alert(
remaining_ids = {seq.id for seq in sequences if seq.id != sequences[1].id}
updated_mappings = {(aid, sid) for aid, sid in mappings_after_other if aid == initial_alert_id}
assert updated_mappings == {(initial_alert_id, sid) for sid in remaining_ids}


async def _create_alert(
session: AsyncSession,
org_id: int,
started_at: datetime,
last_seen_at: datetime,
lat: float | None = 48.0,
lon: float | None = 2.0,
) -> Alert:
alert = Alert(
organization_id=org_id,
lat=lat,
lon=lon,
started_at=started_at,
last_seen_at=last_seen_at,
)
session.add(alert)
await session.commit()
await session.refresh(alert)
return alert


def _parse_csv_body(body: str) -> Tuple[List[str], List[List[str]]]:
reader = csv.reader(io.StringIO(body))
rows = list(reader)
return rows[0], rows[1:]


@pytest.mark.asyncio
async def test_alerts_export_happy_path(async_client: AsyncClient, detection_session: AsyncSession):
base = datetime(2026, 4, 10, 12, 0, 0)
alerts = [
await _create_alert(detection_session, 1, base, base + timedelta(minutes=5), 48.1, 2.1),
await _create_alert(
detection_session, 1, base + timedelta(days=1), base + timedelta(days=1, minutes=5), 48.2, 2.2
),
await _create_alert(
detection_session, 1, base + timedelta(days=2), base + timedelta(days=2, minutes=5), 48.3, 2.3
),
]

auth = pytest.get_token(
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
)
resp = await async_client.get(
"/alerts/export?from_date=2026-04-10&to_date=2026-04-12",
headers=auth,
)
assert resp.status_code == 200, resp.text
assert resp.headers["content-type"].startswith("text/csv")
assert "attachment" in resp.headers["content-disposition"]
assert "alerts_2026-04-10_2026-04-12.csv" in resp.headers["content-disposition"]

header, data_rows = _parse_csv_body(resp.text)
assert header == ["id", "lat", "lon", "started_at", "last_seen_at"]
assert [int(r[0]) for r in data_rows] == [a.id for a in alerts]
# ordering is ascending by started_at
started_values = [r[3] for r in data_rows]
assert started_values == sorted(started_values)
# spot-check values for the first row
assert float(data_rows[0][1]) == pytest.approx(48.1)
assert float(data_rows[0][2]) == pytest.approx(2.1)
assert data_rows[0][3] == alerts[0].started_at.isoformat()
assert data_rows[0][4] == alerts[0].last_seen_at.isoformat()


@pytest.mark.asyncio
async def test_alerts_export_window_narrows(async_client: AsyncClient, detection_session: AsyncSession):
base = datetime(2026, 4, 10, 12, 0, 0)
await _create_alert(detection_session, 1, base, base + timedelta(minutes=5))
a_in = await _create_alert(detection_session, 1, base + timedelta(days=1), base + timedelta(days=1, minutes=5))
await _create_alert(detection_session, 1, base + timedelta(days=2), base + timedelta(days=2, minutes=5))

auth = pytest.get_token(
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
)
resp = await async_client.get(
"/alerts/export?from_date=2026-04-11&to_date=2026-04-11",
headers=auth,
)
assert resp.status_code == 200, resp.text
_, data_rows = _parse_csv_body(resp.text)
returned_ids = {int(r[0]) for r in data_rows}
assert returned_ids == {a_in.id}


@pytest.mark.asyncio
async def test_alerts_export_org_isolation(async_client: AsyncClient, detection_session: AsyncSession):
base = datetime(2026, 4, 10, 12, 0, 0)
org1_alert = await _create_alert(detection_session, 1, base, base + timedelta(minutes=5))
org2_alert = await _create_alert(detection_session, 2, base, base + timedelta(minutes=5))

# Call as a non-admin user from org 1
auth = pytest.get_token(
pytest.user_table[1]["id"], pytest.user_table[1]["role"].split(), pytest.user_table[1]["organization_id"]
)
resp = await async_client.get(
"/alerts/export?from_date=2026-04-10&to_date=2026-04-10",
headers=auth,
)
assert resp.status_code == 200, resp.text
_, data_rows = _parse_csv_body(resp.text)
returned_ids = {int(r[0]) for r in data_rows}
assert org1_alert.id in returned_ids
assert org2_alert.id not in returned_ids


@pytest.mark.asyncio
async def test_alerts_export_empty_range(async_client: AsyncClient, detection_session: AsyncSession):
auth = pytest.get_token(
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
)
resp = await async_client.get(
"/alerts/export?from_date=2099-01-01&to_date=2099-01-31",
headers=auth,
)
assert resp.status_code == 200, resp.text
header, data_rows = _parse_csv_body(resp.text)
assert header == ["id", "lat", "lon", "started_at", "last_seen_at"]
assert data_rows == []


@pytest.mark.asyncio
async def test_alerts_export_renders_null_coordinates_as_empty(
async_client: AsyncClient, detection_session: AsyncSession
):
base = datetime(2026, 4, 10, 12, 0, 0)
alert = await _create_alert(detection_session, 1, base, base + timedelta(minutes=5), lat=None, lon=None)

auth = pytest.get_token(
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
)
resp = await async_client.get(
"/alerts/export?from_date=2026-04-10&to_date=2026-04-10",
headers=auth,
)
assert resp.status_code == 200, resp.text
_, data_rows = _parse_csv_body(resp.text)
row = next(r for r in data_rows if int(r[0]) == alert.id)
assert row[1] == ""
assert row[2] == ""


@pytest.mark.asyncio
async def test_alerts_export_invalid_range(async_client: AsyncClient, detection_session: AsyncSession):
auth = pytest.get_token(
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
)
resp = await async_client.get(
"/alerts/export?from_date=2026-04-12&to_date=2026-04-10",
headers=auth,
)
assert resp.status_code == 422, resp.text


@pytest.mark.asyncio
async def test_alerts_export_unauthenticated(async_client: AsyncClient, detection_session: AsyncSession):
resp = await async_client.get("/alerts/export?from_date=2026-04-10&to_date=2026-04-12")
assert resp.status_code == 401, resp.text
Loading