Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,14 @@ def __init__(
# Local backup
self._backup_size = backup_size

# Augment states with API-specific fields
# Augment states with API-specific fields. Anchor the daily pose timestamp at
# construction so a startup after noon does not trigger an immediate send;
# the next noon crossing is the first one that fires.
init_now = datetime.now()
for state in self._states.values():
state["last_image_sent"] = None
state["last_bbox_mask_fetch"] = None
state["last_pose_image_sent"] = None
state["last_pose_image_sent"] = init_now

# Occlusion masks: cam_id -> dict of bboxes (keyed by mask id)
self.occlusion_masks: Dict[str, Dict[Any, Any]] = {}
Expand All @@ -151,7 +154,7 @@ def _new_state(self) -> Dict[str, Any]:
state = super()._new_state()
state["last_image_sent"] = None
state["last_bbox_mask_fetch"] = None
state["last_pose_image_sent"] = None
state["last_pose_image_sent"] = datetime.now()
return state

def heartbeat(self, cam_id: str) -> Response:
Expand Down Expand Up @@ -216,7 +219,7 @@ def predict(
now = datetime.now()
today_noon = now.replace(hour=12, minute=0, second=0, microsecond=0)
last_pose_sent = self._states[cam_key]["last_pose_image_sent"]
if now >= today_noon and (last_pose_sent is None or last_pose_sent < today_noon):
if now >= today_noon and last_pose_sent < today_noon:
_, pose_id = self.cam_creds[cam_id]
ip = cam_id.split("_")[0]
if ip in self.api_client:
Expand Down
80 changes: 78 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import tempfile
import time
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import onnx
import pytest
Expand Down Expand Up @@ -266,6 +266,82 @@ def test_fill_empty_bboxes_all_empty_for_cam(tmp_path):
assert all(not alert["bboxes"] for alert in engine._alerts)


def _build_engine_with_pose_stub(tmp_path, init_clock):
"""Build an Engine with the api_client stubbed and datetime.now() pinned to init_clock."""
cam_id = "169.254.7.3_3"
cam_creds = {cam_id: ("dummy_token", 7)}

class _FrozenDateTime(datetime):
@classmethod
def now(cls, tz=None) -> datetime:
return init_clock if tz is None else init_clock.replace(tzinfo=tz)

fake_client = MagicMock()
fake_client.update_pose_image.return_value = MagicMock(text="ok")
fake_client.update_last_image.return_value = MagicMock(text="ok")
fake_client.list_pose_masks.return_value = MagicMock(
raise_for_status=MagicMock(),
json=MagicMock(return_value=[]),
)
fake_client.heartbeat.return_value = MagicMock()

with (
patch("pyroengine.engine.datetime", _FrozenDateTime),
patch("pyroengine.engine.client.Client", return_value=fake_client),
):
engine = Engine(api_url="http://stub", cache_folder=str(tmp_path), cam_creds=cam_creds)

return engine, fake_client, cam_id


def _run_predict_at(engine, cam_id, image, run_clock):
class _RunDateTime(datetime):
@classmethod
def now(cls, tz=None) -> datetime:
return run_clock if tz is None else run_clock.replace(tzinfo=tz)

with patch("pyroengine.engine.datetime", _RunDateTime):
engine.predict(image, cam_id)


def test_pose_image_skipped_when_engine_starts_after_noon(tmp_path, mock_forest_image):
init_clock = datetime(2026, 5, 1, 14, 0, 0)
engine, fake_client, cam_id = _build_engine_with_pose_stub(tmp_path, init_clock)

_run_predict_at(engine, cam_id, mock_forest_image, init_clock + timedelta(seconds=1))

fake_client.update_pose_image.assert_not_called()


def test_pose_image_sent_at_noon_crossing(tmp_path, mock_forest_image):
init_clock = datetime(2026, 5, 1, 11, 30, 0)
engine, fake_client, cam_id = _build_engine_with_pose_stub(tmp_path, init_clock)

# Before noon: no send.
_run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 11, 59, 0))
fake_client.update_pose_image.assert_not_called()

# After noon: one send.
_run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 12, 0, 5))
assert fake_client.update_pose_image.call_count == 1

# Same day, no resend.
_run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 13, 0, 0))
assert fake_client.update_pose_image.call_count == 1


def test_pose_image_sent_again_next_day(tmp_path, mock_forest_image):
init_clock = datetime(2026, 5, 1, 9, 0, 0)
engine, fake_client, cam_id = _build_engine_with_pose_stub(tmp_path, init_clock)

_run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 12, 0, 5))
assert fake_client.update_pose_image.call_count == 1

# Next day at noon: another send.
_run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 2, 12, 0, 5))
assert fake_client.update_pose_image.call_count == 2


def test_engine_occlusion(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image):
# Cache
folder = str(tmpdir_factory.mktemp("engine_cache"))
Expand Down
Loading