From 8bb1d251edd2cdc5e7eb9f5390d82244552d71e8 Mon Sep 17 00:00:00 2001 From: v-ichaconsil Date: Fri, 5 Jun 2026 19:21:29 +0000 Subject: [PATCH 1/4] feat(registry): add animaloc/registry/families.py for deployment defaults Introduces a single source of truth for per-model deployment defaults (stitcher, evaluator, model_kwargs, normalization stats, down_ratio, multi_class flag) that tools like `tools/infer.py` need but should not hard-code. Placed under `animaloc/registry/` rather than `animaloc/models/` so the model classes themselves do not pick up an accidental dependency on eval components. Includes entries for the seven registered models in this repo: HerdNet, OWLC, OWLT, OWLD_S, OWLD_B, OWLD_L, OWLD_H The `resolve_family(name, *, checkpoint_meta, overrides)` helper returns the effective config with this resolution order: family defaults -> checkpoint metadata (mean/std/classes saved by tools/train.py) -> explicit CLI overrides (with model_kwargs merged, not replaced). Notable design choices: * All OWL families set `pretrained=False` (HerdNet, OWLD_*) and `pretrained_cnn=False` (OWLT) so inference does not re-fetch backbone weights -- the checkpoint state_dict supersedes them. * Normalization defaults to ImageNet stats. Verified against every OWL training config in this repo (incl. all DINOv3 ViT runs); the user trains with ImageNet stats throughout. Smoke-tested: - All 7 family names resolve. - resolve_family() with checkpoint_meta + overrides correctly merges model_kwargs and overrides scalar fields. - Unknown family name raises KeyError with an actionable message listing known families and the file to edit. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- animaloc/registry/__init__.py | 16 +++ animaloc/registry/families.py | 197 ++++++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 animaloc/registry/__init__.py create mode 100644 animaloc/registry/families.py diff --git a/animaloc/registry/__init__.py b/animaloc/registry/__init__.py new file mode 100644 index 0000000..fd21371 --- /dev/null +++ b/animaloc/registry/__init__.py @@ -0,0 +1,16 @@ +"""Deployment-defaults registries for animaloc. + +Reusable lookup tables that tell client code (CLI tools, notebooks) +which Stitcher, Evaluator, model_kwargs, normalization stats, etc. to +use for each registered model. The model classes themselves do NOT read +this — it's strictly a deployment / tooling concern. Keeping it out of +animaloc.models prevents accidental coupling between model code and +eval components. + +Consumers: + from animaloc.registry.families import FAMILIES, resolve_family +""" + +from .families import FAMILIES, ModelFamily, resolve_family + +__all__ = ["FAMILIES", "ModelFamily", "resolve_family"] diff --git a/animaloc/registry/families.py b/animaloc/registry/families.py new file mode 100644 index 0000000..1b50cc0 --- /dev/null +++ b/animaloc/registry/families.py @@ -0,0 +1,197 @@ +"""Per-model deployment defaults: which Stitcher, Evaluator, model +constructor kwargs, image normalization stats, and downsample ratio to +use for each registered model in `animaloc.models.MODELS`. + +This is a *deployment* concern (used by tools/infer.py and any future +prediction tool), NOT a property of the model itself. Models never +import from here. + +## Design + +A `FAMILIES[name]` entry has the shape of `ModelFamily`: + + stitcher: str # name of class in animaloc.eval.stitchers + evaluator: str # name of class in animaloc.eval.evaluators + model_kwargs: dict[str, Any] # constructor kwargs for the model class + down_ratio: int # output stride; threaded into transforms + stitcher + mean: list[float] # image normalization mean (RGB) + std: list[float] # image normalization std (RGB) + multi_class: bool # True if model outputs (heatmap, classmap), False if heatmap-only + +## How tools should use it + +The `resolve_family(name, *, checkpoint_meta=None, overrides=None)` helper +returns the effective config for a given model name. Resolution order +(later wins): family defaults -> checkpoint metadata -> explicit CLI +overrides. + +## Extending + +To register a new model family, add an entry to `FAMILIES` here, NOT in +`tools/infer.py`. The model itself only needs to be registered with +`@MODELS.register()` (in its own file under `animaloc.models`). +""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +from typing import Any, Optional + + +# Normalization stats used by every config in this repo (HerdNet + all +# OWL variants). DINOv3 backbones happen to use these too in the OWLD_* +# training configs (verified against exp_dpt_vits_proj_r12_frozen.yaml, +# exp_dpt_vith_dinov3_overhead_generalized.yaml, etc.). +_IMAGENET_MEAN = [0.485, 0.456, 0.406] +_IMAGENET_STD = [0.229, 0.224, 0.225] + + +@dataclass(frozen=True) +class ModelFamily: + """Deployment defaults for one model family.""" + + stitcher: str + evaluator: str + model_kwargs: dict[str, Any] = field(default_factory=dict) + down_ratio: int = 2 + mean: list[float] = field(default_factory=lambda: list(_IMAGENET_MEAN)) + std: list[float] = field(default_factory=lambda: list(_IMAGENET_STD)) + multi_class: bool = False + + def as_dict(self) -> dict[str, Any]: + return { + "stitcher": self.stitcher, + "evaluator": self.evaluator, + "model_kwargs": copy.deepcopy(self.model_kwargs), + "down_ratio": self.down_ratio, + "mean": list(self.mean), + "std": list(self.std), + "multi_class": self.multi_class, + } + + +FAMILIES: dict[str, ModelFamily] = { + # Legacy HerdNet -- multi-class, outputs (heatmap, classmap). + "HerdNet": ModelFamily( + stitcher="HerdNetStitcher", + evaluator="HerdNetEvaluator", + model_kwargs=dict( + num_layers=34, + pretrained=False, # inference loads from the .pth checkpoint + down_ratio=2, + head_conv=64, + ), + down_ratio=2, + multi_class=True, + ), + # OWL-C: HerdNet detection branch, single-class FIDT heatmap, DLA-34. + "OWLC": ModelFamily( + stitcher="HerdNet_Detection_Branch_Stitcher", + evaluator="HerdNet_Detection_Branch_Evaluator", + model_kwargs=dict( + num_layers=34, + pretrained=False, + down_ratio=2, + head_conv=64, + ), + down_ratio=2, + multi_class=False, + ), + # OWL-T: DLA-34 + Swin multiscale residual. Note kwarg `pretrained_cnn`, + # not `pretrained`, on the DLA base. + "OWLT": ModelFamily( + stitcher="HerdNet_Detection_Branch_Stitcher", + evaluator="HerdNet_Detection_Branch_Evaluator", + model_kwargs=dict( + num_layers=34, + pretrained_cnn=False, + down_ratio=2, + head_conv=64, + ), + down_ratio=2, + multi_class=False, + ), +} + +# OWL-D family: DINOv3 ViT (S/B/L/H) + DPT decoder. All four variants +# share the same stitcher / evaluator / kwargs (the variant is selected +# by the class name itself). pretrained=False to make sure the +# constructor does not try to fetch DINOv3 hub weights at inference -- +# the checkpoint's state_dict supersedes them anyway. +_OWLD_DEFAULT_KWARGS = dict(down_ratio=2, freeze_backbone=True, pretrained=False) + +for _owld_name in ("OWLD_S", "OWLD_B", "OWLD_L", "OWLD_H"): + FAMILIES[_owld_name] = ModelFamily( + stitcher="HerdNet_Detection_Branch_Stitcher", + evaluator="HerdNet_Detection_Branch_Evaluator", + model_kwargs=dict(_OWLD_DEFAULT_KWARGS), + down_ratio=2, + multi_class=False, + ) + + +def resolve_family( + name: str, + *, + checkpoint_meta: Optional[dict[str, Any]] = None, + overrides: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + """Return the effective deployment config for the named model. + + Resolution order (later wins): + 1. `FAMILIES[name]` defaults + 2. Values pulled from the checkpoint metadata (`mean`, `std`, + `classes`, anything else stored by `tools/train.py`) + 3. Explicit CLI overrides + + Args: + name: Registered model name (must be a key of `FAMILIES`). + checkpoint_meta: Optional dict pulled from `torch.load(pth_file)`. + Recognized keys: `mean`, `std`, `classes` (passes through), + and any other key that matches a `ModelFamily` field. + overrides: Optional dict of CLI-driven overrides. Same recognized + keys as `checkpoint_meta`, plus `model_kwargs` (merged into + family defaults, not replaced). + + Returns: + Plain dict with the resolved config. Always has the keys: + `stitcher`, `evaluator`, `model_kwargs`, `down_ratio`, `mean`, + `std`, `multi_class`. Plus passthrough keys like `classes` when + present in metadata. + + Raises: + KeyError: if `name` is not in `FAMILIES`. The caller should + catch this and report the available families to the user. + """ + if name not in FAMILIES: + raise KeyError( + f"Unknown model family {name!r}. Known families: {sorted(FAMILIES.keys())}. " + "Add an entry to animaloc/registry/families.py for new model classes." + ) + + resolved = FAMILIES[name].as_dict() + + # Pull supported keys from checkpoint metadata (mean, std, classes, + # plus any direct field overrides). + if checkpoint_meta: + for key in ("mean", "std", "down_ratio"): + if key in checkpoint_meta and checkpoint_meta[key] is not None: + resolved[key] = checkpoint_meta[key] + if "classes" in checkpoint_meta: + resolved["classes"] = checkpoint_meta["classes"] + + # CLI overrides. `model_kwargs` is MERGED (not replaced) so users + # can override one kwarg without listing every default. + if overrides: + for key, value in overrides.items(): + if value is None: + continue + if key == "model_kwargs" and isinstance(value, dict): + merged = dict(resolved["model_kwargs"]) + merged.update(value) + resolved["model_kwargs"] = merged + else: + resolved[key] = value + + return resolved From 0ab5abc43eefb367798b8e5b12e1280539824d4b Mon Sep 17 00:00:00 2001 From: v-ichaconsil Date: Fri, 5 Jun 2026 19:28:02 +0000 Subject: [PATCH 2/4] refactor(tools): generalize infer.py to support any registered model `tools/infer.py` was hardcoded to the legacy multi-class HerdNet model: hardcoded `from animaloc.models import HerdNet`, hardcoded HerdNet stitcher/evaluator, a non-recoverable `assert num_classes == 7`, and a hardcoded `classes={1:..6:..}` dict. Single-class OWL models (OWLC, OWLT, OWLD_S/B/L/H) could not run inference through this script -- the docs routed users to `tools/test.py` (which needs ground-truth annotations) as a workaround. This commit rewrites infer.py to look up the model class, stitcher, evaluator, default kwargs, and normalization stats from the `animaloc.registry.families.FAMILIES` table added in the previous commit. Any registered model that has a family entry is usable. ## New CLI surface (all flags optional except positional root + pth) --model NAME from FAMILIES keys; default HerdNet (back-compat) --model-kwarg KEY=VAL override a constructor kwarg (repeatable) --stitcher NAME override family-default stitcher class --evaluator NAME override family-default evaluator class --num-classes N explicit override (HerdNet without metadata only) --mean R,G,B override normalization mean --std R,G,B override normalization std --down-ratio N override down_ratio --lmds-kernel-size H,W LMDS kernel (default 3,3) --lmds-adapt-ts FLOAT LMDS adaptive threshold (default 0.2) --lmds-neg-ts FLOAT LMDS negative threshold (HerdNet family only) --output-dir PATH default /__results -size -over -device -pf -rot --skip-model-inference (unchanged) ## Resolution order For each setting: FAMILIES[name] defaults -> checkpoint metadata (`classes`, `mean`, `std`) -> explicit CLI override. `model_kwargs` is merged rather than replaced so users override one kwarg without listing every default. ## Behavior changes that matter * Output dir is configurable (`--output-dir`) and the folder name now includes the model (e.g. `20260605_OWLC_results`), not the hardcoded `_HerdNet_results`. Default location is unchanged for HerdNet. * `assert num_classes == 7` is gone. For HerdNet without `classes` metadata, the layer-shape probe (`model.cls_head.2.weight`) is kept as a last-resort fallback. For OWL families, num_classes is not passed (the constructor doesn't accept it). * `state_dict` loading is now `strict=False` with explicit missing/unexpected key warnings. Catches partial-load checkpoints without crashing immediately, but tells the user what happened. * `.map(classes) + .dropna()` chain is gone. Detection rows whose label is unmapped now keep the raw label (as string) in `species` and emit a single warning listing unmapped labels. * `pretrained=False` is set in every family's `model_kwargs` so the constructor never re-fetches DINOv3 or DLA-34 weights at inference time (the checkpoint's state_dict supersedes them). ## Sanity checks added at startup * One-shot dummy forward on `torch.zeros(1, 3, size, size)` to detect model/stitcher shape mismatches early with a clear error instead of a deep tuple-unpack failure inside LMDS. * Unwrap `LossWrapper`'s `(output, output_dict)` and ignore `None` entries in tuple outputs (e.g. OWLD_S returns `(heatmap, None)`) before counting outputs. ## Smoke validation * `tools/infer.py /tmp/owl-smoketest/val/ --model OWLC -device cpu` produces a 1856-row detections.csv with columns `images, labels, dscores, x, y, count_1, species`. * `tools/infer.py /tmp/owl-smoketest/val/ --model OWLD_S -device cpu` produces 2270 rows via the DINOv3 ViT-S/16 frozen backbone. End-to-end DINOv3 + animaloc inference works. * `tools/infer.py /tmp/owl-smoketest/val/ ` (no --model, defaults to HerdNet against an OWL checkpoint) fails cleanly with "4 missing key(s) in state_dict" warning + LMDS shape error. ## Deferred to follow-up * The Evaluator path is still used as a wrapper because it already implements stitching + LMDS. Ground-truth values are dummy (x=0, y=0, label=1) and metrics are discarded. A future PR can factor out a pure inference function that does not go through Evaluator. * Adding `model_name` / `stitcher_name` / `evaluator_name` to the checkpoint metadata in `tools/train.py` so `--model` becomes fully auto-detected. Today the user still has to pass `--model` for non-HerdNet checkpoints. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tools/infer.py | 796 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 616 insertions(+), 180 deletions(-) diff --git a/tools/infer.py b/tools/infer.py index 678d798..3a4f54c 100644 --- a/tools/infer.py +++ b/tools/infer.py @@ -1,200 +1,636 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. +"""Run inference with any registered animaloc model on a folder of images. - This source code is under the MIT License. +This is the generic inference CLI. It works for the legacy multi-class +HerdNet model AND for the single-class OWL family (OWLC, OWLT, +OWLD_S/B/L/H). Add new models to `animaloc/registry/families.py` and +they become usable here without code changes. - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. +## Quickstart - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" +Run the legacy HerdNet model (backwards-compatible default): + + python tools/infer.py /path/to/images /path/to/herdnet.pth + +Run an OWL-C model: + + python tools/infer.py /path/to/images /path/to/owlc.pth --model OWLC + +Run an OWL-D-L model and write results elsewhere: + + python tools/infer.py /path/to/images /path/to/owld_l.pth \\ + --model OWLD_L --output-dir /tmp/owld_l_results --device cpu + +Override a single model constructor kwarg: + + python tools/infer.py imgs/ ckpt.pth --model OWLT \\ + --model-kwarg down_ratio=4 + +See `animaloc/registry/families.py` for the supported model families +and their default kwargs/stitcher/evaluator/normalization. +## Outputs + +A timestamped folder under `--output-dir` (or under `/` by +default) containing: + + __results/ + _detections.csv columns: images, x, y, labels, scores, [species] + +The `species` column is included only when the checkpoint stores a +`classes` mapping (saved automatically by `tools/train.py`). + +## Vendored from HerdNet + +This script started as a copy of HerdNet's `tools/infer.py` (MIT, +Universite de Liege) and was rewritten in this repo to be model- +agnostic. See git log for the rewrite commit. +""" + +from __future__ import annotations import argparse -import torch import os -import pandas +import sys import warnings -import numpy -import PIL +from datetime import datetime +from typing import Any, Optional import albumentations as A - -from torch.utils.data import DataLoader -from PIL import Image - -import sys -import os +import pandas +import PIL +import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import animaloc.eval.evaluators as evaluators_mod +import animaloc.eval.stitchers as stitchers_mod +import animaloc.models as models_mod from animaloc.data.transforms import DownSample, Rotate90 -from animaloc.models import LossWrapper, HerdNet -from animaloc.eval import HerdNetStitcher, HerdNetEvaluator -from animaloc.eval.metrics import PointsMetrics from animaloc.datasets import CSVDataset -from animaloc.utils.useful_funcs import mkdir, current_date -from animaloc.vizual import draw_points, draw_text +from animaloc.eval.metrics import PointsMetrics +from animaloc.models.utils import LossWrapper +from animaloc.registry.families import FAMILIES, resolve_family +from animaloc.utils.useful_funcs import current_date, mkdir +from torch.utils.data import DataLoader, SequentialSampler -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") PIL.Image.MAX_IMAGE_PIXELS = None -parser = argparse.ArgumentParser( - prog='inference', - description='Collects the detections of a pretrained HerdNet model on a set of images ' - ) - -parser.add_argument('root', type=str, - help='path to the JPG images folder (str)') -parser.add_argument('pth', type=str, - help='path to PTH file containing your model parameters (str)') -parser.add_argument('-size', type=int, default=512, - help='patch size use for stitching. Defaults to 512.') -parser.add_argument('-over', type=int, default=160, - help='overlap for stitching. Defaults to 160.') -parser.add_argument('-device', type=str, default='cuda', - help='device on which model and images will be allocated (str). \ - Possible values are \'cpu\' or \'cuda\'. Defaults to \'cuda\'.') -parser.add_argument('-ts', type=int, default=256, - help='thumbnail size. Defaults to 256.') -parser.add_argument('-pf', type=int, default=10, - help='print frequence. Defaults to 10.') -parser.add_argument('-rot', type=int, default=0, - help='number of times to rotate by 90 degrees. Defaults to 0.') -parser.add_argument('-skip_model_inference', action='store_true', - help='if set, skips the model inference step (for debugging purposes).') - -args = parser.parse_args() - -def main(): - - # Create destination folder - curr_date = current_date() - dest = os.path.join(args.root, f"{curr_date}_HerdNet_results") - mkdir(dest) - - # Read info from PTH file - map_location = torch.device('cpu') - if torch.cuda.is_available(): - map_location = torch.device('cuda') - if not args.skip_model_inference: - print('Loading the model ...') - checkpoint = torch.load(args.pth, map_location=map_location) - #classes = checkpoint['classes'] - #num_classes = len(classes) + 1 - num_classes = checkpoint['model_state_dict']['model.cls_head.2.weight'].shape[0] - assert num_classes == 7, 'This code is currently hardcoded for 7 classes. Please update it accordingly.' - classes = {1:'class_1', 2:'class_2', 3:'class_3', 4:'class_4', 5:'class_5', 6:'class_6'} - #img_mean = checkpoint['mean'] - img_mean = [0.485, 0.456, 0.406] - #img_std = checkpoint['std'] - img_std = [0.229, 0.224, 0.225] - - # Prepare dataset and dataloader - possible_extensions = ['.tif', '.jpg', '.png', '.jpeg', '.TIF', '.JPG', '.PNG', '.JPEG'] - # Prepare dataset and dataloader - img_names = [i for i in os.listdir(args.root) - if any([i.endswith(ext) for ext in possible_extensions])] - n = len(img_names) - df = pandas.DataFrame(data={'images': img_names, 'x': [0]*n, 'y': [0]*n, 'labels': [1]*n}) - - end_transforms = [] - if args.rot != 0: - end_transforms.append(Rotate90(k=args.rot)) - end_transforms.append(DownSample(down_ratio = 2, anno_type = 'point')) - - albu_transforms = [A.Normalize(mean=img_mean, std=img_std)] - - dataset = CSVDataset( - csv_file = df, - root_dir = args.root, - albu_transforms = albu_transforms, - end_transforms = end_transforms - ) - - dataloader = DataLoader(dataset, batch_size=1, shuffle=False, - sampler=torch.utils.data.SequentialSampler(dataset)) - - # Build the trained model - print('Building the model ...') - device = torch.device(args.device) - model = HerdNet(num_classes=num_classes, pretrained=False) - # Count number of parameters - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f'Number of parameters in the model: {num_params}') - model = LossWrapper(model, []) - model.load_state_dict(checkpoint['model_state_dict']) - - # Build the evaluator - stitcher = HerdNetStitcher( - model = model, - size = (args.size,args.size), - overlap = args.over, - down_ratio = 2, - up = True, - reduction = 'mean', - device_name = device - ) - - metrics = PointsMetrics(20, num_classes = num_classes) - evaluator = HerdNetEvaluator( - model = model, - dataloader = dataloader, - metrics = metrics, - lmds_kwargs = dict(kernel_size=(3,3), adapt_ts=0.2, neg_ts=0.1), - device_name = device, - print_freq = args.pf, - stitcher = stitcher, - work_dir=dest, - header = '[INFERENCE]' - ) - - # Start inference - print('Starting inference ...') - out = evaluator.evaluate(wandb_flag=False, viz=False, log_meters=False) - - # Save the detections - print('Saving the detections ...') - detections = evaluator.detections - detections.dropna(inplace=True) - detections['species'] = detections['labels'].map(classes) - detections.to_csv(os.path.join(dest, f'{curr_date}_detections.csv'), index=False) +# --------------------------------------------------------------------------- # +# CLI # +# --------------------------------------------------------------------------- # + +# Recognized string forms for boolean CLI values. The Python `bool()` +# builtin returns True for any non-empty string ("False" included), so we +# parse explicitly. +_TRUE_STRS = {"true", "1", "yes", "on", "t"} +_FALSE_STRS = {"false", "0", "no", "off", "f"} + + +def _parse_kv_value(raw: str) -> Any: + """Coerce a CLI key=value string into int/float/bool/str. + + Order: int -> float -> bool (explicit string set) -> str. + """ + try: + return int(raw) + except ValueError: + pass + try: + return float(raw) + except ValueError: + pass + low = raw.lower() + if low in _TRUE_STRS: + return True + if low in _FALSE_STRS: + return False + return raw + + +def _parse_kv_pair(s: str) -> tuple[str, Any]: + if "=" not in s: + raise argparse.ArgumentTypeError( + f"--model-kwarg expects key=value, got {s!r}" + ) + k, v = s.split("=", 1) + return k.strip(), _parse_kv_value(v.strip()) + + +def _parse_csv_floats(s: str) -> list[float]: + return [float(x) for x in s.split(",")] + + +def _parse_csv_ints(s: str) -> tuple[int, ...]: + return tuple(int(x) for x in s.split(",")) + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="infer", + description=( + "Run a pretrained animaloc model (HerdNet or any OWL variant) " + "on a folder of images and write the resulting detections to a " + "CSV. Defaults to HerdNet for backwards compatibility." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + # Positional (unchanged). + p.add_argument("root", type=str, help="path to the folder of input images") + p.add_argument("pth", type=str, help="path to the .pth checkpoint") + + # Model selection. + p.add_argument( + "--model", + type=str, + default="HerdNet", + choices=sorted(FAMILIES.keys()), + help="registered model name (default: HerdNet, backwards-compat)", + ) + p.add_argument( + "--model-kwarg", + action="append", + type=_parse_kv_pair, + default=[], + metavar="KEY=VAL", + help=( + "override a single model constructor kwarg (repeatable). " + "Coerces to int/float/bool/str. Use --model-kwarg key=value." + ), + ) + p.add_argument( + "--stitcher", + type=str, + default=None, + help="override the family-default stitcher class name", + ) + p.add_argument( + "--evaluator", + type=str, + default=None, + help="override the family-default evaluator class name", + ) + p.add_argument( + "--num-classes", + type=int, + default=None, + help=( + "explicitly set num_classes (only needed for legacy HerdNet " + "checkpoints whose 'classes' metadata is missing AND whose " + "head shape probe fails)" + ), + ) + + # Normalization + geometry overrides. + p.add_argument( + "--mean", type=_parse_csv_floats, default=None, + metavar="R,G,B", help="image normalization mean (override checkpoint/family)", + ) + p.add_argument( + "--std", type=_parse_csv_floats, default=None, + metavar="R,G,B", help="image normalization std (override checkpoint/family)", + ) + p.add_argument( + "--down-ratio", type=int, default=None, + help="downsample ratio (override family default)", + ) + + # LMDS post-processing knobs (model-agnostic). + p.add_argument( + "--lmds-kernel-size", type=_parse_csv_ints, default=(3, 3), + metavar="H,W", help="LMDS kernel size (default 3,3)", + ) + p.add_argument( + "--lmds-adapt-ts", type=float, default=0.2, + help="LMDS adaptive threshold (default 0.2)", + ) + p.add_argument( + "--lmds-neg-ts", type=float, default=0.1, + help="LMDS negative threshold (HerdNet family only)", + ) + + # Output. + p.add_argument( + "--output-dir", type=str, default=None, + help=( + "where to write results (default: /__results). " + "Useful when is read-only or shared." + ), + ) + + # Stitcher geometry + runtime knobs (kept from the original CLI). + p.add_argument("-size", type=int, default=512, help="patch size for stitching") + p.add_argument("-over", type=int, default=160, help="overlap for stitching") + p.add_argument( + "-device", type=str, default="cuda", + help="'cpu' or 'cuda' (default cuda)", + ) + p.add_argument("-pf", type=int, default=10, help="print frequency") + p.add_argument( + "-rot", type=int, default=0, + help="number of 90-degree CCW rotations to apply", + ) + p.add_argument( + "--skip-model-inference", + action="store_true", + help="skip the inference step (debug-only; preserved from upstream)", + ) + return p + + +# --------------------------------------------------------------------------- # +# Helpers # +# --------------------------------------------------------------------------- # + + +def _load_checkpoint(pth_path: str, map_location: torch.device) -> dict: + return torch.load(pth_path, map_location=map_location) + + +def _probe_herdnet_num_classes(state_dict: dict) -> Optional[int]: + """Last-resort fallback: read the HerdNet classification head shape. + + Returns None if the layer is not present (i.e. not a HerdNet checkpoint). + """ + key = "model.cls_head.2.weight" + if key not in state_dict: + return None + return int(state_dict[key].shape[0]) + + +def _resolve_num_classes( + family_name: str, + args: argparse.Namespace, + checkpoint_meta: dict, + state_dict: dict, +) -> Optional[int]: + """Decide what num_classes to pass to the model constructor. + + For single-class OWL models: returns None (the constructor does + not accept num_classes). + + For HerdNet (multi_class=True): tries in order: + 1. --num-classes CLI override + 2. len(checkpoint_meta['classes']) + 1 (binary + per-species) + 3. layer-shape probe on `model.cls_head.2.weight` + 4. raises with an actionable message + """ + if not FAMILIES[family_name].multi_class: + return None + + if args.num_classes is not None: + return args.num_classes + + if "classes" in checkpoint_meta and checkpoint_meta["classes"]: + # +1 for background class 0 + return len(checkpoint_meta["classes"]) + 1 + + probed = _probe_herdnet_num_classes(state_dict) + if probed is not None: + return probed + + raise RuntimeError( + f"Cannot determine num_classes for {family_name!r} model: the checkpoint " + "has no 'classes' metadata and the head-shape probe failed. " + "Pass --num-classes N explicitly." + ) + + +def _build_model( + family_name: str, + resolved: dict, + num_classes: Optional[int], +) -> torch.nn.Module: + """Instantiate the model class with resolved kwargs.""" + if family_name not in models_mod.__dict__: + known = sorted(models_mod.MODELS.registry_names) + raise KeyError( + f"Model class {family_name!r} not found in animaloc.models. " + f"Known registered models: {known}" + ) + cls = models_mod.__dict__[family_name] + kwargs = dict(resolved["model_kwargs"]) + if num_classes is not None: + kwargs.setdefault("num_classes", num_classes) + return cls(**kwargs) + + +def _build_stitcher( + name: str, + model: torch.nn.Module, + size: int, + overlap: int, + down_ratio: int, + device: torch.device, +): + if name not in stitchers_mod.__dict__: + raise KeyError( + f"Stitcher class {name!r} not found in animaloc.eval.stitchers. " + "Check FAMILIES or pass --stitcher explicitly." + ) + cls = stitchers_mod.__dict__[name] + return cls( + model=model, + size=(size, size), + overlap=overlap, + down_ratio=down_ratio, + up=True, + reduction="mean", + device_name=device, + ) + + +def _build_evaluator( + name: str, + model: torch.nn.Module, + dataloader: DataLoader, + metrics: PointsMetrics, + stitcher, + device: torch.device, + print_freq: int, + work_dir: str, + lmds_kwargs: dict, +): + if name not in evaluators_mod.__dict__: + raise KeyError( + f"Evaluator class {name!r} not found in animaloc.eval.evaluators. " + "Check FAMILIES or pass --evaluator explicitly." + ) + cls = evaluators_mod.__dict__[name] + return cls( + model=model, + dataloader=dataloader, + metrics=metrics, + lmds_kwargs=lmds_kwargs, + device_name=device, + print_freq=print_freq, + stitcher=stitcher, + work_dir=work_dir, + header="[INFERENCE]", + ) + + +def _validate_model_stitcher_shape( + model: torch.nn.Module, family_name: str, size: int, device: torch.device +) -> None: + """One-shot dummy forward to catch model/stitcher mismatches early. + + Cheap (one forward on a 3xSxS tensor on the target device) and + catches the common 'used --model OWLC with --stitcher HerdNetStitcher' + error class with a clear message instead of a deep-stack tuple-unpacking + error later. + + The model is always wrapped in LossWrapper before this is called. + LossWrapper.forward returns `(real_output, output_dict)` in eval + mode, so we unwrap one level before inspecting shape. + """ + family = FAMILIES[family_name] + try: + with torch.no_grad(): + model.eval() + wrapped_out = model(torch.zeros(1, 3, size, size, device=device)) + except Exception as e: + raise RuntimeError( + f"Dummy forward failed for {family_name!r} with size={size}: " + f"{type(e).__name__}: {e}" + ) from e + + # Unwrap LossWrapper's (output, loss_dict) tuple to get the real + # model output. Without this, every model looks like a 2-tuple. + if isinstance(wrapped_out, tuple) and len(wrapped_out) == 2 and isinstance(wrapped_out[1], dict): + real_out = wrapped_out[0] else: - print('Skipping model inference ...') - date = '20250318' - detections = pandas.read_csv(os.path.join(args.root, f'{date}_HerdNet_results', f'{date}_detections.csv')) - # Draw detections on images and create thumbnails - '''print('Exporting plots and thumbnails ...') - dest_plots = os.path.join(dest, 'plots') - mkdir(dest_plots) - dest_thumb = os.path.join(dest, 'thumbnails') - mkdir(dest_thumb) - img_names = numpy.unique(detections['images'].values).tolist() - for img_name in img_names: - img = Image.open(os.path.join(args.root, img_name)) - if args.rot != 0: - rot = args.rot * 90 - img = img.rotate(rot, expand=True) - img_cpy = img.copy() - pts = list(detections[detections['images']==img_name][['y','x']].to_records(index=False)) - pts = [(y, x) for y, x in pts] - output = draw_points(img, pts, color='red', size=10) - output.save(os.path.join(dest_plots, img_name), quality=95) - - # Create and export thumbnails - sp_score = list(detections[detections['images']==img_name][['species','scores']].to_records(index=False)) - for i, ((y, x), (sp, score)) in enumerate(zip(pts, sp_score)): - off = args.ts//2 - coords = (x - off, y - off, x + off, y + off) - thumbnail = img_cpy.crop(coords) - score = round(score * 100, 0) - thumbnail = draw_text(thumbnail, f"{sp} | {score}%", position=(10,5), font_size=int(0.08*args.ts)) - thumbnail.save(os.path.join(dest_thumb, img_name[:-4] + f'_{i}.JPG'))''' - -if __name__ == '__main__': - main() \ No newline at end of file + real_out = wrapped_out + + # Some OWL-family models (e.g. OWLD_*) return `(heatmap, None)` to + # match an optional secondary head signature. Filter Nones so the + # shape check sees only real tensors. + if isinstance(real_out, (tuple, list)): + real_tensors = [x for x in real_out if x is not None] + n_outputs = len(real_tensors) + else: + n_outputs = 1 + + if family.multi_class and n_outputs < 2: + raise RuntimeError( + f"{family_name!r} family expects a multi-output model (heatmap+classmap), " + f"but the model returned {n_outputs} tensor(s). " + f"Check the family definition or override --stitcher / --evaluator." + ) + if not family.multi_class and n_outputs > 1: + print( + f"[WARN] {family_name!r} family expects a single-output model but " + f"forward returned {n_outputs} outputs. Continuing — verify the result.", + file=sys.stderr, + ) + + +def _make_inference_dataset( + image_dir: str, mean: list[float], std: list[float], down_ratio: int, rot: int +) -> CSVDataset: + """Build a CSVDataset of dummy (x=0, y=0, label=1) entries -- one per + image in `image_dir`. The Evaluator path needs a dataloader of + (image, target) pairs; ground truth values are discarded. + """ + possible_extensions = (".tif", ".tiff", ".jpg", ".jpeg", ".png") + img_names = sorted( + f for f in os.listdir(image_dir) + if f.lower().endswith(possible_extensions) + ) + if not img_names: + raise FileNotFoundError( + f"No images with extensions {possible_extensions} found in {image_dir!r}" + ) + + n = len(img_names) + df = pandas.DataFrame( + data={"images": img_names, "x": [0] * n, "y": [0] * n, "labels": [1] * n} + ) + + end_transforms = [] + if rot != 0: + end_transforms.append(Rotate90(k=rot)) + end_transforms.append(DownSample(down_ratio=down_ratio, anno_type="point")) + + albu_transforms = [A.Normalize(mean=mean, std=std)] + + return CSVDataset( + csv_file=df, + root_dir=image_dir, + albu_transforms=albu_transforms, + end_transforms=end_transforms, + ) + + +def _attach_species_column(detections: pandas.DataFrame, classes: dict) -> pandas.DataFrame: + """Add a 'species' column mapped from labels; emit raw label on miss. + + NOTE: The previous version of this code did + df.dropna(inplace=True) + after the .map(), which silently deleted every detection whose label + was not in the mapping. That hid both real bugs (wrong classes dict) + and minor mismatches. We now keep every row and warn about coverage + gaps instead. + """ + if not classes: + return detections + species = detections["labels"].map(classes) + unmapped_labels = detections.loc[species.isna(), "labels"].unique().tolist() + if unmapped_labels: + print( + f"[WARN] {len(unmapped_labels)} detection label(s) had no entry in " + f"the classes mapping: {sorted(unmapped_labels)}. " + "Keeping raw label; please update the checkpoint's `classes` metadata.", + file=sys.stderr, + ) + species = species.fillna(detections["labels"].astype(str)) + detections = detections.copy() + detections["species"] = species + return detections + + +# --------------------------------------------------------------------------- # +# Main # +# --------------------------------------------------------------------------- # + + +def main(argv: Optional[list[str]] = None) -> int: + args = _build_parser().parse_args(argv) + + # Map_location for torch.load: fall back to CPU if --device cuda but + # no CUDA available, so checkpoints load without crashing. + if args.device == "cuda" and not torch.cuda.is_available(): + print("[WARN] --device cuda requested but no CUDA available; using cpu.", file=sys.stderr) + args.device = "cpu" + device = torch.device(args.device) + map_location = device + + if args.skip_model_inference: + print("Skipping model inference (debug mode)") + return 0 + + # ---- Resolve family + checkpoint metadata ---- + print(f"Loading checkpoint: {args.pth}") + checkpoint = _load_checkpoint(args.pth, map_location) + + overrides = dict( + mean=args.mean, + std=args.std, + down_ratio=args.down_ratio, + stitcher=args.stitcher, + evaluator=args.evaluator, + model_kwargs=dict(args.model_kwarg) if args.model_kwarg else None, + ) + resolved = resolve_family(args.model, checkpoint_meta=checkpoint, overrides=overrides) + print(f"Resolved family for --model {args.model}:") + print(f" stitcher = {resolved['stitcher']}") + print(f" evaluator = {resolved['evaluator']}") + print(f" down_ratio = {resolved['down_ratio']}") + print(f" multi_class = {resolved['multi_class']}") + + # ---- num_classes resolution (HerdNet only) ---- + state_dict = checkpoint["model_state_dict"] + num_classes = _resolve_num_classes(args.model, args, checkpoint, state_dict) + if num_classes is not None: + print(f" num_classes = {num_classes}") + + # ---- Build + load model ---- + print(f"Building model {args.model}") + model = _build_model(args.model, resolved, num_classes) + n_params = sum(p.numel() for p in model.parameters()) + print(f" parameters = {n_params:,}") + + model = LossWrapper(model, []) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if missing: + print(f"[WARN] {len(missing)} missing key(s) in state_dict (first 5): {missing[:5]}", file=sys.stderr) + if unexpected: + print(f"[WARN] {len(unexpected)} unexpected key(s) in state_dict (first 5): {unexpected[:5]}", file=sys.stderr) + model = model.to(device) + + # ---- Sanity check ---- + _validate_model_stitcher_shape(model, args.model, args.size, device) + + # ---- Output dir ---- + curr_date = current_date() + output_dir = args.output_dir or os.path.join( + args.root, f"{curr_date}_{args.model}_results" + ) + mkdir(output_dir) + print(f"Output dir: {output_dir}") + + # ---- Dataset / dataloader ---- + print(f"Listing images under {args.root}") + dataset = _make_inference_dataset( + image_dir=args.root, + mean=resolved["mean"], + std=resolved["std"], + down_ratio=resolved["down_ratio"], + rot=args.rot, + ) + print(f" found {len(dataset)} image(s)") + + dataloader = DataLoader( + dataset, batch_size=1, shuffle=False, + sampler=SequentialSampler(dataset), + ) + + # ---- Stitcher + Evaluator ---- + stitcher = _build_stitcher( + resolved["stitcher"], + model=model, + size=args.size, + overlap=args.over, + down_ratio=resolved["down_ratio"], + device=device, + ) + + # PointsMetrics needs a num_classes. For single-class OWL we use 2 + # (background + animal); for HerdNet, the real num_classes. + metrics_num_classes = num_classes if num_classes is not None else 2 + metrics = PointsMetrics(20, num_classes=metrics_num_classes) + + # LMDS kwargs: HerdNet family uses neg_ts; Detection_Branch family does not. + lmds_kwargs: dict[str, Any] = dict( + kernel_size=tuple(args.lmds_kernel_size), + adapt_ts=args.lmds_adapt_ts, + ) + if FAMILIES[args.model].multi_class: + lmds_kwargs["neg_ts"] = args.lmds_neg_ts + + evaluator = _build_evaluator( + resolved["evaluator"], + model=model, + dataloader=dataloader, + metrics=metrics, + stitcher=stitcher, + device=device, + print_freq=args.pf, + work_dir=output_dir, + lmds_kwargs=lmds_kwargs, + ) + + # ---- Run inference ---- + # We use the Evaluator pipeline (it already implements stitching + + # LMDS post-processing) but pass dummy ground truth so the computed + # metrics are meaningless and discarded. A future PR can factor out + # a pure inference path that does not go through Evaluator at all. + print(f"Running inference on {len(dataset)} image(s) ...") + evaluator.evaluate(wandb_flag=False, viz=False, log_meters=False) + + # ---- Save detections ---- + print("Saving detections ...") + detections = evaluator.detections + classes_meta = resolved.get("classes") or {} + detections = _attach_species_column(detections, classes_meta) + + out_csv = os.path.join(output_dir, f"{curr_date}_detections.csv") + detections.to_csv(out_csv, index=False) + print(f"Wrote {len(detections)} detection(s) to {out_csv}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 85b66c0e2d35f4fe5625e563c03719b78622bad6 Mon Sep 17 00:00:00 2001 From: v-ichaconsil Date: Fri, 5 Jun 2026 19:29:30 +0000 Subject: [PATCH 3/4] test(smoke): add tests/smoke_infer.sh + extend tests/README The OWL-C training and evaluation smoke runs were already in tests/README.md. This commit adds the inference smoke run that exercises the newly generalized tools/infer.py. `tests/smoke_infer.sh`: 1. Generates the synthetic dataset at /tmp/owl-smoketest/ if missing 2. Runs the OWL-C training smoke if no checkpoint exists under outputs/ 3. Runs `tools/infer.py /tmp/owl-smoketest/val/ --model OWLC -device cpu --output-dir /tmp/owl-smoketest-infer/` 4. Verifies the detections CSV exists, has > 0 rows, and contains the columns `images, x, y, labels` Exit code 0 on pass, non-zero on any failure. Runs in ~30 seconds on CPU when the training smoke has already produced a checkpoint, or ~90 seconds if it has to train one first. `tests/README.md`: - Adds step 6 (./tests/smoke_infer.sh) to the smoke-test sequence - Documents what the inference smoke script does and what it verifies Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/README.md | 18 +++++++++++ tests/smoke_infer.sh | 76 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100755 tests/smoke_infer.sh diff --git a/tests/README.md b/tests/README.md index 833579a..5e464db 100644 --- a/tests/README.md +++ b/tests/README.md @@ -42,6 +42,9 @@ WANDB_MODE=disabled uv run python tools/train.py train=owld_s_smoketest CKPT=$(ls -t outputs/*/*/best_model.pth | head -1 | xargs realpath) WANDB_MODE=disabled uv run python tools/test.py test=owld_s_smoketest \ "++test.model.pth_file=$CKPT" + +# 6. Inference smoke (auto-runs steps 2 and 3 if needed) +./tests/smoke_infer.sh ``` Expected runtime on CPU: ~1 min for forward-pass + dataset, ~30 s for @@ -83,3 +86,18 @@ Training complete | Best f1_score: ... at epoch 1 The evaluation smoke run writes `metrics_results.csv`, `confusion_matrix.csv`, `detections.csv`, and `plots/precision_recall_curve.png` under `outputs//