diff --git a/pyroengine/engine.py b/pyroengine/engine.py index e7aa0eb3..d9fbca82 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -132,11 +132,14 @@ def __init__( # Local backup self._backup_size = backup_size - # Augment states with API-specific fields + # Augment states with API-specific fields. Anchor the daily pose timestamp at + # construction so a startup after noon does not trigger an immediate send; + # the next noon crossing is the first one that fires. + init_now = datetime.now() for state in self._states.values(): state["last_image_sent"] = None state["last_bbox_mask_fetch"] = None - state["last_pose_image_sent"] = None + state["last_pose_image_sent"] = init_now # Occlusion masks: cam_id -> dict of bboxes (keyed by mask id) self.occlusion_masks: Dict[str, Dict[Any, Any]] = {} @@ -151,7 +154,7 @@ def _new_state(self) -> Dict[str, Any]: state = super()._new_state() state["last_image_sent"] = None state["last_bbox_mask_fetch"] = None - state["last_pose_image_sent"] = None + state["last_pose_image_sent"] = datetime.now() return state def heartbeat(self, cam_id: str) -> Response: @@ -216,7 +219,7 @@ def predict( now = datetime.now() today_noon = now.replace(hour=12, minute=0, second=0, microsecond=0) last_pose_sent = self._states[cam_key]["last_pose_image_sent"] - if now >= today_noon and (last_pose_sent is None or last_pose_sent < today_noon): + if now >= today_noon and last_pose_sent < today_noon: _, pose_id = self.cam_creds[cam_id] ip = cam_id.split("_")[0] if ip in self.api_client: diff --git a/tests/test_engine.py b/tests/test_engine.py index 6d2ed194..d359a711 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,9 +1,9 @@ import os import tempfile import time -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import onnx import pytest @@ -266,6 +266,82 @@ def test_fill_empty_bboxes_all_empty_for_cam(tmp_path): assert all(not alert["bboxes"] for alert in engine._alerts) +def _build_engine_with_pose_stub(tmp_path, init_clock): + """Build an Engine with the api_client stubbed and datetime.now() pinned to init_clock.""" + cam_id = "169.254.7.3_3" + cam_creds = {cam_id: ("dummy_token", 7)} + + class _FrozenDateTime(datetime): + @classmethod + def now(cls, tz=None) -> datetime: + return init_clock if tz is None else init_clock.replace(tzinfo=tz) + + fake_client = MagicMock() + fake_client.update_pose_image.return_value = MagicMock(text="ok") + fake_client.update_last_image.return_value = MagicMock(text="ok") + fake_client.list_pose_masks.return_value = MagicMock( + raise_for_status=MagicMock(), + json=MagicMock(return_value=[]), + ) + fake_client.heartbeat.return_value = MagicMock() + + with ( + patch("pyroengine.engine.datetime", _FrozenDateTime), + patch("pyroengine.engine.client.Client", return_value=fake_client), + ): + engine = Engine(api_url="http://stub", cache_folder=str(tmp_path), cam_creds=cam_creds) + + return engine, fake_client, cam_id + + +def _run_predict_at(engine, cam_id, image, run_clock): + class _RunDateTime(datetime): + @classmethod + def now(cls, tz=None) -> datetime: + return run_clock if tz is None else run_clock.replace(tzinfo=tz) + + with patch("pyroengine.engine.datetime", _RunDateTime): + engine.predict(image, cam_id) + + +def test_pose_image_skipped_when_engine_starts_after_noon(tmp_path, mock_forest_image): + init_clock = datetime(2026, 5, 1, 14, 0, 0) + engine, fake_client, cam_id = _build_engine_with_pose_stub(tmp_path, init_clock) + + _run_predict_at(engine, cam_id, mock_forest_image, init_clock + timedelta(seconds=1)) + + fake_client.update_pose_image.assert_not_called() + + +def test_pose_image_sent_at_noon_crossing(tmp_path, mock_forest_image): + init_clock = datetime(2026, 5, 1, 11, 30, 0) + engine, fake_client, cam_id = _build_engine_with_pose_stub(tmp_path, init_clock) + + # Before noon: no send. + _run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 11, 59, 0)) + fake_client.update_pose_image.assert_not_called() + + # After noon: one send. + _run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 12, 0, 5)) + assert fake_client.update_pose_image.call_count == 1 + + # Same day, no resend. + _run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 13, 0, 0)) + assert fake_client.update_pose_image.call_count == 1 + + +def test_pose_image_sent_again_next_day(tmp_path, mock_forest_image): + init_clock = datetime(2026, 5, 1, 9, 0, 0) + engine, fake_client, cam_id = _build_engine_with_pose_stub(tmp_path, init_clock) + + _run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 1, 12, 0, 5)) + assert fake_client.update_pose_image.call_count == 1 + + # Next day at noon: another send. + _run_predict_at(engine, cam_id, mock_forest_image, datetime(2026, 5, 2, 12, 0, 5)) + assert fake_client.update_pose_image.call_count == 2 + + def test_engine_occlusion(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image): # Cache folder = str(tmpdir_factory.mktemp("engine_cache"))