diff --git a/.gitignore b/.gitignore index af958d8..cf012c8 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ checkpoints/*.pt demo/backend/checkpoints/*.pt datasets/* *.zip -*.txt \ No newline at end of file +*.txt +data \ No newline at end of file diff --git a/convert.ipynb b/convert.ipynb new file mode 100644 index 0000000..40a10b2 --- /dev/null +++ b/convert.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "84f466df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Converting sequence: bike-packing\n", + "→ Saved gray masks in data/sam2_preds_og_gray/bike-packing\n", + "Converting sequence: blackswan\n", + "→ Saved gray masks in data/sam2_preds_og_gray/blackswan\n", + "Converting sequence: bmx-trees\n", + "→ Saved gray masks in data/sam2_preds_og_gray/bmx-trees\n", + "Converting sequence: breakdance\n", + "→ Saved gray masks in data/sam2_preds_og_gray/breakdance\n", + "Converting sequence: camel\n", + "→ Saved gray masks in data/sam2_preds_og_gray/camel\n", + "Converting sequence: car-roundabout\n", + "→ Saved gray masks in data/sam2_preds_og_gray/car-roundabout\n", + "Converting sequence: car-shadow\n", + "→ Saved gray masks in data/sam2_preds_og_gray/car-shadow\n", + "Converting sequence: cows\n", + "→ Saved gray masks in data/sam2_preds_og_gray/cows\n", + "Converting sequence: dance-twirl\n", + "→ Saved gray masks in data/sam2_preds_og_gray/dance-twirl\n", + "Converting sequence: dog\n", + "→ Saved gray masks in data/sam2_preds_og_gray/dog\n", + "Converting sequence: dogs-jump\n", + "→ Saved gray masks in data/sam2_preds_og_gray/dogs-jump\n", + "Converting sequence: drift-chicane\n", + "→ Saved gray masks in data/sam2_preds_og_gray/drift-chicane\n", + "Converting sequence: drift-straight\n", + "→ Saved gray masks in data/sam2_preds_og_gray/drift-straight\n", + "Converting sequence: goat\n", + "→ Saved gray masks in data/sam2_preds_og_gray/goat\n", + "Converting sequence: gold-fish\n", + "→ Saved gray masks in data/sam2_preds_og_gray/gold-fish\n", + "Converting sequence: horsejump-high\n", + "→ Saved gray masks in data/sam2_preds_og_gray/horsejump-high\n", + "Converting sequence: india\n", + "→ Saved gray masks in data/sam2_preds_og_gray/india\n", + "Converting sequence: judo\n", + "→ Saved gray masks in data/sam2_preds_og_gray/judo\n", + "Converting sequence: kite-surf\n", + "→ Saved gray masks in data/sam2_preds_og_gray/kite-surf\n", + "Converting sequence: lab-coat\n", + "→ Saved gray masks in data/sam2_preds_og_gray/lab-coat\n", + "Converting sequence: libby\n", + "→ Saved gray masks in data/sam2_preds_og_gray/libby\n", + "Converting sequence: loading\n", + "→ Saved gray masks in data/sam2_preds_og_gray/loading\n", + "Converting sequence: mbike-trick\n", + "→ Saved gray masks in data/sam2_preds_og_gray/mbike-trick\n", + "Converting sequence: motocross-jump\n", + "→ Saved gray masks in data/sam2_preds_og_gray/motocross-jump\n", + "Converting sequence: paragliding-launch\n", + "→ Saved gray masks in data/sam2_preds_og_gray/paragliding-launch\n", + "Converting sequence: parkour\n", + "→ Saved gray masks in data/sam2_preds_og_gray/parkour\n", + "Converting sequence: pigs\n", + "→ Saved gray masks in data/sam2_preds_og_gray/pigs\n", + "Converting sequence: scooter-black\n", + "→ Saved gray masks in data/sam2_preds_og_gray/scooter-black\n", + "Converting sequence: shooting\n", + "→ Saved gray masks in data/sam2_preds_og_gray/shooting\n", + "Converting sequence: soapbox\n", + "→ Saved gray masks in data/sam2_preds_og_gray/soapbox\n", + "All sequences converted. Now run DAVIS evaluation on data/sam2_preds_og_gray\n" + ] + } + ], + "source": [ + "import os\n", + "from pathlib import Path\n", + "import imageio.v3 as iio\n", + "import numpy as np\n", + "\n", + "# ── USER CONFIG ─────────────────────────────────────────────────────────\n", + "SOURCE_DIR = Path(\"./data/sam2_preds_og\") # your colored‐RGB predictions\n", + "TARGET_DIR = Path(\"./data/sam2_preds_og_gray\") # where we'll save single‐channel PNGs\n", + "\n", + "TARGET_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# ── HELPER: find the “object‐colors → ID” mapping from frame 0 ─────────────\n", + "def extract_color_to_id_map(first_color_png):\n", + " \"\"\"\n", + " Read the RGB image first_color_png (H×W×3), find all non‐black colors,\n", + " and assign them IDs 1,2,3,… in sorted order. Returns:\n", + " • color2id: dict mapping (R,G,B) tuples → integer ID\n", + " • H, W: image dimensions\n", + " \"\"\"\n", + " rgb = iio.imread(str(first_color_png))\n", + " if rgb.ndim != 3 or rgb.shape[2] != 3:\n", + " raise RuntimeError(f\"Expected a (H, W, 3) image at {first_color_png}\")\n", + "\n", + " flat = rgb.reshape(-1, 3)\n", + " uniq = np.unique(flat, axis=0) # (K,3) array of all colors present\n", + " # Drop black (0,0,0):\n", + " non_black = [tuple(c) for c in uniq if not np.all(c == 0)]\n", + " if len(non_black) == 0:\n", + " raise RuntimeError(f\"No non‐black colors found in {first_color_png}\")\n", + "\n", + " # Sort by RGB lex order (optional) to assign stable IDs:\n", + " non_black.sort()\n", + " color2id = {color: (i + 1) for i, color in enumerate(non_black)}\n", + " return color2id, rgb.shape[0], rgb.shape[1]\n", + "\n", + "# ── MAIN LOOP: for each sequence, convert every frame’s RGB to single‐channel IDs ───\n", + "for seq_folder in sorted(SOURCE_DIR.iterdir()):\n", + " if not seq_folder.is_dir():\n", + " continue\n", + "\n", + " print(f\"Converting sequence: {seq_folder.name}\")\n", + " out_seq = TARGET_DIR / seq_folder.name\n", + " out_seq.mkdir(parents=True, exist_ok=True)\n", + "\n", + " # (1) find frame 00000.png and build the color→ID lookup\n", + " first_frame = seq_folder / \"00000.png\"\n", + " if not first_frame.exists():\n", + " raise RuntimeError(f\"Cannot find {first_frame}\")\n", + "\n", + " color2id, H, W = extract_color_to_id_map(first_frame)\n", + " # Example: color2id might be {(200,0,0): 1, (0,200,0): 2}\n", + "\n", + " # (2) iterate over all PNGs in this sequence folder (00000.png, 00001.png, …)\n", + " all_frames = sorted(seq_folder.glob(\"*.png\"))\n", + "\n", + " for frame_path in all_frames:\n", + " rgb = iio.imread(str(frame_path))\n", + " if rgb.ndim != 3 or rgb.shape[:2] != (H, W):\n", + " raise RuntimeError(f\"Unexpected image shape in {frame_path}: {rgb.shape}\")\n", + "\n", + " # Build a blank H×W array of uint8\n", + " id_map = np.zeros((H, W), dtype=np.uint8)\n", + "\n", + " # For each distinct color in color2id, mask and assign ID\n", + " # (Pixels that remain black → ID=0)\n", + " for (R, G, B), obj_id in color2id.items():\n", + " mask = (rgb[:, :, 0] == R) & (rgb[:, :, 1] == G) & (rgb[:, :, 2] == B)\n", + " if mask.any():\n", + " id_map[mask] = obj_id\n", + "\n", + " # Save the new single‐channel PNG\n", + " out_path = out_seq / frame_path.name\n", + " iio.imwrite(str(out_path), id_map)\n", + "\n", + " print(f\"→ Saved gray masks in {out_seq}\")\n", + "\n", + "print(\"All sequences converted. Now run DAVIS evaluation on\", TARGET_DIR)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sam2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/run_mgfs_davis.ipynb b/run_mgfs_davis.ipynb new file mode 100644 index 0000000..ef7a093 --- /dev/null +++ b/run_mgfs_davis.ipynb @@ -0,0 +1,4314 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "id": "acc92286", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 30 validation sequences\n" + ] + }, + { + "data": { + "text/plain": [ + "SAM2VideoPredictor(\n", + " (image_encoder): ImageEncoder(\n", + " (trunk): Hiera(\n", + " (patch_embed): PatchEmbed(\n", + " (proj): Conv2d(3, 144, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))\n", + " )\n", + " (blocks): ModuleList(\n", + " (0-1): 2 x MultiScaleBlock(\n", + " (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=144, out_features=432, bias=True)\n", + " (proj): Linear(in_features=144, out_features=144, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((144,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=144, out_features=576, bias=True)\n", + " (1): Linear(in_features=576, out_features=144, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " (2): MultiScaleBlock(\n", + " (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)\n", + " (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (attn): MultiScaleAttention(\n", + " (q_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (qkv): Linear(in_features=144, out_features=864, bias=True)\n", + " (proj): Linear(in_features=288, out_features=288, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=288, out_features=1152, bias=True)\n", + " (1): Linear(in_features=1152, out_features=288, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " (proj): Linear(in_features=144, out_features=288, bias=True)\n", + " )\n", + " (3-7): 5 x MultiScaleBlock(\n", + " (norm1): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=288, out_features=864, bias=True)\n", + " (proj): Linear(in_features=288, out_features=288, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=288, out_features=1152, bias=True)\n", + " (1): Linear(in_features=1152, out_features=288, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " (8): MultiScaleBlock(\n", + " (norm1): LayerNorm((288,), eps=1e-06, elementwise_affine=True)\n", + " (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (attn): MultiScaleAttention(\n", + " (q_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (qkv): Linear(in_features=288, out_features=1728, bias=True)\n", + " (proj): Linear(in_features=576, out_features=576, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=576, out_features=2304, bias=True)\n", + " (1): Linear(in_features=2304, out_features=576, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " (proj): Linear(in_features=288, out_features=576, bias=True)\n", + " )\n", + " (9-43): 35 x MultiScaleBlock(\n", + " (norm1): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=576, out_features=1728, bias=True)\n", + " (proj): Linear(in_features=576, out_features=576, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=576, out_features=2304, bias=True)\n", + " (1): Linear(in_features=2304, out_features=576, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " (44): MultiScaleBlock(\n", + " (norm1): LayerNorm((576,), eps=1e-06, elementwise_affine=True)\n", + " (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (attn): MultiScaleAttention(\n", + " (q_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n", + " (qkv): Linear(in_features=576, out_features=3456, bias=True)\n", + " (proj): Linear(in_features=1152, out_features=1152, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=1152, out_features=4608, bias=True)\n", + " (1): Linear(in_features=4608, out_features=1152, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " (proj): Linear(in_features=576, out_features=1152, bias=True)\n", + " )\n", + " (45-47): 3 x MultiScaleBlock(\n", + " (norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MultiScaleAttention(\n", + " (qkv): Linear(in_features=1152, out_features=3456, bias=True)\n", + " (proj): Linear(in_features=1152, out_features=1152, bias=True)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=1152, out_features=4608, bias=True)\n", + " (1): Linear(in_features=4608, out_features=1152, bias=True)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (neck): FpnNeck(\n", + " (position_encoding): PositionEmbeddingSine()\n", + " (convs): ModuleList(\n", + " (0): Sequential(\n", + " (conv): Conv2d(1152, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (1): Sequential(\n", + " (conv): Conv2d(576, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (2): Sequential(\n", + " (conv): Conv2d(288, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (3): Sequential(\n", + " (conv): Conv2d(144, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (mask_downsample): Conv2d(1, 1, kernel_size=(4, 4), stride=(4, 4))\n", + " (memory_attention): MemoryAttention(\n", + " (layers): ModuleList(\n", + " (0-3): 4 x MemoryAttentionLayer(\n", + " (self_attn): RoPEAttention(\n", + " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (cross_attn_image): RoPEAttention(\n", + " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (k_proj): Linear(in_features=64, out_features=256, bias=True)\n", + " (v_proj): Linear(in_features=64, out_features=256, bias=True)\n", + " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (linear1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=256, bias=True)\n", + " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " (dropout3): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (memory_encoder): MemoryEncoder(\n", + " (mask_downsampler): MaskDownSampler(\n", + " (encoder): Sequential(\n", + " (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): LayerNorm2d()\n", + " (2): GELU(approximate='none')\n", + " (3): Conv2d(4, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (4): LayerNorm2d()\n", + " (5): GELU(approximate='none')\n", + " (6): Conv2d(16, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (7): LayerNorm2d()\n", + " (8): GELU(approximate='none')\n", + " (9): Conv2d(64, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (10): LayerNorm2d()\n", + " (11): GELU(approximate='none')\n", + " (12): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (pix_feat_proj): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (fuser): Fuser(\n", + " (proj): Identity()\n", + " (layers): ModuleList(\n", + " (0-1): 2 x CXBlock(\n", + " (dwconv): Conv2d(256, 256, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=256)\n", + " (norm): LayerNorm2d()\n", + " (pwconv1): Linear(in_features=256, out_features=1024, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (pwconv2): Linear(in_features=1024, out_features=256, bias=True)\n", + " (drop_path): Identity()\n", + " )\n", + " )\n", + " )\n", + " (position_encoding): PositionEmbeddingSine()\n", + " (out_proj): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (sam_prompt_encoder): PromptEncoder(\n", + " (pe_layer): PositionEmbeddingRandom()\n", + " (point_embeddings): ModuleList(\n", + " (0-3): 4 x Embedding(1, 256)\n", + " )\n", + " (not_a_point_embed): Embedding(1, 256)\n", + " (mask_downscaling): Sequential(\n", + " (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))\n", + " (1): LayerNorm2d()\n", + " (2): GELU(approximate='none')\n", + " (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))\n", + " (4): LayerNorm2d()\n", + " (5): GELU(approximate='none')\n", + " (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (no_mask_embed): Embedding(1, 256)\n", + " )\n", + " (sam_mask_decoder): MaskDecoder(\n", + " (transformer): TwoWayTransformer(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x TwoWayAttentionBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (cross_attn_token_to_image): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (out_proj): Linear(in_features=128, out_features=256, bias=True)\n", + " )\n", + " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=256, out_features=2048, bias=True)\n", + " (1): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (norm4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " (cross_attn_image_to_token): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (out_proj): Linear(in_features=128, out_features=256, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (final_attn_token_to_image): Attention(\n", + " (q_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (k_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (v_proj): Linear(in_features=256, out_features=128, bias=True)\n", + " (out_proj): Linear(in_features=128, out_features=256, bias=True)\n", + " )\n", + " (norm_final_attn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (iou_token): Embedding(1, 256)\n", + " (mask_tokens): Embedding(4, 256)\n", + " (obj_score_token): Embedding(1, 256)\n", + " (output_upscaling): Sequential(\n", + " (0): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))\n", + " (1): LayerNorm2d()\n", + " (2): GELU(approximate='none')\n", + " (3): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))\n", + " (4): GELU(approximate='none')\n", + " )\n", + " (conv_s0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_s1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (output_hypernetworks_mlps): ModuleList(\n", + " (0-3): 4 x MLP(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)\n", + " (2): Linear(in_features=256, out_features=32, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " )\n", + " (iou_prediction_head): MLP(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)\n", + " (2): Linear(in_features=256, out_features=4, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " (pred_obj_score_head): MLP(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)\n", + " (2): Linear(in_features=256, out_features=1, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " )\n", + " (obj_ptr_proj): MLP(\n", + " (layers): ModuleList(\n", + " (0-2): 3 x Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + " (act): ReLU()\n", + " )\n", + " (obj_ptr_tpos_proj): Linear(in_features=256, out_features=64, bias=True)\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sys, os, json, time, subprocess, pathlib\n", + "from pathlib import Path\n", + "from davis2017.davis import DAVIS\n", + "import imageio.v3 as iio\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "import torch\n", + "\n", + "# ── USER‐CONFIGURABLE PATHS ──────────────────────────────────────────────────\n", + "DAVIS_ROOT = Path(\"./data/davis/DAVIS\") # ← point this at your DAVIS folder\n", + "OUT_DIR = Path(\"./data/sam2_preds\") # ← where we’ll write out PNGs\n", + "OUT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# ── STEP 1: load the DAVIS “val” split (semi‐supervised task) ───────────────────\n", + "ds = DAVIS(str(DAVIS_ROOT), task=\"semi-supervised\", subset=\"val\", resolution=\"480p\")\n", + "print(f\"Loaded {len(ds.sequences)} validation sequences\")\n", + "\n", + "# ── STEP 2: build SAM 2 video predictor ─────────────────────────────────────────\n", + "from sam2.build_sam import build_sam2_video_predictor\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "sam2_checkpoint = \"./checkpoints/sam2.1_hiera_large.pt\" # ← adjust if needed\n", + "model_cfg = \"configs/sam2.1/sam2.1_hiera_l.yaml\" # ← adjust if needed\n", + "\n", + "predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)\n", + "predictor.to(device) # make sure model is on CUDA if available" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "bdd669c2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== Processing sequence: bike-packing ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 69/69 [00:01<00:00, 60.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 unique non‐black colors in bike-packing/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 7%|▋ | 5/68 [00:00<00:01, 45.88it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 1 due to low MAD (0.04)\n", + "Skipping frame 2 due to low MAD (0.06)\n", + "Skipping frame 3 due to low MAD (0.07)\n", + "Skipping frame 4 due to low MAD (0.09)\n", + "Skipping frame 6 due to low MAD (0.07)\n", + "Skipping frame 7 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 15%|█▍ | 10/68 [00:00<00:04, 13.88it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 9 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 19%|█▉ | 13/68 [00:01<00:07, 7.50it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 13 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 22%|██▏ | 15/68 [00:01<00:07, 7.41it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 15 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 25%|██▌ | 17/68 [00:01<00:06, 7.35it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 17 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 34%|███▍ | 23/68 [00:03<00:09, 4.83it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 23 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 38%|███▊ | 26/68 [00:03<00:08, 5.06it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 26 due to low MAD (0.08)\n", + "Skipping frame 27 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 50%|█████ | 34/68 [00:04<00:02, 12.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 29 due to low MAD (0.04)\n", + "Skipping frame 30 due to low MAD (0.05)\n", + "Avg FPS last 30 frames: 7.27\n", + "Skipping frame 31 due to low MAD (0.05)\n", + "Skipping frame 32 due to low MAD (0.06)\n", + "Skipping frame 33 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 53%|█████▎ | 36/68 [00:04<00:02, 10.67it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 35 due to low MAD (0.07)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 56%|█████▌ | 38/68 [00:04<00:03, 9.55it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 37 due to low MAD (0.07)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 62%|██████▏ | 42/68 [00:05<00:04, 5.72it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 44 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 81%|████████ | 55/68 [00:07<00:02, 5.75it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 55 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 84%|████████▍ | 57/68 [00:07<00:01, 6.06it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 57 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 90%|████████▉ | 61/68 [00:08<00:01, 5.29it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 6.83\n", + "Skipping frame 61 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 93%|█████████▎| 63/68 [00:08<00:00, 5.84it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 63 due to low MAD (0.07)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bike-packing: 96%|█████████▌| 65/68 [00:09<00:00, 6.24it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 65 due to low MAD (0.07)\n", + "Skipping frame 66 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 69/69 [00:09<00:00, 7.34it/s]2it/s]\n", + "Propagating bike-packing: 69it [00:09, 7.33it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 68 due to low MAD (0.04)\n", + "Skipped 28 frames due to low MAD.\n", + "→ Saved all predicted masks for bike-packing in data/sam2_preds_multi/bike-packing\n", + "\n", + "=== Processing sequence: blackswan ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 50/50 [00:00<00:00, 60.41it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in blackswan/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating blackswan: 65%|██████▌ | 32/49 [00:05<00:03, 5.29it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.39\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 50/50 [00:09<00:00, 5.55it/s] \n", + "Propagating blackswan: 50it [00:09, 5.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for blackswan in data/sam2_preds_multi/blackswan\n", + "\n", + "=== Processing sequence: bmx-trees ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 80/80 [00:01<00:00, 60.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2 unique non‐black colors in bmx-trees/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bmx-trees: 39%|███▉ | 31/79 [00:07<00:12, 3.95it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 4.10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating bmx-trees: 77%|███████▋ | 61/79 [00:14<00:04, 3.97it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 3.96\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 80/80 [00:18<00:00, 4.39it/s] \n", + "Propagating bmx-trees: 80it [00:18, 4.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for bmx-trees in data/sam2_preds_multi/bmx-trees\n", + "\n", + "=== Processing sequence: breakdance ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 84/84 [00:01<00:00, 63.11it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in breakdance/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 4%|▎ | 3/83 [00:00<00:04, 16.40it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 1 due to low MAD (0.09)\n", + "Skipping frame 3 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 11%|█ | 9/83 [00:00<00:07, 9.26it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 7 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 19%|█▉ | 16/83 [00:01<00:08, 7.55it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 14 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 29%|██▉ | 24/83 [00:03<00:08, 7.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 22 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 31%|███▏ | 26/83 [00:03<00:07, 7.90it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 24 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 39%|███▊ | 32/83 [00:04<00:07, 7.18it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 30 due to low MAD (0.10)\n", + "Avg FPS last 30 frames: 7.06\n", + "Skipping frame 32 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 43%|████▎ | 36/83 [00:04<00:05, 8.72it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 34 due to low MAD (0.10)\n", + "Skipping frame 36 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 48%|████▊ | 40/83 [00:05<00:04, 9.53it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 38 due to low MAD (0.09)\n", + "Skipping frame 40 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 58%|█████▊ | 48/83 [00:06<00:04, 7.82it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 46 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 69%|██████▊ | 57/83 [00:07<00:03, 7.27it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 55 due to low MAD (0.08)\n", + "Skipping frame 57 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 73%|███████▎ | 61/83 [00:08<00:02, 8.87it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 59 due to low MAD (0.10)\n", + "Avg FPS last 30 frames: 7.71\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 88%|████████▊ | 73/83 [00:10<00:01, 7.11it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 71 due to low MAD (0.09)\n", + "Skipping frame 73 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 93%|█████████▎| 77/83 [00:10<00:00, 8.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 75 due to low MAD (0.09)\n", + "Skipping frame 77 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating breakdance: 98%|█████████▊| 81/83 [00:10<00:00, 9.59it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 79 due to low MAD (0.09)\n", + "Skipping frame 81 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 84/84 [00:11<00:00, 7.55it/s]t/s]\n", + "Propagating breakdance: 84it [00:11, 7.54it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 83 due to low MAD (0.09)\n", + "Skipped 23 frames due to low MAD.\n", + "→ Saved all predicted masks for breakdance in data/sam2_preds_multi/breakdance\n", + "\n", + "=== Processing sequence: camel ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 90/90 [00:01<00:00, 61.51it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in camel/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating camel: 36%|███▌ | 32/89 [00:05<00:10, 5.48it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.62\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating camel: 70%|██████▉ | 62/89 [00:10<00:04, 5.52it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.50\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 90/90 [00:14<00:00, 6.15it/s]\n", + "Propagating camel: 90it [00:14, 6.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for camel in data/sam2_preds_multi/camel\n", + "\n", + "=== Processing sequence: car-roundabout ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 75/75 [00:01<00:00, 61.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in car-roundabout/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating car-roundabout: 43%|████▎ | 32/74 [00:05<00:07, 5.54it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.67\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating car-roundabout: 84%|████████▍ | 62/74 [00:11<00:02, 5.32it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.43\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 75/75 [00:13<00:00, 5.55it/s] \n", + "Propagating car-roundabout: 75it [00:13, 5.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for car-roundabout in data/sam2_preds_multi/car-roundabout\n", + "\n", + "=== Processing sequence: car-shadow ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 40/40 [00:00<00:00, 61.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in car-shadow/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating car-shadow: 82%|████████▏ | 32/39 [00:05<00:01, 5.22it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.34\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 40/40 [00:07<00:00, 5.44it/s] \n", + "Propagating car-shadow: 40it [00:07, 5.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for car-shadow in data/sam2_preds_multi/car-shadow\n", + "\n", + "=== Processing sequence: cows ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 104/104 [00:01<00:00, 58.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in cows/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 12%|█▏ | 12/103 [00:01<00:16, 5.37it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 12 due to low MAD (0.09)\n", + "Skipping frame 13 due to low MAD (0.08)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 29%|██▉ | 30/103 [00:03<00:04, 17.70it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 7.76\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 34%|███▍ | 35/103 [00:04<00:06, 10.31it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 33 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 60%|██████ | 62/103 [00:09<00:07, 5.25it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.40\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 79%|███████▊ | 81/103 [00:13<00:04, 5.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 81 due to low MAD (0.07)\n", + "Skipping frame 82 due to low MAD (0.10)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating cows: 89%|████████▉ | 92/103 [00:14<00:01, 5.52it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.66\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 104/104 [00:17<00:00, 6.06it/s]\n", + "Propagating cows: 104it [00:17, 6.06it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipped 5 frames due to low MAD.\n", + "→ Saved all predicted masks for cows in data/sam2_preds_multi/cows\n", + "\n", + "=== Processing sequence: dance-twirl ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 90/90 [00:01<00:00, 60.09it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in dance-twirl/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dance-twirl: 36%|███▌ | 32/89 [00:05<00:10, 5.32it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.45\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dance-twirl: 55%|█████▌ | 49/89 [00:08<00:05, 6.72it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping frame 47 due to low MAD (0.09)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dance-twirl: 70%|██████▉ | 62/89 [00:11<00:05, 5.35it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.48\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 90/90 [00:15<00:00, 5.98it/s]it/s]\n", + "Propagating dance-twirl: 90it [00:15, 5.98it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipped 1 frames due to low MAD.\n", + "→ Saved all predicted masks for dance-twirl in data/sam2_preds_multi/dance-twirl\n", + "\n", + "=== Processing sequence: dog ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 60/60 [00:01<00:00, 59.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 1 unique non‐black colors in dog/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dog: 54%|█████▍ | 32/59 [00:05<00:05, 5.36it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Avg FPS last 30 frames: 5.46\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "propagate in video: 100%|██████████| 60/60 [00:10<00:00, 5.54it/s]\n", + "Propagating dog: 60it [00:10, 5.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "→ Saved all predicted masks for dog in data/sam2_preds_multi/dog\n", + "\n", + "=== Processing sequence: dogs-jump ===\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "frame loading (JPEG): 100%|██████████| 66/66 [00:01<00:00, 58.85it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 3 unique non‐black colors in dogs-jump/00000.png\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Propagating dogs-jump: 0%| | 0/65 [00:00\"\n", + " inference_state = predictor.init_state(\n", + " video_path=video_dir,\n", + " offload_video_to_cpu=False,\n", + " offload_state_to_cpu=False,\n", + " async_loading_frames=False\n", + " )\n", + "\n", + "# (b) load the single “00000.png” which contains two different colored regions\n", + " rgb = iio.imread(str(mask_dir / \"00000.png\")) # shape (H, W, 3)\n", + " H, W, C = rgb.shape\n", + " assert C == 3, \"Expected a 3‐channel (RGB) first‐frame mask.\"\n", + "\n", + " # (c) find all unique RGB colors except black\n", + " flat = rgb.reshape(-1, 3) # shape (H*W, 3)\n", + " uniq_colors = np.unique(flat, axis=0) # shape (K, 3), where K ≤ (H*W)\n", + " # Remove the black color (0,0,0) if present\n", + " non_black = [tuple(c) for c in uniq_colors if not np.all(c == 0)]\n", + " if len(non_black) == 0:\n", + " raise RuntimeError(f\"No non‐black colors found in {seq}/00000.png\")\n", + "\n", + " # (d) for each unique non‐black color, build a 2D boolean mask and register it\n", + " print(f\"Found {len(non_black)} unique non‐black colors in {seq}/00000.png\")\n", + " for idx, color in enumerate(non_black):\n", + " # color is something like (200, 0, 0) or (0, 200, 0)\n", + " R, G, B = color\n", + " # build a binary mask: True where pixel == this color\n", + " bin_mask = np.logical_and.reduce([\n", + " rgb[:, :, 0] == R,\n", + " rgb[:, :, 1] == G,\n", + " rgb[:, :, 2] == B\n", + " ]) # shape (H, W), dtype=bool\n", + "\n", + " # wrap as torch.bool on the same device as SAM 2\n", + " mask2d = torch.from_numpy(bin_mask).to(device)\n", + "\n", + " # register this mask as object `idx`\n", + " predictor.add_new_mask(\n", + " inference_state=inference_state,\n", + " frame_idx=0,\n", + " obj_id=idx, # choose 0,1,2,… per color\n", + " mask=mask2d\n", + " )\n", + "\n", + " # 3e) now propagate through all frames. As each new frame is processed,\n", + " # propagate_in_video yields (frame_idx, [obj_ids], video_res_masks).\n", + " #\n", + " # We’ll save each mask as “00000.png”, “00001.png”, … under OUT_DIR//\n", + " seq_out_dir = OUT_DIR / seq\n", + " seq_out_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " for frame_idx, obj_ids, video_res_masks in tqdm(\n", + " predictor.propagate_in_video(inference_state),\n", + " total=len(img_paths)-1,\n", + " desc=f\"Propagating {seq}\"\n", + " ):\n", + " # # ‣ frame_idx is an integer (1,2,3,…). video_res_masks is a tensor of shape\n", + " # # (num_objects, H, W). For DAVIS, num_objects==1.\n", + " # #\n", + " # # ‣ Thresholding has already happened internally; `video_res_masks` is\n", + " # # a float‐tensor where positive values correspond to predicted “object.”\n", + " # mask_np = (video_res_masks[0].cpu().numpy() > 0.0).astype(np.uint8) * 255\n", + "\n", + " # # Save with zero‐padded five digits to match DAVIS naming:\n", + " # save_name = f\"{frame_idx:05d}.png\"\n", + " # save_path = seq_out_dir / save_name\n", + " # iio.imwrite(str(save_path), mask_np)\n", + "\n", + " # Suppose `video_res_masks` is whatever you get from propagate_in_video:\n", + " # • If there is only one object, it may be a 2D tensor of shape (H, W)\n", + " # • If there are multiple objects, it will be a 3D tensor of shape (O, H, W)\n", + "\n", + " pred_np = video_res_masks.cpu().numpy() # dtype=float32 or float; # ───────────────────────────────────────────────────────────────\n", + " # Assume you already did:\n", + " # pred_np = video_res_masks.cpu().numpy()\n", + "\n", + " # 1) Check how many dimensions `pred_np` has:\n", + " if pred_np.ndim == 2:\n", + " # Case A: single object, shape = (H, W)\n", + " H, W = pred_np.shape\n", + " O = 1\n", + " pred_np = pred_np[np.newaxis, ...] # -> now shape (1, H, W)\n", + "\n", + " elif pred_np.ndim == 3:\n", + " # Could be either:\n", + " # (A) shape = (1, H, W) ← single object with a leading axis\n", + " # (B) shape = (O, H, W) ← multiple objects, no extra channel axis\n", + " if pred_np.shape[0] == 1:\n", + " # Treat as “one‐object” → squeeze to (1, H, W) (already fits our convention)\n", + " O, H, W = pred_np.shape\n", + " else:\n", + " # Multi‐object already: (O, H, W)\n", + " O, H, W = pred_np.shape\n", + " # (no need to reshape because it’s already (O, H, W))\n", + "\n", + " elif pred_np.ndim == 4:\n", + " # Some SAM 2 builds return (O, 1, H, W). In that case:\n", + " # • pred_np.shape = (O, 1, H, W)\n", + " # → we want to drop the “channel” dimension (axis=1).\n", + " O = pred_np.shape[0]\n", + " H = pred_np.shape[2]\n", + " W = pred_np.shape[3]\n", + " pred_np = pred_np[:, 0, :, :] # now shape (O, H, W)\n", + "\n", + " else:\n", + " raise RuntimeError(f\"Unexpected mask array with ndim={pred_np.ndim}, shape={pred_np.shape}\")\n", + "\n", + " # At this point:\n", + " # • pred_np is guaranteed to have shape (O, H, W)\n", + " # • O, H, W are set correctly\n", + " # ───────────────────────────────────────────────────────────────\n", + "\n", + " # Now you can build your colored output exactly as before:\n", + "\n", + " colored = np.zeros((H, W, 3), dtype=np.uint8)\n", + "\n", + " for i in range(O):\n", + " mask_i = (pred_np[i] > 0.0) # boolean mask (H, W)\n", + " if not mask_i.any():\n", + " continue\n", + " R, G, B = non_black[i] # the original RGB for object i\n", + " colored[mask_i, 0] = R\n", + " colored[mask_i, 1] = G\n", + " colored[mask_i, 2] = B\n", + "\n", + " save_name = f\"{frame_idx:05d}.png\"\n", + " save_path = seq_out_dir / save_name\n", + " iio.imwrite(str(save_path), colored)\n", + "\n", + "\n", + " print(f\"→ Saved all predicted masks for {seq} in {seq_out_dir}\")\n", + "\n", + "print(\"\\nAll sequences processed.\")\n", + "print(f\"Your SAM 2 masks live under: {OUT_DIR}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sam2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index c88e111..07709ae 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -9,14 +9,29 @@ import torch import torch.nn.functional as F -import numpy as np + from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames -from sam2.utils.change_detection import should_skip_frame -from sam2.utils.optical_flow import warp_mask_forward + +# MGFK Helpers import cv2 +import numpy as np +import time + +def _mean_abs_diff(img1: torch.Tensor, img2: torch.Tensor) -> float: + """ + Compute mean absolute pixel difference on two HWC images in uint8 range [0,255]. + Inputs are 3-D CPU tensors (H,W,C) with dtype uint8. + """ + # convert to NumPy for speed + a = img1.cpu().numpy().astype(np.int16) + b = img2.cpu().numpy().astype(np.int16) + return float(np.mean(np.abs(a - b))) + + + class SAM2VideoPredictor(SAM2Base): """The predictor class to handle user interactions and manage inference states.""" @@ -32,6 +47,7 @@ def __init__( # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames add_all_frames_to_correct_as_cond=False, + skip_mad_threshold: float = 0.05, **kwargs, ): super().__init__(**kwargs) @@ -39,6 +55,7 @@ def __init__( self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.skip_mad_threshold = skip_mad_threshold @torch.inference_mode() def init_state( @@ -96,20 +113,13 @@ def init_state( # (we directly use their consolidated outputs during tracking) # metadata for each tracking frame (e.g. which direction it's tracked) inference_state["frames_tracked_per_obj"] = {} - # Warm up the visual backbone and cache the image feature on frame 0 - self._get_image_feature(inference_state, frame_idx=0, batch_size=1) - first = inference_state["images"][0] - # If CHW, swap dims - if first.ndim == 3 and first.shape[0] in (1,3): - _, h, w = first.shape - elif first.ndim == 3 and first.shape[2] in (1,3): - h, w, _ = first.shape - else: - # Fallback: assume H x W - h, w = first.shape[-2], first.shape[-1] - inference_state["video_size"] = (h, w) + # Add two slots to the inference state so they survive across frames + inference_state["last_processed_frame"] = None # uint8 HWC tensor + inference_state["last_processed_output"] = None # compact_current_out + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @classmethod @@ -293,6 +303,11 @@ def add_new_points_or_box( # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out + # Any user interaction invalidates reuse cache + inference_state["last_processed_frame"] = None + inference_state["last_processed_output"] = None + + # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( @@ -381,6 +396,10 @@ def add_new_mask( # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out + # Any user interaction invalidates reuse cache + inference_state["last_processed_frame"] = None + inference_state["last_processed_output"] = None + # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( @@ -564,201 +583,140 @@ def propagate_in_video( max_frame_num_to_track=None, reverse=False, ): - """ - Propagate the input prompts (points/masks) across all frames - in the video, optionally skipping low-change frames. - Yields (frame_idx, obj_ids, masks) once per frame. - """ - # Prepare various buffers and metadata + """Propagate the input points across frames to track in the entire video.""" self.propagate_in_video_preflight(inference_state) - obj_ids = inference_state["obj_ids"] # list of object IDs - num_frames = inference_state["num_frames"] # total frames - batch_size = len(obj_ids) # number of objects + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) - # Determine start/end indices + # set start index, end index, and processing order if start_frame_idx is None: - # first conditioning frame across all objects + # default: start from the earliest frame with input points start_frame_idx = min( t - for obj_dict in inference_state["output_dict_per_obj"].values() - for t in obj_dict["cond_frame_outputs"] + for obj_output_dict in inference_state["output_dict_per_obj"].values() + for t in obj_output_dict["cond_frame_outputs"] ) if max_frame_num_to_track is None: + # default: track all the frames in the video max_frame_num_to_track = num_frames - if reverse: end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) - processing_order = ( - range(start_frame_idx, end_frame_idx - 1, -1) - if start_frame_idx > 0 - else [] - ) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 else: - end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1) + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) processing_order = range(start_frame_idx, end_frame_idx + 1) - # Track the last full-inference frame and mask for skipping - prev_frame_idx = None - prev_pred_masks = None + t0 = time.time() + skipped_ctr = 0 - # Iterate through each frame index for frame_idx in tqdm(processing_order, desc="propagate in video"): - # Container for this frame’s predicted masks per object - pred_masks_per_obj = [None] * batch_size - - # Loop over each object independently - for obj_idx in range(batch_size): - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] - storage_key = None - pred_masks = None - # 1) Conditioning frame (manual prompt) - if frame_idx in obj_output_dict["cond_frame_outputs"]: - storage_key = "cond_frame_outputs" - current_out = obj_output_dict[storage_key][frame_idx] - device = inference_state["device"] - pred_masks = current_out["pred_masks"].to(device, non_blocking=True) + # ========================================================= + # 1) FRAME-LEVEL SKIP DECISION (runs once per frame) + # ========================================================= + reuse_this_frame = False + if ( + inference_state["last_processed_frame"] is not None + and all( + frame_idx not in inference_state["output_dict_per_obj"][i]["cond_frame_outputs"] + for i in range(batch_size) + ) + ): + prev = inference_state["last_processed_frame"] + curr = inference_state["images"][frame_idx].cpu() + mad = _mean_abs_diff(prev, curr) + if mad < self.skip_mad_threshold: + reuse_this_frame = True + skipped_ctr += 1 + print(f"Skipping frame {frame_idx} due to low MAD ({mad:.2f})") - prev_frame_idx = frame_idx - prev_pred_masks = pred_masks + pred_masks_per_obj = [None] * batch_size - # Optional: clear nearby non_cond memory - if self.clear_non_cond_mem_around_input: - self._clear_obj_non_cond_mem_around_input( - inference_state, frame_idx, obj_idx - ) + # ========================================================= + # 2) FAST PATH – reuse cached outputs for *all* objects + # ========================================================= + if reuse_this_frame: + cached_list = inference_state["last_processed_outputs"] # len == batch_size + device = inference_state["device"] + + for obj_idx, cached_out in enumerate(cached_list): + obj_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_dict["non_cond_frame_outputs"][frame_idx] = cached_out + inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { + "reverse": reverse + } + pred_masks_per_obj[obj_idx] = cached_out["pred_masks"].to( + device, non_blocking=True + ) - # 2) Already computed non-conditioning frame - elif frame_idx in obj_output_dict["non_cond_frame_outputs"]: - storage_key = "non_cond_frame_outputs" - current_out = obj_output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - - prev_frame_idx = frame_idx - prev_pred_masks = pred_masks - - # 3) New frame: decide skip vs full inference - else: - storage_key = "non_cond_frame_outputs" - - # --- a) Memory-guided skip? - if ( - self.skip_threshold is not None - and self.skip_threshold > 0 - and prev_frame_idx is not None - ): - prev_frame = inference_state["images"][prev_frame_idx].cpu() - curr_frame = inference_state["images"][frame_idx].cpu() - if should_skip_frame( - prev_frame, curr_frame, threshold=self.skip_threshold - ): - # Warp last mask and upsample - print(f"Skipping frame {frame_idx} (reuse frame {prev_frame_idx}), skip threshold {self.skip_threshold}") - warped_pred = warp_mask_forward( - prev_pred_masks.cpu(), - prev_frame.cpu().numpy(), - curr_frame.cpu().numpy(), + # ========================================================= + # 3) SLOW PATH – run the model once per object + # ========================================================= + else: + frame_outputs = [] + for obj_idx in range(batch_size): + obj_dict = inference_state["output_dict_per_obj"][obj_idx] + + # 3a) conditioning frame? + if frame_idx in obj_dict["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = obj_dict[storage_key][frame_idx] + device = inference_state["device"] + pred_masks = current_out["pred_masks"].to(device, non_blocking=True) + if self.clear_non_cond_mem_around_input: + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx ) - # --------------------------------------------- - # 2) Shallow‐copy prior full‐inference dict so - # maskmem_features (and other keys) exist - # --------------------------------------------- - prev_mem_out = obj_output_dict["non_cond_frame_outputs"].get(prev_frame_idx) - if prev_mem_out is None: - prev_mem_out = obj_output_dict["cond_frame_outputs"][prev_frame_idx] - obj_output_dict[storage_key][frame_idx] = dict(prev_mem_out) - obj_output_dict[storage_key][frame_idx]["pred_masks"] = warped_pred - - # Now the memory bank for frame_idx has both key and mask features. - - # Pack into (1,C,Hmask,Wmask) - wp = warped_pred - if wp.ndim == 2: - wp = wp.unsqueeze(0).unsqueeze(0) - elif wp.ndim == 3: - wp = wp.unsqueeze(0) - - # Upsample → (1,C,Hvideo,Wvideo) - video_h, video_w = inference_state["video_size"] - upsamp = F.interpolate( - wp, - size=(video_h, video_w), - mode="bilinear", - align_corners=False, - ).squeeze(0) # → (C,Hvideo,Wvideo) - - inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { - "reverse": reverse - } - - # Save full-res mask for concatenation - pred_masks_per_obj[obj_idx] = upsamp - - prev_frame_idx = frame_idx - prev_pred_masks = warped_pred - # done with this object - continue - - # --- b) Fallback: full SAM 2 inference on this object - current_out, pred_masks = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=obj_output_dict, - frame_idx=frame_idx, - batch_size=1, - is_init_cond_frame=False, - point_inputs=None, - mask_inputs=None, - reverse=reverse, - run_mem_encoder=True, - ) - obj_output_dict[storage_key][frame_idx] = current_out + # 3b) plain frame – heavy inference + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state, obj_dict, frame_idx, 1, + is_init_cond_frame=False, + point_inputs=None, mask_inputs=None, + reverse=reverse, run_mem_encoder=True + ) + obj_dict[storage_key][frame_idx] = current_out - prev_frame_idx = frame_idx - prev_pred_masks = pred_masks + inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { + "reverse": reverse + } + pred_masks_per_obj[obj_idx] = pred_masks + frame_outputs.append(current_out) - # Mark this object as tracked and record its prediction - inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { - "reverse": reverse - } - pred_masks_per_obj[obj_idx] = pred_masks + # ---- cache the full-frame outputs for the next frame ---- + inference_state["last_processed_frame"] = inference_state["images"][frame_idx].cpu() + inference_state["last_processed_outputs"] = frame_outputs - # --- After processing all objects for this frame --- - # Concatenate or select single-object masks + # ========================================================= + # 4) COMMON TAIL – resize masks & yield + # ========================================================= if batch_size > 1: all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) else: all_pred_masks = pred_masks_per_obj[0] - video_H, video_W = inference_state["video_size"] - - # 1) Move to CPU numpy - npm = all_pred_masks.detach().cpu().numpy() - # 2) Squeeze out any leading singleton batch dims until ndim <= 3 - while npm.ndim > 3: - npm = np.squeeze(npm, axis=0) - # 3) Ensure channel axis: if 2-D, add channel dim - if npm.ndim == 2: - npm = npm[np.newaxis, ...] # → shape (1, Hmask, Wmask) - - # Now npm.ndim == 3: (C, Hmask, Wmask) - C, Hm, Wm = npm.shape - # 4) Prepare output array - resized = np.zeros((C, video_H, video_W), dtype=npm.dtype) - # 5) Resize each channel - for c in range(C): - resized[c] = cv2.resize( - npm[c], - dsize=(video_W, video_H), - interpolation=cv2.INTER_LINEAR, - ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, all_pred_masks + ) - # 6) Convert back to torch.Tensor on the predictor’s device - device = inference_state["device"] - am_up = torch.from_numpy(resized).to(device) - # 7) Yield one mask-per-frame tensor - yield frame_idx, obj_ids, am_up + if frame_idx % 30 == 0 and frame_idx > 0: + dt = time.time() - t0 + print(f"Avg FPS last 30 frames: {30/dt:.2f}") + t0 = time.time() + + yield frame_idx, obj_ids, video_res_masks + + if skipped_ctr: + print(f"Skipped {skipped_ctr} frames due to low MAD.") @torch.inference_mode() def clear_all_prompts_in_frame( @@ -816,6 +774,9 @@ def reset_state(self, inference_state): inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() inference_state["frames_tracked_per_obj"].clear() + inference_state["last_processed_frame"] = None + inference_state["last_processed_outputs"] = None + def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos."""