Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/app/api/api_v1/endpoints/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ async def _get_recent_sequences(
sequences: SequenceCRUD,
camera_ids: List[int],
sequence_: Sequence,
reference_time: Optional[datetime] = None,
) -> List[Sequence]:
anchor = reference_time if reference_time is not None else utcnow()
recent_sequences = await sequences.fetch_all(
in_pair=("camera_id", camera_ids),
inequality_pair=(
"last_seen_at",
">",
utcnow() - timedelta(seconds=settings.SEQUENCE_RELAXATION_SECONDS),
anchor - timedelta(seconds=settings.SEQUENCE_RELAXATION_SECONDS),
),
)
if all(seq.id != sequence_.id for seq in recent_sequences):
Expand Down Expand Up @@ -294,18 +296,19 @@ def _build_links_for_group(
return links


async def _attach_sequence_to_alert(
async def attach_sequence_to_alert(
sequence_: Sequence,
camera: Camera,
cameras: CameraCRUD,
sequences: SequenceCRUD,
alerts: AlertCRUD,
reference_time: Optional[datetime] = None,
) -> Optional[int]:
"""Assign the given sequence to an alert based on cone/time overlap."""
camera_by_id = await _get_camera_by_id(camera, cameras, sequence_.camera_id)

# Fetch recent sequences for the organization based on recency of last_seen_at
recent_sequences = await _get_recent_sequences(sequences, list(camera_by_id.keys()), sequence_)
recent_sequences = await _get_recent_sequences(sequences, list(camera_by_id.keys()), sequence_, reference_time)

# Build DataFrame for overlap computation
records = _build_overlap_records(recent_sequences, camera_by_id)
Expand Down Expand Up @@ -482,7 +485,7 @@ async def create_detection(
if det_.id == det.id:
det = updated

alert_id = await _attach_sequence_to_alert(sequence_, camera, cameras, sequences, alerts)
alert_id = await attach_sequence_to_alert(sequence_, camera, cameras, sequences, alerts)

# Webhooks
whs = await webhooks.fetch_all()
Expand Down
37 changes: 27 additions & 10 deletions src/app/api/api_v1/endpoints/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlmodel import delete, func, select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.api.api_v1.endpoints.detections import attach_sequence_to_alert
from app.api.dependencies import get_alert_crud, get_camera_crud, get_detection_crud, get_jwt, get_sequence_crud
from app.core.time import utcnow
from app.crud import AlertCRUD, CameraCRUD, DetectionCRUD, SequenceCRUD
Expand Down Expand Up @@ -86,6 +87,18 @@ async def _refresh_alert_state(alert_id: int, session: AsyncSession, alerts: Ale
)


async def _detach_sequence_from_alerts(sequence_id: int, session: AsyncSession, alerts: AlertCRUD) -> None:
alert_ids_res = await session.exec(select(AlertSequence.alert_id).where(AlertSequence.sequence_id == sequence_id))
alert_ids = list(alert_ids_res.all())
if not alert_ids:
return
delete_links: Any = delete(AlertSequence).where(cast(Any, AlertSequence.sequence_id) == sequence_id)
await session.exec(delete_links)
await session.commit()
for aid in alert_ids:
await _refresh_alert_state(aid, session, alerts)


def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead:
return SequenceRead(**sequence.model_dump(), detections_count=detections_count)

Expand Down Expand Up @@ -256,20 +269,12 @@ async def label_sequence(
if UserRole.ADMIN not in token_payload.scopes:
await verify_org_rights(token_payload.organization_id, sequence.camera_id, cameras)

previous_label = sequence.is_wildfire
updated = await sequences.update(sequence_id, payload)

# If sequence is labeled as non-wildfire, remove it from alerts and refresh those alerts
if payload.is_wildfire is not None and payload.is_wildfire != AnnotationType.WILDFIRE_SMOKE:
alert_ids_res = await session.exec(
select(AlertSequence.alert_id).where(AlertSequence.sequence_id == sequence_id)
)
alert_ids = list(alert_ids_res.all())
if alert_ids:
delete_links: Any = delete(AlertSequence).where(cast(Any, AlertSequence.sequence_id) == sequence_id)
await session.exec(delete_links)
await session.commit()
for aid in alert_ids:
await _refresh_alert_state(aid, session, alerts)
await _detach_sequence_from_alerts(sequence_id, session, alerts)
# Create a fresh alert for this sequence alone
camera = cast(Camera, await cameras.get(sequence.camera_id, strict=True))
new_alert = await alerts.create(
Expand All @@ -283,5 +288,17 @@ async def label_sequence(
)
session.add(AlertSequence(alert_id=new_alert.id, sequence_id=sequence_id))
await session.commit()
# Reverting a previously non-wildfire label back to wildfire_smoke: re-run cone matching
elif (
payload.is_wildfire == AnnotationType.WILDFIRE_SMOKE
and previous_label is not None
and previous_label != AnnotationType.WILDFIRE_SMOKE
):
await _detach_sequence_from_alerts(sequence_id, session, alerts)
camera = cast(Camera, await cameras.get(sequence.camera_id, strict=True))
# Anchor the candidate window on the sequence's own time so old relabels still merge
await attach_sequence_to_alert(
updated, camera, cameras, sequences, alerts, reference_time=sequence.last_seen_at
)

return updated
10 changes: 5 additions & 5 deletions src/tests/endpoints/test_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from app.api.api_v1.endpoints import detections as detections_api
from app.api.api_v1.endpoints.detections import (
_attach_sequence_to_alert,
_build_links_for_group,
_build_overlap_records,
_fetch_alert_mapping,
Expand All @@ -25,6 +24,7 @@
_maybe_update_alert,
_parse_bbox,
_resolve_groups_and_locations,
attach_sequence_to_alert,
create_detection,
)
from app.core.config import settings
Expand Down Expand Up @@ -663,7 +663,7 @@ async def test_attach_sequence_to_alert_returns_without_overlap_records(detectio
await detection_session.commit()
await detection_session.refresh(sequence)

await _attach_sequence_to_alert(sequence, camera, cam_crud, seq_crud, alert_crud)
await attach_sequence_to_alert(sequence, camera, cam_crud, seq_crud, alert_crud)

alerts = await alert_crud.fetch_all()
assert alerts == []
Expand Down Expand Up @@ -1225,7 +1225,7 @@ async def test_attach_sequence_to_alert_creates_alert(detection_session: AsyncSe
await detection_session.refresh(seq1)
await detection_session.refresh(seq2)

await _attach_sequence_to_alert(seq2, cam2, cam_crud, seq_crud, alert_crud)
await attach_sequence_to_alert(seq2, cam2, cam_crud, seq_crud, alert_crud)

alerts = await alert_crud.fetch_all()
assert len(alerts) == 1
Expand Down Expand Up @@ -1327,7 +1327,7 @@ async def test_attach_sequence_does_not_bridge_to_distant_alert(detection_sessio
await detection_session.refresh(seq_cam5)

# Step 1 — attach cam5 sequence triangulates with cam7, creates smoke-A alert.
smoke_a_alert_id = await _attach_sequence_to_alert(seq_cam5, cam5, cam_crud, seq_crud, alert_crud)
smoke_a_alert_id = await attach_sequence_to_alert(seq_cam5, cam5, cam_crud, seq_crud, alert_crud)
assert smoke_a_alert_id is not None
smoke_a = await alert_crud.get(smoke_a_alert_id, strict=True)
assert smoke_a.lat is not None
Expand All @@ -1348,7 +1348,7 @@ async def test_attach_sequence_does_not_bridge_to_distant_alert(detection_sessio
await detection_session.commit()
await detection_session.refresh(seq_cam2)

target_id = await _attach_sequence_to_alert(seq_cam2, cam2, cam_crud, seq_crud, alert_crud)
target_id = await attach_sequence_to_alert(seq_cam2, cam2, cam_crud, seq_crud, alert_crud)

# The cam2 sequence must land on a NEW alert, not the smoke-A one.
assert target_id is not None
Expand Down
67 changes: 67 additions & 0 deletions src/tests/endpoints/test_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,73 @@
assert updated_sequence.is_wildfire == AnnotationType.OTHER_SMOKE


@pytest.mark.asyncio
@patch("app.api.api_v1.endpoints.sequences.attach_sequence_to_alert", new_callable=AsyncMock)
@patch("app.api.api_v1.endpoints.sequences._refresh_alert_state", new_callable=AsyncMock)
async def test_unit_relabel_sequence_to_wildfire_smoke_reattaches(
mock_refresh_alert_state: AsyncMock,
mock_attach_sequence_to_alert: AsyncMock,
):
"""Reverting a non-wildfire label back to wildfire_smoke should drop the lonely alert

Check notice on line 667 in src/tests/endpoints/test_sequences.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/tests/endpoints/test_sequences.py#L667

1 blank line required between summary line and description (found 0) (D205)

Check notice on line 667 in src/tests/endpoints/test_sequences.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/tests/endpoints/test_sequences.py#L667

First line should end with a period, question mark, or exclamation point (not 't') (D415)

Check notice on line 667 in src/tests/endpoints/test_sequences.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/tests/endpoints/test_sequences.py#L667

Multi-line docstring closing quotes should be on a separate line (D209)

Check notice on line 667 in src/tests/endpoints/test_sequences.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/tests/endpoints/test_sequences.py#L667

Multi-line docstring summary should start at the second line (D213)
and re-run cone matching to merge the sequence into an overlapping alert."""
mock_sequence = Sequence(
id=1,
camera_id=1,
is_wildfire=AnnotationType.OTHER_SMOKE,
started_at=utcnow(),
last_seen_at=utcnow(),
)
mock_camera = Camera(id=1, organization_id=1)

mock_sequences_crud = AsyncMock()
mock_sequences_crud.get.return_value = mock_sequence
updated_seq = Sequence(
id=1,
camera_id=1,
is_wildfire=AnnotationType.WILDFIRE_SMOKE,
started_at=mock_sequence.started_at,
last_seen_at=mock_sequence.last_seen_at,
)
mock_sequences_crud.update.return_value = updated_seq

mock_cameras_crud = AsyncMock()
mock_cameras_crud.get.return_value = mock_camera

mock_alerts_crud = AsyncMock()

mock_session = AsyncMock()
mock_exec_result = MagicMock()
mock_exec_result.all.return_value = [202] # Currently linked to lonely alert 202
mock_session.exec.return_value = mock_exec_result

mock_token_payload = TokenPayload(sub=1, scopes=[UserRole.AGENT], organization_id=1)
payload = SequenceLabel(is_wildfire=AnnotationType.WILDFIRE_SMOKE)

result = await label_sequence(
payload=payload,
sequence_id=1,
cameras=mock_cameras_crud,
sequences=mock_sequences_crud,
alerts=mock_alerts_crud,
session=mock_session,
token_payload=mock_token_payload,
)

mock_sequences_crud.update.assert_called_once_with(1, payload)
assert mock_session.exec.call_count == 2 # fetch alert_ids + delete links
mock_refresh_alert_state.assert_called_once_with(202, mock_session, mock_alerts_crud)
mock_attach_sequence_to_alert.assert_awaited_once_with(
updated_seq,
mock_camera,
mock_cameras_crud,
mock_sequences_crud,
mock_alerts_crud,
reference_time=mock_sequence.last_seen_at,
)
mock_alerts_crud.create.assert_not_called()
assert result.is_wildfire == AnnotationType.WILDFIRE_SMOKE


@pytest.mark.asyncio
async def test_unit_label_sequence_as_wildfire_smoke_does_not_refresh():
"""Verify that labeling a sequence as wildfire smoke does NOT trigger an alert refresh."""
Expand Down
Loading