Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def predict(
if cam_key not in self._states:
self._states[cam_key] = self._new_state()

# Keep the pre-resize frame so detection crops can be taken at original resolution.
original_frame = frame

# Reduce image size to save bandwidth
if isinstance(self.frame_size, tuple):
target = (self.frame_size[1], self.frame_size[0]) # PIL expects (W, H)
Expand Down Expand Up @@ -265,7 +268,8 @@ def predict(
preds = np.reshape(preds, (-1, 5))

logger.info(f"pred for {cam_key} : {preds}")
conf = self._update_states(frame, preds, cam_key, encoded_bytes=encoded_bytes)
# Store the original frame in state so _process_alerts can crop at full resolution.
conf = self._update_states(original_frame, preds, cam_key, encoded_bytes=encoded_bytes)

if self.save_captured_frames:
self._local_backup(frame, cam_id, is_alert=False, encoded_bytes=encoded_bytes)
Expand Down Expand Up @@ -294,6 +298,63 @@ def predict(

return float(conf)

@staticmethod
def _compute_crop_box(
bboxes: list,
img_w: int,
img_h: int,
padding: float = 0.20,
) -> Tuple[int, int, int, int]:
"""Square crop covering all bboxes (normalized coords) with `padding` on the largest dim."""
arr = np.asarray(bboxes, dtype=float)
x1 = float(arr[:, 0].min()) * img_w
y1 = float(arr[:, 1].min()) * img_h
x2 = float(arr[:, 2].max()) * img_w
y2 = float(arr[:, 3].max()) * img_h

side = max(x2 - x1, y2 - y1) * (1.0 + padding)
side = min(side, float(min(img_w, img_h)))

cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
half = side / 2.0
left, top, right, bottom = cx - half, cy - half, cx + half, cy + half

# Shift back inside the image to keep the crop square instead of clipping.
if left < 0:
right -= left
left = 0
if top < 0:
bottom -= top
top = 0
if right > img_w:
left -= right - img_w
right = img_w
if bottom > img_h:
top -= bottom - img_h
bottom = img_h

return round(left), round(top), round(right), round(bottom)

def _encode_detection_crop(self, frame: Image.Image, bboxes: list) -> Optional[bytes]:
"""Crop the original frame around bboxes and encode the 224x224 JPEG to upload."""
if not bboxes:
return None
img_w, img_h = frame.size
box = self._compute_crop_box(bboxes, img_w, img_h, padding=0.20)
crop = frame.crop(box)
crop_w, crop_h = crop.size
downscaling = crop_w > 224 or crop_h > 224
if (crop_w, crop_h) != (224, 224):
crop = crop.resize((224, 224), Image.LANCZOS) # type: ignore[attr-defined]
buf = io.BytesIO()
if downscaling:
crop.save(buf, format="JPEG", quality=95)
else:
# Crop was at or below 224 — preserve detail with no chroma subsampling.
crop.save(buf, format="JPEG", quality=100, subsampling=0, optimize=True)
return buf.getvalue()

def _stage_alert(
self,
frame: Image.Image,
Expand Down Expand Up @@ -363,9 +424,10 @@ def _process_alerts(self) -> None:
frame_info["frame"].save(stream, format="JPEG", quality=self.jpeg_quality)
jpeg_bytes = stream.getvalue()
bboxes = [tuple(bboxe) for bboxe in bboxes]
crop_bytes = self._encode_detection_crop(frame_info["frame"], bboxes)
_, pose_id = self.cam_creds[cam_id]
ip = cam_id.split("_")[0]
response = self.api_client[ip].create_detection(jpeg_bytes, bboxes, pose_id)
response = self.api_client[ip].create_detection(jpeg_bytes, bboxes, pose_id, crop=crop_bytes)

try:
response.json()["id"]
Expand Down
2 changes: 1 addition & 1 deletion requirements-git.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pyroclient @ git+https://github.com/pyronear/pyro-api.git@0790626447bab7c5be86322759975c33f0e54784#subdirectory=client
pyroclient @ git+https://github.com/pyronear/pyro-api.git@6facf2f780e5ab0bb594f88b9f50118bd6f23c81#subdirectory=client
Loading