From c9c55e6ef13b671b5bf68005ae54cedb8a200b5a Mon Sep 17 00:00:00 2001 From: YoungKameSennin Date: Sat, 31 May 2025 14:01:12 -0700 Subject: [PATCH 1/7] naive frame skipping --- sam2/sam2_video_predictor.py | 279 ++++++++++++++--------------------- 1 file changed, 109 insertions(+), 170 deletions(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index c88e111..1166c50 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -9,14 +9,29 @@ import torch import torch.nn.functional as F -import numpy as np + from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames -from sam2.utils.change_detection import should_skip_frame -from sam2.utils.optical_flow import warp_mask_forward + +# MGFK Helpers import cv2 +import numpy as np +import time + +def _mean_abs_diff(img1: torch.Tensor, img2: torch.Tensor) -> float: + """ + Compute mean absolute pixel difference on two HWC images in uint8 range [0,255]. + Inputs are 3-D CPU tensors (H,W,C) with dtype uint8. + """ + # convert to NumPy for speed + a = img1.cpu().numpy().astype(np.int16) + b = img2.cpu().numpy().astype(np.int16) + return float(np.mean(np.abs(a - b))) + + + class SAM2VideoPredictor(SAM2Base): """The predictor class to handle user interactions and manage inference states.""" @@ -32,6 +47,7 @@ def __init__( # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames add_all_frames_to_correct_as_cond=False, + skip_mad_threshold: float = 0.05, **kwargs, ): super().__init__(**kwargs) @@ -39,6 +55,7 @@ def __init__( self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.skip_mad_threshold = skip_mad_threshold @torch.inference_mode() def init_state( @@ -96,20 +113,13 @@ def init_state( # (we directly use their consolidated outputs during tracking) # metadata for each tracking frame (e.g. which direction it's tracked) inference_state["frames_tracked_per_obj"] = {} - # Warm up the visual backbone and cache the image feature on frame 0 - self._get_image_feature(inference_state, frame_idx=0, batch_size=1) - first = inference_state["images"][0] - # If CHW, swap dims - if first.ndim == 3 and first.shape[0] in (1,3): - _, h, w = first.shape - elif first.ndim == 3 and first.shape[2] in (1,3): - h, w, _ = first.shape - else: - # Fallback: assume H x W - h, w = first.shape[-2], first.shape[-1] - inference_state["video_size"] = (h, w) + # Add two slots to the inference state so they survive across frames + inference_state["last_processed_frame"] = None # uint8 HWC tensor + inference_state["last_processed_output"] = None # compact_current_out + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @classmethod @@ -293,6 +303,11 @@ def add_new_points_or_box( # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out + # Any user interaction invalidates reuse cache + inference_state["last_processed_frame"] = None + inference_state["last_processed_output"] = None + + # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( @@ -381,6 +396,10 @@ def add_new_mask( # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out + # Any user interaction invalidates reuse cache + inference_state["last_processed_frame"] = None + inference_state["last_processed_output"] = None + # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( @@ -564,201 +583,121 @@ def propagate_in_video( max_frame_num_to_track=None, reverse=False, ): - """ - Propagate the input prompts (points/masks) across all frames - in the video, optionally skipping low-change frames. - Yields (frame_idx, obj_ids, masks) once per frame. - """ - # Prepare various buffers and metadata + """Propagate the input points across frames to track in the entire video.""" self.propagate_in_video_preflight(inference_state) - obj_ids = inference_state["obj_ids"] # list of object IDs - num_frames = inference_state["num_frames"] # total frames - batch_size = len(obj_ids) # number of objects + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) - # Determine start/end indices + # set start index, end index, and processing order if start_frame_idx is None: - # first conditioning frame across all objects + # default: start from the earliest frame with input points start_frame_idx = min( t - for obj_dict in inference_state["output_dict_per_obj"].values() - for t in obj_dict["cond_frame_outputs"] + for obj_output_dict in inference_state["output_dict_per_obj"].values() + for t in obj_output_dict["cond_frame_outputs"] ) if max_frame_num_to_track is None: + # default: track all the frames in the video max_frame_num_to_track = num_frames - if reverse: end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) - processing_order = ( - range(start_frame_idx, end_frame_idx - 1, -1) - if start_frame_idx > 0 - else [] - ) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 else: - end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1) + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) processing_order = range(start_frame_idx, end_frame_idx + 1) - # Track the last full-inference frame and mask for skipping - prev_frame_idx = None - prev_pred_masks = None - - # Iterate through each frame index + t0 = time.time() + count = 0 for frame_idx in tqdm(processing_order, desc="propagate in video"): - # Container for this frame’s predicted masks per object pred_masks_per_obj = [None] * batch_size - - # Loop over each object independently for obj_idx in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] - storage_key = None - pred_masks = None - - # 1) Conditioning frame (manual prompt) + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. if frame_idx in obj_output_dict["cond_frame_outputs"]: storage_key = "cond_frame_outputs" current_out = obj_output_dict[storage_key][frame_idx] device = inference_state["device"] pred_masks = current_out["pred_masks"].to(device, non_blocking=True) - - prev_frame_idx = frame_idx - prev_pred_masks = pred_masks - - # Optional: clear nearby non_cond memory if self.clear_non_cond_mem_around_input: + # clear non-conditioning memory of the surrounding frames self._clear_obj_non_cond_mem_around_input( inference_state, frame_idx, obj_idx ) - - # 2) Already computed non-conditioning frame - elif frame_idx in obj_output_dict["non_cond_frame_outputs"]: - storage_key = "non_cond_frame_outputs" - current_out = obj_output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - - prev_frame_idx = frame_idx - prev_pred_masks = pred_masks - - # 3) New frame: decide skip vs full inference else: storage_key = "non_cond_frame_outputs" - - # --- a) Memory-guided skip? - if ( - self.skip_threshold is not None - and self.skip_threshold > 0 - and prev_frame_idx is not None - ): - prev_frame = inference_state["images"][prev_frame_idx].cpu() - curr_frame = inference_state["images"][frame_idx].cpu() - if should_skip_frame( - prev_frame, curr_frame, threshold=self.skip_threshold - ): - # Warp last mask and upsample - print(f"Skipping frame {frame_idx} (reuse frame {prev_frame_idx}), skip threshold {self.skip_threshold}") - warped_pred = warp_mask_forward( - prev_pred_masks.cpu(), - prev_frame.cpu().numpy(), - curr_frame.cpu().numpy(), + # ========= FRAME-SKIP CHECK (NEW) ========= + if inference_state["last_processed_frame"] is not None: + prev_img = inference_state["last_processed_frame"] + curr_img = inference_state["images"][frame_idx].cpu() # uint8, HWC + mad = _mean_abs_diff(prev_img, curr_img) + if mad < self.skip_mad_threshold: + print( + f"Skipping frame {frame_idx} due to low MAD ({mad:.2f}) ") + # Skip heavy inference: reuse previous output + count += 1 + current_out = inference_state["last_processed_output"] + pred_masks = current_out["pred_masks"].to( + inference_state["device"], non_blocking=True ) + # still need to record bookkeeping so downstream code works + obj_output_dict[storage_key][frame_idx] = current_out + else: + current_out, pred_masks = self._run_single_frame_inference( + inference_state, obj_output_dict, frame_idx, 1, + is_init_cond_frame=False, point_inputs=None, mask_inputs=None, + reverse=reverse, run_mem_encoder=True + ) + obj_output_dict[storage_key][frame_idx] = current_out + # update cache + inference_state["last_processed_frame"] = curr_img + inference_state["last_processed_output"] = current_out + else: + # first time – must run inference + current_out, pred_masks = self._run_single_frame_inference( + inference_state, obj_output_dict, frame_idx, 1, + is_init_cond_frame=False, point_inputs=None, mask_inputs=None, + reverse=reverse, run_mem_encoder=True + ) + obj_output_dict[storage_key][frame_idx] = current_out + # update cache + inference_state["last_processed_frame"] = inference_state["images"][frame_idx].cpu() + inference_state["last_processed_output"] = current_out + # ========= END FRAME-SKIP CHECK ========= + - # --------------------------------------------- - # 2) Shallow‐copy prior full‐inference dict so - # maskmem_features (and other keys) exist - # --------------------------------------------- - prev_mem_out = obj_output_dict["non_cond_frame_outputs"].get(prev_frame_idx) - if prev_mem_out is None: - prev_mem_out = obj_output_dict["cond_frame_outputs"][prev_frame_idx] - obj_output_dict[storage_key][frame_idx] = dict(prev_mem_out) - obj_output_dict[storage_key][frame_idx]["pred_masks"] = warped_pred - - # Now the memory bank for frame_idx has both key and mask features. - - # Pack into (1,C,Hmask,Wmask) - wp = warped_pred - if wp.ndim == 2: - wp = wp.unsqueeze(0).unsqueeze(0) - elif wp.ndim == 3: - wp = wp.unsqueeze(0) - - # Upsample → (1,C,Hvideo,Wvideo) - video_h, video_w = inference_state["video_size"] - upsamp = F.interpolate( - wp, - size=(video_h, video_w), - mode="bilinear", - align_corners=False, - ).squeeze(0) # → (C,Hvideo,Wvideo) - - inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { - "reverse": reverse - } - - # Save full-res mask for concatenation - pred_masks_per_obj[obj_idx] = upsamp - - prev_frame_idx = frame_idx - prev_pred_masks = warped_pred - # done with this object - continue - - # --- b) Fallback: full SAM 2 inference on this object - current_out, pred_masks = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=obj_output_dict, - frame_idx=frame_idx, - batch_size=1, - is_init_cond_frame=False, - point_inputs=None, - mask_inputs=None, - reverse=reverse, - run_mem_encoder=True, - ) - obj_output_dict[storage_key][frame_idx] = current_out - - prev_frame_idx = frame_idx - prev_pred_masks = pred_masks - - # Mark this object as tracked and record its prediction inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { "reverse": reverse } pred_masks_per_obj[obj_idx] = pred_masks - # --- After processing all objects for this frame --- - # Concatenate or select single-object masks - if batch_size > 1: + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + if len(pred_masks_per_obj) > 1: all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) else: all_pred_masks = pred_masks_per_obj[0] + _, video_res_masks = self._get_orig_video_res_output( + inference_state, all_pred_masks + ) - video_H, video_W = inference_state["video_size"] - - # 1) Move to CPU numpy - npm = all_pred_masks.detach().cpu().numpy() - # 2) Squeeze out any leading singleton batch dims until ndim <= 3 - while npm.ndim > 3: - npm = np.squeeze(npm, axis=0) - # 3) Ensure channel axis: if 2-D, add channel dim - if npm.ndim == 2: - npm = npm[np.newaxis, ...] # → shape (1, Hmask, Wmask) - - # Now npm.ndim == 3: (C, Hmask, Wmask) - C, Hm, Wm = npm.shape - # 4) Prepare output array - resized = np.zeros((C, video_H, video_W), dtype=npm.dtype) - # 5) Resize each channel - for c in range(C): - resized[c] = cv2.resize( - npm[c], - dsize=(video_W, video_H), - interpolation=cv2.INTER_LINEAR, - ) - - # 6) Convert back to torch.Tensor on the predictor’s device - device = inference_state["device"] - am_up = torch.from_numpy(resized).to(device) - # 7) Yield one mask-per-frame tensor - yield frame_idx, obj_ids, am_up + if frame_idx % 30 == 0 and frame_idx > 0: + dt = time.time() - t0 + print(f"Avg FPS last 30 frames: {30/dt:.2f}") + t0 = time.time() + yield frame_idx, obj_ids, video_res_masks + # pront the number of skipped frames over the entire video + if count > 0: + print(f"Skipped {count} frame due to low MAD.") @torch.inference_mode() def clear_all_prompts_in_frame( From e0173f7d6110a1c780a80c6a36adedce055b8de6 Mon Sep 17 00:00:00 2001 From: YoungKameSennin Date: Sat, 31 May 2025 22:34:21 -0700 Subject: [PATCH 2/7] Multi-sub updates --- sam2/sam2_video_predictor.py | 152 ++++++++++++++++++++--------------- 1 file changed, 87 insertions(+), 65 deletions(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 1166c50..07709ae 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -614,78 +614,96 @@ def propagate_in_video( processing_order = range(start_frame_idx, end_frame_idx + 1) t0 = time.time() - count = 0 + skipped_ctr = 0 + for frame_idx in tqdm(processing_order, desc="propagate in video"): + + # ========================================================= + # 1) FRAME-LEVEL SKIP DECISION (runs once per frame) + # ========================================================= + reuse_this_frame = False + if ( + inference_state["last_processed_frame"] is not None + and all( + frame_idx not in inference_state["output_dict_per_obj"][i]["cond_frame_outputs"] + for i in range(batch_size) + ) + ): + prev = inference_state["last_processed_frame"] + curr = inference_state["images"][frame_idx].cpu() + mad = _mean_abs_diff(prev, curr) + if mad < self.skip_mad_threshold: + reuse_this_frame = True + skipped_ctr += 1 + print(f"Skipping frame {frame_idx} due to low MAD ({mad:.2f})") + pred_masks_per_obj = [None] * batch_size - for obj_idx in range(batch_size): - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] - # We skip those frames already in consolidated outputs (these are frames - # that received input clicks or mask). Note that we cannot directly run - # batched forward on them via `_run_single_frame_inference` because the - # number of clicks on each object might be different. - if frame_idx in obj_output_dict["cond_frame_outputs"]: - storage_key = "cond_frame_outputs" - current_out = obj_output_dict[storage_key][frame_idx] - device = inference_state["device"] - pred_masks = current_out["pred_masks"].to(device, non_blocking=True) - if self.clear_non_cond_mem_around_input: - # clear non-conditioning memory of the surrounding frames - self._clear_obj_non_cond_mem_around_input( - inference_state, frame_idx, obj_idx - ) - else: - storage_key = "non_cond_frame_outputs" - # ========= FRAME-SKIP CHECK (NEW) ========= - if inference_state["last_processed_frame"] is not None: - prev_img = inference_state["last_processed_frame"] - curr_img = inference_state["images"][frame_idx].cpu() # uint8, HWC - mad = _mean_abs_diff(prev_img, curr_img) - if mad < self.skip_mad_threshold: - print( - f"Skipping frame {frame_idx} due to low MAD ({mad:.2f}) ") - # Skip heavy inference: reuse previous output - count += 1 - current_out = inference_state["last_processed_output"] - pred_masks = current_out["pred_masks"].to( - inference_state["device"], non_blocking=True - ) - # still need to record bookkeeping so downstream code works - obj_output_dict[storage_key][frame_idx] = current_out - else: - current_out, pred_masks = self._run_single_frame_inference( - inference_state, obj_output_dict, frame_idx, 1, - is_init_cond_frame=False, point_inputs=None, mask_inputs=None, - reverse=reverse, run_mem_encoder=True + + # ========================================================= + # 2) FAST PATH – reuse cached outputs for *all* objects + # ========================================================= + if reuse_this_frame: + cached_list = inference_state["last_processed_outputs"] # len == batch_size + device = inference_state["device"] + + for obj_idx, cached_out in enumerate(cached_list): + obj_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_dict["non_cond_frame_outputs"][frame_idx] = cached_out + inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { + "reverse": reverse + } + pred_masks_per_obj[obj_idx] = cached_out["pred_masks"].to( + device, non_blocking=True + ) + + # ========================================================= + # 3) SLOW PATH – run the model once per object + # ========================================================= + else: + frame_outputs = [] + for obj_idx in range(batch_size): + obj_dict = inference_state["output_dict_per_obj"][obj_idx] + + # 3a) conditioning frame? + if frame_idx in obj_dict["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = obj_dict[storage_key][frame_idx] + device = inference_state["device"] + pred_masks = current_out["pred_masks"].to(device, non_blocking=True) + if self.clear_non_cond_mem_around_input: + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx ) - obj_output_dict[storage_key][frame_idx] = current_out - # update cache - inference_state["last_processed_frame"] = curr_img - inference_state["last_processed_output"] = current_out + + # 3b) plain frame – heavy inference else: - # first time – must run inference + storage_key = "non_cond_frame_outputs" current_out, pred_masks = self._run_single_frame_inference( - inference_state, obj_output_dict, frame_idx, 1, - is_init_cond_frame=False, point_inputs=None, mask_inputs=None, + inference_state, obj_dict, frame_idx, 1, + is_init_cond_frame=False, + point_inputs=None, mask_inputs=None, reverse=reverse, run_mem_encoder=True ) - obj_output_dict[storage_key][frame_idx] = current_out - # update cache - inference_state["last_processed_frame"] = inference_state["images"][frame_idx].cpu() - inference_state["last_processed_output"] = current_out - # ========= END FRAME-SKIP CHECK ========= - - - inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { - "reverse": reverse - } - pred_masks_per_obj[obj_idx] = pred_masks - - # Resize the output mask to the original video resolution (we directly use - # the mask scores on GPU for output to avoid any CPU conversion in between) - if len(pred_masks_per_obj) > 1: + obj_dict[storage_key][frame_idx] = current_out + + inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { + "reverse": reverse + } + pred_masks_per_obj[obj_idx] = pred_masks + frame_outputs.append(current_out) + + # ---- cache the full-frame outputs for the next frame ---- + inference_state["last_processed_frame"] = inference_state["images"][frame_idx].cpu() + inference_state["last_processed_outputs"] = frame_outputs + + # ========================================================= + # 4) COMMON TAIL – resize masks & yield + # ========================================================= + if batch_size > 1: all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) else: all_pred_masks = pred_masks_per_obj[0] + _, video_res_masks = self._get_orig_video_res_output( inference_state, all_pred_masks ) @@ -694,10 +712,11 @@ def propagate_in_video( dt = time.time() - t0 print(f"Avg FPS last 30 frames: {30/dt:.2f}") t0 = time.time() + yield frame_idx, obj_ids, video_res_masks - # pront the number of skipped frames over the entire video - if count > 0: - print(f"Skipped {count} frame due to low MAD.") + + if skipped_ctr: + print(f"Skipped {skipped_ctr} frames due to low MAD.") @torch.inference_mode() def clear_all_prompts_in_frame( @@ -755,6 +774,9 @@ def reset_state(self, inference_state): inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() inference_state["frames_tracked_per_obj"].clear() + inference_state["last_processed_frame"] = None + inference_state["last_processed_outputs"] = None + def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos.""" From fef2b11e88653711bbc95c71fd2a670431f2a376 Mon Sep 17 00:00:00 2001 From: YoungKameSennin Date: Mon, 2 Jun 2025 15:39:34 -0700 Subject: [PATCH 3/7] davis2017 FPS eval --- run_mgfs_davis.ipynb | 177 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 run_mgfs_davis.ipynb diff --git a/run_mgfs_davis.ipynb b/run_mgfs_davis.ipynb new file mode 100644 index 0000000..09158a9 --- /dev/null +++ b/run_mgfs_davis.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "acc92286", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 30 validation sequences\n" + ] + } + ], + "source": [ + "import sys, os, json, time, subprocess, pathlib\n", + "from pathlib import Path\n", + "from davis2017.davis import DAVIS\n", + "import imageio.v3 as iio\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "import torch\n", + "\n", + "DAVIS_ROOT = Path(\"./data/davis/DAVIS\") # ← adjust\n", + "OUT_DIR = Path(\"./data/sam2_preds\") # where we’ll save PNGs\n", + "OUT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# DAVIS helper (semi-supervised = first-frame GT mask)\n", + "ds = DAVIS(str(DAVIS_ROOT), task=\"semi-supervised\", subset=\"val\", resolution=\"480p\")\n", + "print(f\"Loaded {len(ds.sequences)} validation sequences\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "302a2bb5", + "metadata": {}, + "outputs": [], + "source": [ + "from sam2.build_sam import build_sam2_video_predictor\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "\n", + "sam2_checkpoint = \"./checkpoints/sam2.1_hiera_large.pt\"\n", + "model_cfg = \"configs/sam2.1/sam2.1_hiera_l.yaml\"\n", + "\n", + "predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c4bb2a9b", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "def show_mask(mask, ax, obj_id=None, random_color=False):\n", + " if random_color:\n", + " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n", + " else:\n", + " cmap = plt.get_cmap(\"tab10\")\n", + " cmap_idx = 0 if obj_id is None else obj_id\n", + " color = np.array([*cmap(cmap_idx)[:3], 0.6])\n", + " h, w = mask.shape[-2:]\n", + " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n", + " ax.imshow(mask_image)\n", + "\n", + "def run_sequence(seq_name: str) -> float:\n", + " img_dir = Path(ds.sequences[seq_name]['images'][0]).parent\n", + " inference_state = predictor.init_state(str(img_dir))\n", + " predictor.reset_state(inference_state)\n", + "\n", + " # --- first-frame GT mask (ensure 2-D) ---\n", + " first_gt = iio.imread(ds.sequences[seq_name]['masks'][0])\n", + " if first_gt.ndim == 3: # palette PNG → RGB/RGBA\n", + " first_gt = first_gt[..., 0]\n", + "\n", + " for k in range(1, int(first_gt.max()) + 1):\n", + " predictor.add_new_mask(\n", + " inference_state,\n", + " frame_idx=0,\n", + " obj_id=f\"obj-{k}\",\n", + " mask=(first_gt == k).astype(\"uint8\"),\n", + " )\n", + "\n", + " # --- propagate & save ---\n", + " t0 = time.time()\n", + " n_frames = len(ds.sequences[seq_name]['images'])\n", + " # run propagation throughout the video and collect the results in a dict\n", + " video_segments = {} # video_segments contains the per-frame segmentation results\n", + " for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):\n", + " video_segments[out_frame_idx] = {\n", + " out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()\n", + " for i, out_obj_id in enumerate(out_obj_ids)\n", + " }\n", + " # # save the results to PNGs\n", + " # for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n", + " # out_mask = out_mask.astype(\"uint8\") * 255\n", + " # out_path = OUT_DIR / f\"{seq_name}_{out_frame_idx:04d}_{out_obj_id}.png\"\n", + " # iio.imwrite(out_path, out_mask)\n", + "\n", + " vis_frame_stride = 1\n", + " plt.close(\"all\")\n", + " for out_frame_idx in range(0, len(n_frames), vis_frame_stride):\n", + " plt.figure(figsize=(6, 4))\n", + " plt.title(f\"frame {out_frame_idx}\")\n", + " plt.imshow(Image.open(ds.sequences[seq_name]['images'][out_frame_idx]))\n", + " for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n", + " show_mask(out_mask, plt.gca(), obj_id=out_obj_id)\n", + "\n", + " return n_frames / (time.time() - t0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b365167", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 69/69 [00:01<00:00, 60.71it/s]\n", + "/home/wei/FrameSkipSAM/sam2/sam2_video_predictor.py:878: UserWarning: cannot import name '_C' from 'sam2' (/home/wei/FrameSkipSAM/sam2/__init__.py)\n", + "\n", + "Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n", + " pred_masks_gpu = fill_holes_in_mask_scores(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 1 due to low MAD (0.04)\n", + "Skipping frame 2 due to low MAD (0.06)\n", + "Skipping frame 3 due to low MAD (0.07)\n", + "Skipping frame 4 due to low MAD (0.09)\n" + ] + } + ], + "source": [ + "fps_vals = {}\n", + "for seq in tqdm(ds.get_sequences(), desc=\"SAM2 on DAVIS-val\"):\n", + " fps_vals[seq] = run_sequence(seq)\n", + "\n", + "print(f\"Mean FPS: {sum(fps_vals.values())/len(fps_vals):.2f}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sam2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 781b9a3ee89685aeb2a03375b4c49014be0aff64 Mon Sep 17 00:00:00 2001 From: YoungKameSennin Date: Mon, 2 Jun 2025 15:45:45 -0700 Subject: [PATCH 4/7] bug fix --- run_mgfs_davis.ipynb | 253 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 250 insertions(+), 3 deletions(-) diff --git a/run_mgfs_davis.ipynb b/run_mgfs_davis.ipynb index 09158a9..17948fe 100644 --- a/run_mgfs_davis.ipynb +++ b/run_mgfs_davis.ipynb @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "c4bb2a9b", "metadata": {}, "outputs": [], @@ -106,7 +106,7 @@ "\n", " vis_frame_stride = 1\n", " plt.close(\"all\")\n", - " for out_frame_idx in range(0, len(n_frames), vis_frame_stride):\n", + " for out_frame_idx in range(0, n_frames, vis_frame_stride):\n", " plt.figure(figsize=(6, 4))\n", " plt.title(f\"frame {out_frame_idx}\")\n", " plt.imshow(Image.open(ds.sequences[seq_name]['images'][out_frame_idx]))\n", @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "9b365167", "metadata": {}, "outputs": [ @@ -142,6 +142,253 @@ "Skipping frame 3 due to low MAD (0.07)\n", "Skipping frame 4 due to low MAD (0.09)\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 6 due to low MAD (0.07)\n", + "Skipping frame 7 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 9 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 13 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 15 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 17 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 23 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 26 due to low MAD (0.08)\n", + "Skipping frame 27 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 29 due to low MAD (0.04)\n", + "Skipping frame 30 due to low MAD (0.05)\n", + "Avg FPS last 30 frames: 0.26\n", + "Skipping frame 31 due to low MAD (0.05)\n", + "Skipping frame 32 due to low MAD (0.06)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 33 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 35 due to low MAD (0.07)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 37 due to low MAD (0.07)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 44 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 55 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 57 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 0.18\n", + "Skipping frame 61 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 63 due to low MAD (0.07)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 65 due to low MAD (0.07)\n", + "Skipping frame 66 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 69/69 [04:58<00:00, 4.33s/it]\n", + "SAM2 on DAVIS-val: 0it [05:02, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 68 due to low MAD (0.04)\n", + "Skipped 28 frames due to low MAD.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "ename": "TypeError", + "evalue": "object of type 'int' has no len()", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m fps_vals = {}\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m tqdm(ds.get_sequences(), desc=\u001b[33m\"\u001b[39m\u001b[33mSAM2 on DAVIS-val\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m fps_vals[seq] = \u001b[43mrun_sequence\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseq\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 5\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mMean FPS: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28msum\u001b[39m(fps_vals.values())/\u001b[38;5;28mlen\u001b[39m(fps_vals)\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 50\u001b[39m, in \u001b[36mrun_sequence\u001b[39m\u001b[34m(seq_name)\u001b[39m\n\u001b[32m 48\u001b[39m vis_frame_stride = \u001b[32m1\u001b[39m\n\u001b[32m 49\u001b[39m plt.close(\u001b[33m\"\u001b[39m\u001b[33mall\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m out_frame_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m0\u001b[39m, \u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mn_frames\u001b[49m\u001b[43m)\u001b[49m, vis_frame_stride):\n\u001b[32m 51\u001b[39m plt.figure(figsize=(\u001b[32m6\u001b[39m, \u001b[32m4\u001b[39m))\n\u001b[32m 52\u001b[39m plt.title(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mframe \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mout_frame_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mTypeError\u001b[39m: object of type 'int' has no len()" + ] } ], "source": [ From e5472cfa0c4c6f3c8e387f8cff9158ce2aefc6dd Mon Sep 17 00:00:00 2001 From: YoungKameSennin Date: Tue, 3 Jun 2025 01:52:31 -0700 Subject: [PATCH 5/7] saves pred --- run_mgfs_davis.ipynb | 4180 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 4035 insertions(+), 145 deletions(-) diff --git a/run_mgfs_davis.ipynb b/run_mgfs_davis.ipynb index 17948fe..ef7a093 100644 --- a/run_mgfs_davis.ipynb +++ b/run_mgfs_davis.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "id": "acc92286", "metadata": {}, "outputs": [ @@ -12,6 +12,331 @@ "text": [ "Loaded 30 validation sequences\n" ] + }, + { + "data": { + "text/plain": [ + "SAM2VideoPredictor(\n", + " (image_encoder): ImageEncoder(\n", + " (trunk): Hiera(\n", + " (patch_embed): PatchEmbed(\n", + " (proj): Conv2d(3, 144, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))\n", + " )\n", + " (blocks): ModuleList(\n", + " (0-1): 2 x MultiScaleBlock(\n", + " (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=144, out_features=432, bias=True)\n", + " (proj): Linear(in_features=144, out_features=144, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((144,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=144, out_features=576, bias=True)\n", + " (1): Linear(in_features=576, out_features=144, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " (2): MultiScaleBlock(\n", + " (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)\n", + " (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (attn): MultiScaleAttention(\n", + " (q_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (qkv): Linear(in_features=144, out_features=864, bias=True)\n", + " (proj): Linear(in_features=288, out_features=288, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=288, out_features=1152, bias=True)\n", + " (1): Linear(in_features=1152, out_features=288, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " (proj): Linear(in_features=144, out_features=288, bias=True)\n", + " )\n", + " (3-7): 5 x MultiScaleBlock(\n", + " (norm1): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=288, out_features=864, bias=True)\n", + " (proj): Linear(in_features=288, out_features=288, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=288, out_features=1152, bias=True)\n", + " (1): Linear(in_features=1152, out_features=288, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " (8): MultiScaleBlock(\n", + " (norm1): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (attn): MultiScaleAttention(\n", + " (q_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (qkv): Linear(in_features=288, out_features=1728, bias=True)\n", + " (proj): Linear(in_features=576, out_features=576, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=576, out_features=2304, bias=True)\n", + " (1): Linear(in_features=2304, out_features=576, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " (proj): Linear(in_features=288, out_features=576, bias=True)\n", + " )\n", + " (9-43): 35 x MultiScaleBlock(\n", + " (norm1): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=576, out_features=1728, bias=True)\n", + " (proj): Linear(in_features=576, out_features=576, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=576, out_features=2304, bias=True)\n", + " (1): Linear(in_features=2304, out_features=576, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " (44): MultiScaleBlock(\n", + " (norm1): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (attn): MultiScaleAttention(\n", + " (q_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (qkv): Linear(in_features=576, out_features=3456, bias=True)\n", + " (proj): Linear(in_features=1152, out_features=1152, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=1152, out_features=4608, bias=True)\n", + " (1): Linear(in_features=4608, out_features=1152, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " (proj): Linear(in_features=576, out_features=1152, bias=True)\n", + " )\n", + " (45-47): 3 x MultiScaleBlock(\n", + " (norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=1152, out_features=3456, bias=True)\n", + " (proj): Linear(in_features=1152, out_features=1152, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=1152, out_features=4608, bias=True)\n", + " (1): Linear(in_features=4608, out_features=1152, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (neck): FpnNeck(\n", + " (position_encoding): PositionEmbeddingSine()\n", + " (convs): ModuleList(\n", + " (0): Sequential(\n", + " (conv): Conv2d(1152, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (1): Sequential(\n", + " (conv): Conv2d(576, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (2): Sequential(\n", + " (conv): Conv2d(288, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (3): Sequential(\n", + " (conv): Conv2d(144, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (mask_downsample): Conv2d(1, 1, kernel_size=(4, 4), stride=(4, 4))\n", + " (memory_attention): MemoryAttention(\n", + " (layers): ModuleList(\n", + " (0-3): 4 x MemoryAttentionLayer(\n", + " (self_attn): RoPEAttention(\n", + " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (cross_attn_image): RoPEAttention(\n", + " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (k_proj): Linear(in_features=64, out_features=256, bias=True)\n", + " (v_proj): Linear(in_features=64, out_features=256, bias=True)\n", + " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", + " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " (dropout3): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (memory_encoder): MemoryEncoder(\n", + " (mask_downsampler): MaskDownSampler(\n", + " (encoder): Sequential(\n", + " (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): LayerNorm2d()\n", + " (2): GELU(approximate='none')\n", + " (3): Conv2d(4, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (4): LayerNorm2d()\n", + " (5): GELU(approximate='none')\n", + " (6): Conv2d(16, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (7): LayerNorm2d()\n", + " (8): GELU(approximate='none')\n", + " (9): Conv2d(64, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (10): LayerNorm2d()\n", + " (11): GELU(approximate='none')\n", + " (12): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (pix_feat_proj): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (fuser): Fuser(\n", + " (proj): Identity()\n", + " (layers): ModuleList(\n", + " (0-1): 2 x CXBlock(\n", + " (dwconv): Conv2d(256, 256, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=256)\n", + " (norm): LayerNorm2d()\n", + " (pwconv1): Linear(in_features=256, out_features=1024, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (pwconv2): Linear(in_features=1024, out_features=256, bias=True)\n", + " (drop_path): Identity()\n", + " )\n", + " )\n", + " )\n", + " (position_encoding): PositionEmbeddingSine()\n", + " (out_proj): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (sam_prompt_encoder): PromptEncoder(\n", + " (pe_layer): PositionEmbeddingRandom()\n", + " (point_embeddings): ModuleList(\n", + " (0-3): 4 x Embedding(1, 256)\n", + " )\n", + " (not_a_point_embed): Embedding(1, 256)\n", + " (mask_downscaling): Sequential(\n", + " (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))\n", + " (1): LayerNorm2d()\n", + " (2): GELU(approximate='none')\n", + " (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))\n", + " (4): LayerNorm2d()\n", + " (5): GELU(approximate='none')\n", + " (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (no_mask_embed): Embedding(1, 256)\n", + " )\n", + " (sam_mask_decoder): MaskDecoder(\n", + " (transformer): TwoWayTransformer(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x TwoWayAttentionBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (cross_attn_token_to_image): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (out_proj): Linear(in_features=128, out_features=256, bias=True)\n", + " )\n", + " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=256, out_features=2048, bias=True)\n", + " (1): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (norm4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (cross_attn_image_to_token): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (out_proj): Linear(in_features=128, out_features=256, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (final_attn_token_to_image): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (out_proj): Linear(in_features=128, out_features=256, bias=True)\n", + " )\n", + " (norm_final_attn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (iou_token): Embedding(1, 256)\n", + " (mask_tokens): Embedding(4, 256)\n", + " (obj_score_token): Embedding(1, 256)\n", + " (output_upscaling): Sequential(\n", + " (0): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))\n", + " (1): LayerNorm2d()\n", + " (2): GELU(approximate='none')\n", + " (3): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))\n", + " (4): GELU(approximate='none')\n", + " )\n", + " (conv_s0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_s1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (output_hypernetworks_mlps): ModuleList(\n", + " (0-3): 4 x MLP(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)\n", + " (2): Linear(in_features=256, out_features=32, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " )\n", + " (iou_prediction_head): MLP(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)\n", + " (2): Linear(in_features=256, out_features=4, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " (pred_obj_score_head): MLP(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)\n", + " (2): Linear(in_features=256, out_features=1, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " )\n", + " (obj_ptr_proj): MLP(\n", + " (layers): ModuleList(\n", + " (0-2): 3 x Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " (obj_ptr_tpos_proj): Linear(in_features=256, out_features=64, bias=True)\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -23,135 +348,69 @@ "from tqdm import tqdm\n", "import torch\n", "\n", - "DAVIS_ROOT = Path(\"./data/davis/DAVIS\") # ← adjust\n", - "OUT_DIR = Path(\"./data/sam2_preds\") # where we’ll save PNGs\n", + "# ── USER‐CONFIGURABLE PATHS ──────────────────────────────────────────────────\n", + "DAVIS_ROOT = Path(\"./data/davis/DAVIS\") # ← point this at your DAVIS folder\n", + "OUT_DIR = Path(\"./data/sam2_preds\") # ← where we’ll write out PNGs\n", "OUT_DIR.mkdir(parents=True, exist_ok=True)\n", "\n", - "# DAVIS helper (semi-supervised = first-frame GT mask)\n", + "# ── STEP 1: load the DAVIS “val” split (semi‐supervised task) ───────────────────\n", "ds = DAVIS(str(DAVIS_ROOT), task=\"semi-supervised\", subset=\"val\", resolution=\"480p\")\n", - "print(f\"Loaded {len(ds.sequences)} validation sequences\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "302a2bb5", - "metadata": {}, - "outputs": [], - "source": [ + "print(f\"Loaded {len(ds.sequences)} validation sequences\")\n", + "\n", + "# ── STEP 2: build SAM 2 video predictor ─────────────────────────────────────────\n", "from sam2.build_sam import build_sam2_video_predictor\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "sam2_checkpoint = \"./checkpoints/sam2.1_hiera_large.pt\" # ← adjust if needed\n", + "model_cfg = \"configs/sam2.1/sam2.1_hiera_l.yaml\" # ← adjust if needed\n", "\n", - "\n", - "sam2_checkpoint = \"./checkpoints/sam2.1_hiera_large.pt\"\n", - "model_cfg = \"configs/sam2.1/sam2.1_hiera_l.yaml\"\n", - "\n", - "predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c4bb2a9b", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "def show_mask(mask, ax, obj_id=None, random_color=False):\n", - " if random_color:\n", - " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n", - " else:\n", - " cmap = plt.get_cmap(\"tab10\")\n", - " cmap_idx = 0 if obj_id is None else obj_id\n", - " color = np.array([*cmap(cmap_idx)[:3], 0.6])\n", - " h, w = mask.shape[-2:]\n", - " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n", - " ax.imshow(mask_image)\n", - "\n", - "def run_sequence(seq_name: str) -> float:\n", - " img_dir = Path(ds.sequences[seq_name]['images'][0]).parent\n", - " inference_state = predictor.init_state(str(img_dir))\n", - " predictor.reset_state(inference_state)\n", - "\n", - " # --- first-frame GT mask (ensure 2-D) ---\n", - " first_gt = iio.imread(ds.sequences[seq_name]['masks'][0])\n", - " if first_gt.ndim == 3: # palette PNG → RGB/RGBA\n", - " first_gt = first_gt[..., 0]\n", - "\n", - " for k in range(1, int(first_gt.max()) + 1):\n", - " predictor.add_new_mask(\n", - " inference_state,\n", - " frame_idx=0,\n", - " obj_id=f\"obj-{k}\",\n", - " mask=(first_gt == k).astype(\"uint8\"),\n", - " )\n", - "\n", - " # --- propagate & save ---\n", - " t0 = time.time()\n", - " n_frames = len(ds.sequences[seq_name]['images'])\n", - " # run propagation throughout the video and collect the results in a dict\n", - " video_segments = {} # video_segments contains the per-frame segmentation results\n", - " for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):\n", - " video_segments[out_frame_idx] = {\n", - " out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()\n", - " for i, out_obj_id in enumerate(out_obj_ids)\n", - " }\n", - " # # save the results to PNGs\n", - " # for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n", - " # out_mask = out_mask.astype(\"uint8\") * 255\n", - " # out_path = OUT_DIR / f\"{seq_name}_{out_frame_idx:04d}_{out_obj_id}.png\"\n", - " # iio.imwrite(out_path, out_mask)\n", - "\n", - " vis_frame_stride = 1\n", - " plt.close(\"all\")\n", - " for out_frame_idx in range(0, n_frames, vis_frame_stride):\n", - " plt.figure(figsize=(6, 4))\n", - " plt.title(f\"frame {out_frame_idx}\")\n", - " plt.imshow(Image.open(ds.sequences[seq_name]['images'][out_frame_idx]))\n", - " for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n", - " show_mask(out_mask, plt.gca(), obj_id=out_obj_id)\n", - "\n", - " return n_frames / (time.time() - t0)\n" + "predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)\n", + "predictor.to(device) # make sure model is on CUDA if available" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "9b365167", + "execution_count": 31, + "id": "bdd669c2", "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "frame loading (JPEG): 100%|██████████| 69/69 [00:01<00:00, 60.71it/s]\n", - "/home/wei/FrameSkipSAM/sam2/sam2_video_predictor.py:878: UserWarning: cannot import name '_C' from 'sam2' (/home/wei/FrameSkipSAM/sam2/__init__.py)\n", "\n", - "Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n", - " pred_masks_gpu = fill_holes_in_mask_scores(\n" + "=== Processing sequence: bike-packing ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 69/69 [00:01<00:00, 60.43it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Skipping frame 1 due to low MAD (0.04)\n", - "Skipping frame 2 due to low MAD (0.06)\n", - "Skipping frame 3 due to low MAD (0.07)\n", - "Skipping frame 4 due to low MAD (0.09)\n" + "Found 2 unique non‐black colors in bike-packing/00000.png\n" ] }, { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 7%|▋ | 5/68 [00:00<00:01, 45.88it/s]" + ] }, { "name": "stdout", "output_type": "stream", "text": [ + "Skipping frame 1 due to low MAD (0.04)\n", + "Skipping frame 2 due to low MAD (0.06)\n", + "Skipping frame 3 due to low MAD (0.07)\n", + "Skipping frame 4 due to low MAD (0.09)\n", "Skipping frame 6 due to low MAD (0.07)\n", "Skipping frame 7 due to low MAD (0.09)\n" ] @@ -159,7 +418,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 15%|█▍ | 10/68 [00:00<00:04, 13.88it/s]" + ] }, { "name": "stdout", @@ -171,7 +432,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 19%|█▉ | 13/68 [00:01<00:07, 7.50it/s]" + ] }, { "name": "stdout", @@ -183,7 +446,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 22%|██▏ | 15/68 [00:01<00:07, 7.41it/s]" + ] }, { "name": "stdout", @@ -195,7 +460,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 25%|██▌ | 17/68 [00:01<00:06, 7.35it/s]" + ] }, { "name": "stdout", @@ -207,7 +474,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 34%|███▍ | 23/68 [00:03<00:09, 4.83it/s]" + ] }, { "name": "stdout", @@ -219,7 +488,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 38%|███▊ | 26/68 [00:03<00:08, 5.06it/s]" + ] }, { "name": "stdout", @@ -232,7 +503,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 50%|█████ | 34/68 [00:04<00:02, 12.47it/s]" + ] }, { "name": "stdout", @@ -240,28 +513,19 @@ "text": [ "Skipping frame 29 due to low MAD (0.04)\n", "Skipping frame 30 due to low MAD (0.05)\n", - "Avg FPS last 30 frames: 0.26\n", + "Avg FPS last 30 frames: 7.27\n", "Skipping frame 31 due to low MAD (0.05)\n", - "Skipping frame 32 due to low MAD (0.06)\n" + "Skipping frame 32 due to low MAD (0.06)\n", + "Skipping frame 33 due to low MAD (0.08)\n" ] }, { "name": "stderr", "output_type": "stream", - "text": [] - }, - { - "name": "stdout", - "output_type": "stream", "text": [ - "Skipping frame 33 due to low MAD (0.08)\n" + "Propagating bike-packing: 53%|█████▎ | 36/68 [00:04<00:02, 10.67it/s]" ] }, - { - "name": "stderr", - "output_type": "stream", - "text": [] - }, { "name": "stdout", "output_type": "stream", @@ -272,7 +536,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 56%|█████▌ | 38/68 [00:04<00:03, 9.55it/s]" + ] }, { "name": "stdout", @@ -284,7 +550,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 62%|██████▏ | 42/68 [00:05<00:04, 5.72it/s]" + ] }, { "name": "stdout", @@ -296,7 +564,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 81%|████████ | 55/68 [00:07<00:02, 5.75it/s]" + ] }, { "name": "stdout", @@ -308,7 +578,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 84%|████████▍ | 57/68 [00:07<00:01, 6.06it/s]" + ] }, { "name": "stdout", @@ -320,20 +592,24 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 90%|████████▉ | 61/68 [00:08<00:01, 5.29it/s]" + ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Avg FPS last 30 frames: 0.18\n", + "Avg FPS last 30 frames: 6.83\n", "Skipping frame 61 due to low MAD (0.09)\n" ] }, { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 93%|█████████▎| 63/68 [00:08<00:00, 5.84it/s]" + ] }, { "name": "stdout", @@ -345,7 +621,9 @@ { "name": "stderr", "output_type": "stream", - "text": [] + "text": [ + "Propagating bike-packing: 96%|█████████▌| 65/68 [00:09<00:00, 6.24it/s]" + ] }, { "name": "stdout", @@ -359,8 +637,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "propagate in video: 100%|██████████| 69/69 [04:58<00:00, 4.33s/it]\n", - "SAM2 on DAVIS-val: 0it [05:02, ?it/s]" + "propagate in video: 100%|██████████| 69/69 [00:09<00:00, 7.34it/s]2it/s]\n", + "Propagating bike-packing: 69it [00:09, 7.33it/s] \n" ] }, { @@ -368,35 +646,3647 @@ "output_type": "stream", "text": [ "Skipping frame 68 due to low MAD (0.04)\n", - "Skipped 28 frames due to low MAD.\n" + "Skipped 28 frames due to low MAD.\n", + "→ Saved all predicted masks for bike-packing in data/sam2_preds_multi/bike-packing\n", + "\n", + "=== Processing sequence: blackswan ===\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\n" + "frame loading (JPEG): 100%|██████████| 50/50 [00:00<00:00, 60.41it/s]\n" ] }, { - "ename": "TypeError", - "evalue": "object of type 'int' has no len()", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m fps_vals = {}\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m tqdm(ds.get_sequences(), desc=\u001b[33m\"\u001b[39m\u001b[33mSAM2 on DAVIS-val\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m fps_vals[seq] = \u001b[43mrun_sequence\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseq\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 5\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mMean FPS: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28msum\u001b[39m(fps_vals.values())/\u001b[38;5;28mlen\u001b[39m(fps_vals)\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 50\u001b[39m, in \u001b[36mrun_sequence\u001b[39m\u001b[34m(seq_name)\u001b[39m\n\u001b[32m 48\u001b[39m vis_frame_stride = \u001b[32m1\u001b[39m\n\u001b[32m 49\u001b[39m plt.close(\u001b[33m\"\u001b[39m\u001b[33mall\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m out_frame_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m0\u001b[39m, \u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mn_frames\u001b[49m\u001b[43m)\u001b[49m, vis_frame_stride):\n\u001b[32m 51\u001b[39m plt.figure(figsize=(\u001b[32m6\u001b[39m, \u001b[32m4\u001b[39m))\n\u001b[32m 52\u001b[39m plt.title(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mframe \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mout_frame_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", - "\u001b[31mTypeError\u001b[39m: object of type 'int' has no len()" + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in blackswan/00000.png\n" ] - } - ], - "source": [ - "fps_vals = {}\n", - "for seq in tqdm(ds.get_sequences(), desc=\"SAM2 on DAVIS-val\"):\n", - " fps_vals[seq] = run_sequence(seq)\n", + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating blackswan: 65%|██████▌ | 32/49 [00:05<00:03, 5.29it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.39\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 50/50 [00:09<00:00, 5.55it/s] \n", + "Propagating blackswan: 50it [00:09, 5.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for blackswan in data/sam2_preds_multi/blackswan\n", + "\n", + "=== Processing sequence: bmx-trees ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 80/80 [00:01<00:00, 60.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 unique non‐black colors in bmx-trees/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bmx-trees: 39%|███▉ | 31/79 [00:07<00:12, 3.95it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 4.10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bmx-trees: 77%|███████▋ | 61/79 [00:14<00:04, 3.97it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 3.96\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 80/80 [00:18<00:00, 4.39it/s] \n", + "Propagating bmx-trees: 80it [00:18, 4.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for bmx-trees in data/sam2_preds_multi/bmx-trees\n", + "\n", + "=== Processing sequence: breakdance ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 84/84 [00:01<00:00, 63.11it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in breakdance/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 4%|▎ | 3/83 [00:00<00:04, 16.40it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 1 due to low MAD (0.09)\n", + "Skipping frame 3 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 11%|█ | 9/83 [00:00<00:07, 9.26it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 7 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 19%|█▉ | 16/83 [00:01<00:08, 7.55it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 14 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 29%|██▉ | 24/83 [00:03<00:08, 7.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 22 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 31%|███▏ | 26/83 [00:03<00:07, 7.90it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 24 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 39%|███▊ | 32/83 [00:04<00:07, 7.18it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 30 due to low MAD (0.10)\n", + "Avg FPS last 30 frames: 7.06\n", + "Skipping frame 32 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 43%|████▎ | 36/83 [00:04<00:05, 8.72it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 34 due to low MAD (0.10)\n", + "Skipping frame 36 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 48%|████▊ | 40/83 [00:05<00:04, 9.53it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 38 due to low MAD (0.09)\n", + "Skipping frame 40 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 58%|█████▊ | 48/83 [00:06<00:04, 7.82it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 46 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 69%|██████▊ | 57/83 [00:07<00:03, 7.27it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 55 due to low MAD (0.08)\n", + "Skipping frame 57 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 73%|███████▎ | 61/83 [00:08<00:02, 8.87it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 59 due to low MAD (0.10)\n", + "Avg FPS last 30 frames: 7.71\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 88%|████████▊ | 73/83 [00:10<00:01, 7.11it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 71 due to low MAD (0.09)\n", + "Skipping frame 73 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 93%|█████████▎| 77/83 [00:10<00:00, 8.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 75 due to low MAD (0.09)\n", + "Skipping frame 77 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 98%|█████████▊| 81/83 [00:10<00:00, 9.59it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 79 due to low MAD (0.09)\n", + "Skipping frame 81 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 84/84 [00:11<00:00, 7.55it/s]t/s]\n", + "Propagating breakdance: 84it [00:11, 7.54it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 83 due to low MAD (0.09)\n", + "Skipped 23 frames due to low MAD.\n", + "→ Saved all predicted masks for breakdance in data/sam2_preds_multi/breakdance\n", + "\n", + "=== Processing sequence: camel ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 90/90 [00:01<00:00, 61.51it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in camel/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating camel: 36%|███▌ | 32/89 [00:05<00:10, 5.48it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.62\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating camel: 70%|██████▉ | 62/89 [00:10<00:04, 5.52it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.50\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 90/90 [00:14<00:00, 6.15it/s]\n", + "Propagating camel: 90it [00:14, 6.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for camel in data/sam2_preds_multi/camel\n", + "\n", + "=== Processing sequence: car-roundabout ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 75/75 [00:01<00:00, 61.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in car-roundabout/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating car-roundabout: 43%|████▎ | 32/74 [00:05<00:07, 5.54it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.67\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating car-roundabout: 84%|████████▍ | 62/74 [00:11<00:02, 5.32it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.43\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 75/75 [00:13<00:00, 5.55it/s] \n", + "Propagating car-roundabout: 75it [00:13, 5.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for car-roundabout in data/sam2_preds_multi/car-roundabout\n", + "\n", + "=== Processing sequence: car-shadow ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 40/40 [00:00<00:00, 61.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in car-shadow/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating car-shadow: 82%|████████▏ | 32/39 [00:05<00:01, 5.22it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.34\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 40/40 [00:07<00:00, 5.44it/s] \n", + "Propagating car-shadow: 40it [00:07, 5.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for car-shadow in data/sam2_preds_multi/car-shadow\n", + "\n", + "=== Processing sequence: cows ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 104/104 [00:01<00:00, 58.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in cows/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 12%|█▏ | 12/103 [00:01<00:16, 5.37it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 12 due to low MAD (0.09)\n", + "Skipping frame 13 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 29%|██▉ | 30/103 [00:03<00:04, 17.70it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 7.76\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 34%|███▍ | 35/103 [00:04<00:06, 10.31it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 33 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 60%|██████ | 62/103 [00:09<00:07, 5.25it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.40\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 79%|███████▊ | 81/103 [00:13<00:04, 5.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 81 due to low MAD (0.07)\n", + "Skipping frame 82 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 89%|████████▉ | 92/103 [00:14<00:01, 5.52it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.66\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 104/104 [00:17<00:00, 6.06it/s]\n", + "Propagating cows: 104it [00:17, 6.06it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipped 5 frames due to low MAD.\n", + "→ Saved all predicted masks for cows in data/sam2_preds_multi/cows\n", + "\n", + "=== Processing sequence: dance-twirl ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 90/90 [00:01<00:00, 60.09it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in dance-twirl/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dance-twirl: 36%|███▌ | 32/89 [00:05<00:10, 5.32it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.45\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dance-twirl: 55%|█████▌ | 49/89 [00:08<00:05, 6.72it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 47 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dance-twirl: 70%|██████▉ | 62/89 [00:11<00:05, 5.35it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.48\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 90/90 [00:15<00:00, 5.98it/s]it/s]\n", + "Propagating dance-twirl: 90it [00:15, 5.98it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipped 1 frames due to low MAD.\n", + "→ Saved all predicted masks for dance-twirl in data/sam2_preds_multi/dance-twirl\n", + "\n", + "=== Processing sequence: dog ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 60/60 [00:01<00:00, 59.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in dog/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dog: 54%|█████▍ | 32/59 [00:05<00:05, 5.36it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.46\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 60/60 [00:10<00:00, 5.54it/s]\n", + "Propagating dog: 60it [00:10, 5.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for dog in data/sam2_preds_multi/dog\n", + "\n", + "=== Processing sequence: dogs-jump ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 66/66 [00:01<00:00, 58.85it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 3 unique non‐black colors in dogs-jump/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dogs-jump: 0%| | 0/65 [00:00\"\n", + " inference_state = predictor.init_state(\n", + " video_path=video_dir,\n", + " offload_video_to_cpu=False,\n", + " offload_state_to_cpu=False,\n", + " async_loading_frames=False\n", + " )\n", + "\n", + "# (b) load the single “00000.png” which contains two different colored regions\n", + " rgb = iio.imread(str(mask_dir / \"00000.png\")) # shape (H, W, 3)\n", + " H, W, C = rgb.shape\n", + " assert C == 3, \"Expected a 3‐channel (RGB) first‐frame mask.\"\n", + "\n", + " # (c) find all unique RGB colors except black\n", + " flat = rgb.reshape(-1, 3) # shape (H*W, 3)\n", + " uniq_colors = np.unique(flat, axis=0) # shape (K, 3), where K ≤ (H*W)\n", + " # Remove the black color (0,0,0) if present\n", + " non_black = [tuple(c) for c in uniq_colors if not np.all(c == 0)]\n", + " if len(non_black) == 0:\n", + " raise RuntimeError(f\"No non‐black colors found in {seq}/00000.png\")\n", + "\n", + " # (d) for each unique non‐black color, build a 2D boolean mask and register it\n", + " print(f\"Found {len(non_black)} unique non‐black colors in {seq}/00000.png\")\n", + " for idx, color in enumerate(non_black):\n", + " # color is something like (200, 0, 0) or (0, 200, 0)\n", + " R, G, B = color\n", + " # build a binary mask: True where pixel == this color\n", + " bin_mask = np.logical_and.reduce([\n", + " rgb[:, :, 0] == R,\n", + " rgb[:, :, 1] == G,\n", + " rgb[:, :, 2] == B\n", + " ]) # shape (H, W), dtype=bool\n", + "\n", + " # wrap as torch.bool on the same device as SAM 2\n", + " mask2d = torch.from_numpy(bin_mask).to(device)\n", + "\n", + " # register this mask as object `idx`\n", + " predictor.add_new_mask(\n", + " inference_state=inference_state,\n", + " frame_idx=0,\n", + " obj_id=idx, # choose 0,1,2,… per color\n", + " mask=mask2d\n", + " )\n", + "\n", + " # 3e) now propagate through all frames. As each new frame is processed,\n", + " # propagate_in_video yields (frame_idx, [obj_ids], video_res_masks).\n", + " #\n", + " # We’ll save each mask as “00000.png”, “00001.png”, … under OUT_DIR//\n", + " seq_out_dir = OUT_DIR / seq\n", + " seq_out_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " for frame_idx, obj_ids, video_res_masks in tqdm(\n", + " predictor.propagate_in_video(inference_state),\n", + " total=len(img_paths)-1,\n", + " desc=f\"Propagating {seq}\"\n", + " ):\n", + " # # ‣ frame_idx is an integer (1,2,3,…). video_res_masks is a tensor of shape\n", + " # # (num_objects, H, W). For DAVIS, num_objects==1.\n", + " # #\n", + " # # ‣ Thresholding has already happened internally; `video_res_masks` is\n", + " # # a float‐tensor where positive values correspond to predicted “object.”\n", + " # mask_np = (video_res_masks[0].cpu().numpy() > 0.0).astype(np.uint8) * 255\n", + "\n", + " # # Save with zero‐padded five digits to match DAVIS naming:\n", + " # save_name = f\"{frame_idx:05d}.png\"\n", + " # save_path = seq_out_dir / save_name\n", + " # iio.imwrite(str(save_path), mask_np)\n", + "\n", + " # Suppose `video_res_masks` is whatever you get from propagate_in_video:\n", + " # • If there is only one object, it may be a 2D tensor of shape (H, W)\n", + " # • If there are multiple objects, it will be a 3D tensor of shape (O, H, W)\n", + "\n", + " pred_np = video_res_masks.cpu().numpy() # dtype=float32 or float; # ───────────────────────────────────────────────────────────────\n", + " # Assume you already did:\n", + " # pred_np = video_res_masks.cpu().numpy()\n", + "\n", + " # 1) Check how many dimensions `pred_np` has:\n", + " if pred_np.ndim == 2:\n", + " # Case A: single object, shape = (H, W)\n", + " H, W = pred_np.shape\n", + " O = 1\n", + " pred_np = pred_np[np.newaxis, ...] # -> now shape (1, H, W)\n", + "\n", + " elif pred_np.ndim == 3:\n", + " # Could be either:\n", + " # (A) shape = (1, H, W) ← single object with a leading axis\n", + " # (B) shape = (O, H, W) ← multiple objects, no extra channel axis\n", + " if pred_np.shape[0] == 1:\n", + " # Treat as “one‐object” → squeeze to (1, H, W) (already fits our convention)\n", + " O, H, W = pred_np.shape\n", + " else:\n", + " # Multi‐object already: (O, H, W)\n", + " O, H, W = pred_np.shape\n", + " # (no need to reshape because it’s already (O, H, W))\n", + "\n", + " elif pred_np.ndim == 4:\n", + " # Some SAM 2 builds return (O, 1, H, W). In that case:\n", + " # • pred_np.shape = (O, 1, H, W)\n", + " # → we want to drop the “channel” dimension (axis=1).\n", + " O = pred_np.shape[0]\n", + " H = pred_np.shape[2]\n", + " W = pred_np.shape[3]\n", + " pred_np = pred_np[:, 0, :, :] # now shape (O, H, W)\n", + "\n", + " else:\n", + " raise RuntimeError(f\"Unexpected mask array with ndim={pred_np.ndim}, shape={pred_np.shape}\")\n", + "\n", + " # At this point:\n", + " # • pred_np is guaranteed to have shape (O, H, W)\n", + " # • O, H, W are set correctly\n", + " # ───────────────────────────────────────────────────────────────\n", + "\n", + " # Now you can build your colored output exactly as before:\n", + "\n", + " colored = np.zeros((H, W, 3), dtype=np.uint8)\n", + "\n", + " for i in range(O):\n", + " mask_i = (pred_np[i] > 0.0) # boolean mask (H, W)\n", + " if not mask_i.any():\n", + " continue\n", + " R, G, B = non_black[i] # the original RGB for object i\n", + " colored[mask_i, 0] = R\n", + " colored[mask_i, 1] = G\n", + " colored[mask_i, 2] = B\n", + "\n", + " save_name = f\"{frame_idx:05d}.png\"\n", + " save_path = seq_out_dir / save_name\n", + " iio.imwrite(str(save_path), colored)\n", + "\n", + "\n", + " print(f\"→ Saved all predicted masks for {seq} in {seq_out_dir}\")\n", "\n", - "print(f\"Mean FPS: {sum(fps_vals.values())/len(fps_vals):.2f}\")\n" + "print(\"\\nAll sequences processed.\")\n", + "print(f\"Your SAM 2 masks live under: {OUT_DIR}\")\n" ] } ], From bc38cab338c5aadda46a319398fb6626fd20a7eb Mon Sep 17 00:00:00 2001 From: YoungKameSennin Date: Tue, 3 Jun 2025 01:55:20 -0700 Subject: [PATCH 6/7] ignore ./data/ --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index af958d8..cf012c8 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ checkpoints/*.pt demo/backend/checkpoints/*.pt datasets/* *.zip -*.txt \ No newline at end of file +*.txt +data \ No newline at end of file From 6dfba9092cbc9b90e068d94a3469b347b1f81e5c Mon Sep 17 00:00:00 2001 From: YoungKameSennin Date: Tue, 3 Jun 2025 20:28:43 -0700 Subject: [PATCH 7/7] convert images to eval compatible format --- convert.ipynb | 179 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 convert.ipynb diff --git a/convert.ipynb b/convert.ipynb new file mode 100644 index 0000000..40a10b2 --- /dev/null +++ b/convert.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "84f466df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Converting sequence: bike-packing\n", + "→ Saved gray masks in data/sam2_preds_og_gray/bike-packing\n", + "Converting sequence: blackswan\n", + "→ Saved gray masks in data/sam2_preds_og_gray/blackswan\n", + "Converting sequence: bmx-trees\n", + "→ Saved gray masks in data/sam2_preds_og_gray/bmx-trees\n", + "Converting sequence: breakdance\n", + "→ Saved gray masks in data/sam2_preds_og_gray/breakdance\n", + "Converting sequence: camel\n", + "→ Saved gray masks in data/sam2_preds_og_gray/camel\n", + "Converting sequence: car-roundabout\n", + "→ Saved gray masks in data/sam2_preds_og_gray/car-roundabout\n", + "Converting sequence: car-shadow\n", + "→ Saved gray masks in data/sam2_preds_og_gray/car-shadow\n", + "Converting sequence: cows\n", + "→ Saved gray masks in data/sam2_preds_og_gray/cows\n", + "Converting sequence: dance-twirl\n", + "→ Saved gray masks in data/sam2_preds_og_gray/dance-twirl\n", + "Converting sequence: dog\n", + "→ Saved gray masks in data/sam2_preds_og_gray/dog\n", + "Converting sequence: dogs-jump\n", + "→ Saved gray masks in data/sam2_preds_og_gray/dogs-jump\n", + "Converting sequence: drift-chicane\n", + "→ Saved gray masks in data/sam2_preds_og_gray/drift-chicane\n", + "Converting sequence: drift-straight\n", + "→ Saved gray masks in data/sam2_preds_og_gray/drift-straight\n", + "Converting sequence: goat\n", + "→ Saved gray masks in data/sam2_preds_og_gray/goat\n", + "Converting sequence: gold-fish\n", + "→ Saved gray masks in data/sam2_preds_og_gray/gold-fish\n", + "Converting sequence: horsejump-high\n", + "→ Saved gray masks in data/sam2_preds_og_gray/horsejump-high\n", + "Converting sequence: india\n", + "→ Saved gray masks in data/sam2_preds_og_gray/india\n", + "Converting sequence: judo\n", + "→ Saved gray masks in data/sam2_preds_og_gray/judo\n", + "Converting sequence: kite-surf\n", + "→ Saved gray masks in data/sam2_preds_og_gray/kite-surf\n", + "Converting sequence: lab-coat\n", + "→ Saved gray masks in data/sam2_preds_og_gray/lab-coat\n", + "Converting sequence: libby\n", + "→ Saved gray masks in data/sam2_preds_og_gray/libby\n", + "Converting sequence: loading\n", + "→ Saved gray masks in data/sam2_preds_og_gray/loading\n", + "Converting sequence: mbike-trick\n", + "→ Saved gray masks in data/sam2_preds_og_gray/mbike-trick\n", + "Converting sequence: motocross-jump\n", + "→ Saved gray masks in data/sam2_preds_og_gray/motocross-jump\n", + "Converting sequence: paragliding-launch\n", + "→ Saved gray masks in data/sam2_preds_og_gray/paragliding-launch\n", + "Converting sequence: parkour\n", + "→ Saved gray masks in data/sam2_preds_og_gray/parkour\n", + "Converting sequence: pigs\n", + "→ Saved gray masks in data/sam2_preds_og_gray/pigs\n", + "Converting sequence: scooter-black\n", + "→ Saved gray masks in data/sam2_preds_og_gray/scooter-black\n", + "Converting sequence: shooting\n", + "→ Saved gray masks in data/sam2_preds_og_gray/shooting\n", + "Converting sequence: soapbox\n", + "→ Saved gray masks in data/sam2_preds_og_gray/soapbox\n", + "All sequences converted. Now run DAVIS evaluation on data/sam2_preds_og_gray\n" + ] + } + ], + "source": [ + "import os\n", + "from pathlib import Path\n", + "import imageio.v3 as iio\n", + "import numpy as np\n", + "\n", + "# ── USER CONFIG ─────────────────────────────────────────────────────────\n", + "SOURCE_DIR = Path(\"./data/sam2_preds_og\") # your colored‐RGB predictions\n", + "TARGET_DIR = Path(\"./data/sam2_preds_og_gray\") # where we'll save single‐channel PNGs\n", + "\n", + "TARGET_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# ── HELPER: find the “object‐colors → ID” mapping from frame 0 ─────────────\n", + "def extract_color_to_id_map(first_color_png):\n", + " \"\"\"\n", + " Read the RGB image first_color_png (H×W×3), find all non‐black colors,\n", + " and assign them IDs 1,2,3,… in sorted order. Returns:\n", + " • color2id: dict mapping (R,G,B) tuples → integer ID\n", + " • H, W: image dimensions\n", + " \"\"\"\n", + " rgb = iio.imread(str(first_color_png))\n", + " if rgb.ndim != 3 or rgb.shape[2] != 3:\n", + " raise RuntimeError(f\"Expected a (H, W, 3) image at {first_color_png}\")\n", + "\n", + " flat = rgb.reshape(-1, 3)\n", + " uniq = np.unique(flat, axis=0) # (K,3) array of all colors present\n", + " # Drop black (0,0,0):\n", + " non_black = [tuple(c) for c in uniq if not np.all(c == 0)]\n", + " if len(non_black) == 0:\n", + " raise RuntimeError(f\"No non‐black colors found in {first_color_png}\")\n", + "\n", + " # Sort by RGB lex order (optional) to assign stable IDs:\n", + " non_black.sort()\n", + " color2id = {color: (i + 1) for i, color in enumerate(non_black)}\n", + " return color2id, rgb.shape[0], rgb.shape[1]\n", + "\n", + "# ── MAIN LOOP: for each sequence, convert every frame’s RGB to single‐channel IDs ───\n", + "for seq_folder in sorted(SOURCE_DIR.iterdir()):\n", + " if not seq_folder.is_dir():\n", + " continue\n", + "\n", + " print(f\"Converting sequence: {seq_folder.name}\")\n", + " out_seq = TARGET_DIR / seq_folder.name\n", + " out_seq.mkdir(parents=True, exist_ok=True)\n", + "\n", + " # (1) find frame 00000.png and build the color→ID lookup\n", + " first_frame = seq_folder / \"00000.png\"\n", + " if not first_frame.exists():\n", + " raise RuntimeError(f\"Cannot find {first_frame}\")\n", + "\n", + " color2id, H, W = extract_color_to_id_map(first_frame)\n", + " # Example: color2id might be {(200,0,0): 1, (0,200,0): 2}\n", + "\n", + " # (2) iterate over all PNGs in this sequence folder (00000.png, 00001.png, …)\n", + " all_frames = sorted(seq_folder.glob(\"*.png\"))\n", + "\n", + " for frame_path in all_frames:\n", + " rgb = iio.imread(str(frame_path))\n", + " if rgb.ndim != 3 or rgb.shape[:2] != (H, W):\n", + " raise RuntimeError(f\"Unexpected image shape in {frame_path}: {rgb.shape}\")\n", + "\n", + " # Build a blank H×W array of uint8\n", + " id_map = np.zeros((H, W), dtype=np.uint8)\n", + "\n", + " # For each distinct color in color2id, mask and assign ID\n", + " # (Pixels that remain black → ID=0)\n", + " for (R, G, B), obj_id in color2id.items():\n", + " mask = (rgb[:, :, 0] == R) & (rgb[:, :, 1] == G) & (rgb[:, :, 2] == B)\n", + " if mask.any():\n", + " id_map[mask] = obj_id\n", + "\n", + " # Save the new single‐channel PNG\n", + " out_path = out_seq / frame_path.name\n", + " iio.imwrite(str(out_path), id_map)\n", + "\n", + " print(f\"→ Saved gray masks in {out_seq}\")\n", + "\n", + "print(\"All sequences converted. Now run DAVIS evaluation on\", TARGET_DIR)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sam2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}