diff --git a/autoware_ml/hooks/__init__.py b/autoware_ml/hooks/__init__.py index 62666b52e..ada9b7621 100644 --- a/autoware_ml/hooks/__init__.py +++ b/autoware_ml/hooks/__init__.py @@ -6,6 +6,8 @@ PytorchTrainingProfilerHook, PytorchValidationProfilerHook, ) +from .t4_seg_logger_hook import T4SegLoggerHook +from .t4_seg_tensorboard_hook import T4SegTensorboardHook __all__ = [ "MomentumInfoHook", @@ -14,4 +16,6 @@ "PytorchValidationProfilerHook", "LossScaleInfoHook", "LoggerHook", + "T4SegLoggerHook", + "T4SegTensorboardHook", ] diff --git a/autoware_ml/hooks/t4_seg_logger_hook.py b/autoware_ml/hooks/t4_seg_logger_hook.py new file mode 100644 index 000000000..2c3fa5d12 --- /dev/null +++ b/autoware_ml/hooks/t4_seg_logger_hook.py @@ -0,0 +1,13 @@ +from mmengine.registry import HOOKS + +from .logger_hook import LoggerHook + + +@HOOKS.register_module() +class T4SegLoggerHook(LoggerHook): + """Logger hook for T4 segmentation configs using custom TensorBoard metric tags.""" + + def after_val_epoch(self, runner, metrics=None) -> None: + """Log validation results without the default TensorBoard scalar dump.""" + _, log_str = runner.log_processor.get_log_after_epoch(runner, len(runner.val_dataloader), "val") + runner.logger.info(log_str) diff --git a/autoware_ml/hooks/t4_seg_tensorboard_hook.py b/autoware_ml/hooks/t4_seg_tensorboard_hook.py new file mode 100644 index 000000000..1f632f863 --- /dev/null +++ b/autoware_ml/hooks/t4_seg_tensorboard_hook.py @@ -0,0 +1,63 @@ +import matplotlib.pyplot as plt +from mmengine.hooks import Hook +from mmengine.registry import HOOKS +from mmengine.visualization import Visualizer + +from autoware_ml.segmentation3d.evaluation import ( + T4SegMetric, + build_t4_seg_tb_scalars, + figure_to_numpy, + iter_t4_seg_confusion_matrix_figures, +) + + +@HOOKS.register_module() +class T4SegTensorboardHook(Hook): + """Log shared T4 segmentation TensorBoard tags for MMEngine runners.""" + + priority = "LOW" + + def after_val_epoch(self, runner, metrics=None): + self._log_stage(runner, stage="val", step=runner.iter) + + def after_test_epoch(self, runner, metrics=None): + self._log_stage(runner, stage="test", step=0) + + def _log_stage(self, runner, stage: str, step: int) -> None: + metric = self._get_metric(runner, stage) + if metric is None or metric.last_eval_result is None: + return + + try: + vis = Visualizer.get_current_instance() + except Exception: + return + + class_names = [metric.last_label2cat[i] for i in sorted(metric.last_label2cat)] + scalars = build_t4_seg_tb_scalars( + metrics=metric.last_eval_result.metrics, + class_names=class_names, + stage=stage, + distance_ranges=metric.distance_ranges, + ) + if scalars: + vis.add_scalars(scalars, step=step) + + for tag, fig in iter_t4_seg_confusion_matrix_figures(metric.last_eval_result, class_names, stage): + try: + vis.add_image(tag, figure_to_numpy(fig), step=step) + except Exception: + pass + finally: + plt.close(fig) + + @staticmethod + def _get_metric(runner, stage: str): + loop = runner.val_loop if stage == "val" else runner.test_loop + evaluator = getattr(loop, "evaluator", None) + if evaluator is None: + return None + for metric in getattr(evaluator, "metrics", []): + if isinstance(metric, T4SegMetric): + return metric + return None diff --git a/autoware_ml/segmentation3d/evaluation/__init__.py b/autoware_ml/segmentation3d/evaluation/__init__.py new file mode 100644 index 000000000..d2400edc7 --- /dev/null +++ b/autoware_ml/segmentation3d/evaluation/__init__.py @@ -0,0 +1,45 @@ +# Copyright (c) TIER IV, Inc. All rights reserved. +"""Segmentation evaluation: functional helpers + MMEngine metric adapter.""" + +from .functional.t4_seg_eval import ( + SegEvalResult, + compute_bev_distance, + fast_hist, + figure_to_numpy, + get_acc, + get_acc_cls, + normalize_confusion_matrix, + per_class_f1, + per_class_iou, + per_class_precision, + per_class_recall, + plot_confusion_matrix, + range_label, + t4_seg_eval, + t4_seg_eval_from_hists, + update_seg_eval_histograms, +) +from .metrics.t4_seg_metric import T4SegMetric +from .tensorboard import build_t4_seg_tb_scalars, iter_t4_seg_confusion_matrix_figures + +__all__ = [ + "SegEvalResult", + "T4SegMetric", + "build_t4_seg_tb_scalars", + "compute_bev_distance", + "fast_hist", + "figure_to_numpy", + "get_acc", + "get_acc_cls", + "normalize_confusion_matrix", + "per_class_f1", + "per_class_iou", + "per_class_precision", + "per_class_recall", + "plot_confusion_matrix", + "range_label", + "t4_seg_eval", + "t4_seg_eval_from_hists", + "iter_t4_seg_confusion_matrix_figures", + "update_seg_eval_histograms", +] diff --git a/autoware_ml/segmentation3d/evaluation/functional/__init__.py b/autoware_ml/segmentation3d/evaluation/functional/__init__.py new file mode 100644 index 000000000..2d95736f4 --- /dev/null +++ b/autoware_ml/segmentation3d/evaluation/functional/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) TIER IV, Inc. All rights reserved. +from .t4_seg_eval import ( + SegEvalResult, + compute_bev_distance, + fast_hist, + figure_to_numpy, + get_acc, + get_acc_cls, + normalize_confusion_matrix, + per_class_f1, + per_class_iou, + per_class_precision, + per_class_recall, + plot_confusion_matrix, + range_label, + t4_seg_eval, +) + +__all__ = [ + "SegEvalResult", + "compute_bev_distance", + "fast_hist", + "figure_to_numpy", + "get_acc", + "get_acc_cls", + "normalize_confusion_matrix", + "per_class_f1", + "per_class_iou", + "per_class_precision", + "per_class_recall", + "plot_confusion_matrix", + "range_label", + "t4_seg_eval", +] diff --git a/autoware_ml/segmentation3d/evaluation/functional/t4_seg_eval.py b/autoware_ml/segmentation3d/evaluation/functional/t4_seg_eval.py new file mode 100644 index 000000000..3ee98d7ae --- /dev/null +++ b/autoware_ml/segmentation3d/evaluation/functional/t4_seg_eval.py @@ -0,0 +1,423 @@ +# Copyright (c) TIER IV, Inc. All rights reserved. +"""Helpers for 3D semantic segmentation evaluation.""" + +from __future__ import annotations + +import io +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import Normalize as MplNormalize +from mmengine.logging import print_log +from PIL import Image +from terminaltables import AsciiTable + +_EPS = 1e-10 + + +def fast_hist(preds: np.ndarray, labels: np.ndarray, num_classes: int) -> np.ndarray: + """Confusion matrix for one sample (matches mmdet3d ``fast_hist``). + + ``hist[gt_class, pred_class]`` = number of points. + + Args: + preds: Predicted label array, shape ``(N,)``. + labels: Ground-truth label array, shape ``(N,)``. + num_classes: Number of classes. + + Returns: + ``np.ndarray`` of shape ``(num_classes, num_classes)``. + """ + k = (labels >= 0) & (labels < num_classes) & (preds >= 0) & (preds < num_classes) + bin_count = np.bincount( + num_classes * labels[k].astype(int) + preds[k], + minlength=num_classes**2, + ) + return bin_count[: num_classes**2].reshape(num_classes, num_classes) + + +def per_class_iou(hist: np.ndarray) -> np.ndarray: + """Per-class IoU from cumulative confusion matrix.""" + tp = np.diag(hist) + denom = hist.sum(1) + hist.sum(0) - tp + return np.where(denom > _EPS, tp / (denom + _EPS), np.nan) + + +def get_acc(hist: np.ndarray) -> float: + """Overall point-level accuracy.""" + return float(np.diag(hist).sum() / (hist.sum() + _EPS)) + + +def get_acc_cls(hist: np.ndarray) -> float: + """Class-average accuracy (same as macro recall).""" + return float(np.nanmean(np.diag(hist) / (hist.sum(axis=1) + _EPS))) + + +def per_class_precision(hist: np.ndarray) -> np.ndarray: + """Per-class precision: TP / (TP + FP) = TP / predicted-as-class.""" + tp = np.diag(hist) + predicted = hist.sum(axis=0) # column sums + return np.where(predicted > _EPS, tp / (predicted + _EPS), np.nan) + + +def per_class_recall(hist: np.ndarray) -> np.ndarray: + """Per-class recall: TP / (TP + FN) = TP / actual-class-count.""" + tp = np.diag(hist) + actual = hist.sum(axis=1) # row sums + return np.where(actual > _EPS, tp / (actual + _EPS), np.nan) + + +def per_class_f1(hist: np.ndarray) -> np.ndarray: + """Per-class F1 score: 2 * precision * recall / (precision + recall).""" + prec = per_class_precision(hist) + rec = per_class_recall(hist) + denom = prec + rec + return np.where(denom > _EPS, 2.0 * prec * rec / (denom + _EPS), np.nan) + + +def normalize_confusion_matrix(cm: np.ndarray) -> np.ndarray: + """Row-normalise so each row sums to 1 (GT-class perspective). + + Rows without any GT sample are set to 0 rather than NaN so the result + can be safely passed to matplotlib's ``imshow``. + """ + row_sums = cm.sum(axis=1, keepdims=True) + safe = np.where(row_sums > 0, row_sums, 1.0) + return cm / safe + + +def plot_confusion_matrix( + cm: np.ndarray, + class_names: List[str], + normalize: bool = True, + label: str = "", +) -> "matplotlib.figure.Figure": # type: ignore[name-defined] + """Render a confusion matrix as a matplotlib Figure for TensorBoard. + + * Y-axis = "True label", X-axis = "Predicted label". + * Color scale fixed at ``[0, 1]`` (normalised fractions) so plots from + different epochs are directly comparable. + * Numeric value annotated in every cell. + + Args: + cm: ``(num_classes, num_classes)`` confusion matrix, ``cm[gt][pred]``. + class_names: Human-readable class names. + normalize: If ``True`` (default), row-normalise before plotting. + label: Optional range label appended to the figure title. + + Returns: + ``matplotlib.figure.Figure`` - caller is responsible for closing it. + """ + nc = cm.shape[0] + cm_plot = normalize_confusion_matrix(cm) if normalize else cm.astype(float) + + fig, ax = plt.subplots(figsize=(max(10, nc * 0.6), max(8, nc * 0.55))) + im = ax.imshow( + cm_plot, + interpolation="nearest", + cmap="Blues", + norm=MplNormalize(vmin=0.0, vmax=1.0), + ) + fig.colorbar(im, ax=ax, shrink=0.8) + + font_size = max(4, 7 - nc // 10) + for i in range(nc): + for j in range(nc): + val = cm_plot[i, j] + color = "white" if val > 0.5 else "black" + ax.text(j, i, f"{val:.2f}", ha="center", va="center", fontsize=font_size, color=color) + + title = "Confusion Matrix" + if label: + title += f" [{label}]" + ax.set_title(title, fontsize=12) + ax.set_ylabel("True label", fontsize=11) + ax.set_xlabel("Predicted label", fontsize=11) + + tick_marks = np.arange(nc) + ax.set_xticks(tick_marks) + ax.set_yticks(tick_marks) + ax.set_xticklabels(class_names, rotation=45, ha="right", fontsize=7) + ax.set_yticklabels(class_names, fontsize=7) + fig.tight_layout() + return fig + + +def figure_to_numpy(fig) -> np.ndarray: + """Convert a matplotlib Figure to a uint8 HWC NumPy array (RGB). + + Uses the in-memory PNG path; does not require a display. + """ + buf = io.BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight") + buf.seek(0) + img = Image.open(buf).convert("RGB") + return np.array(img) + + +def compute_bev_distance(coords: np.ndarray) -> np.ndarray: + """BEV distance from ego: ``sqrt(x^2 + y^2)`` for each point. + + Args: + coords: ``(N, ≥2)`` array (first two columns are X, Y in metres). + + Returns: + ``(N,)`` array of distances in metres. + """ + return np.sqrt(coords[:, 0] ** 2 + coords[:, 1] ** 2) + + +def range_label(lo: float, hi: float) -> str: + """Human-readable range label, e.g. ``'0-20m'``.""" + return f"{lo:g}-{hi:g}m" + + +@dataclass +class SegEvalResult: + """Evaluation result with scalar metrics and raw confusion matrices. + + Attributes: + metrics: Flat dict of scalar metrics keyed in mmdetection3d style: + ``miou``, ``acc``, ``acc_cls``, per-class IoU by name, + ``mprecision``, ``mrecall``, ``mf1``, + ``precision/{class}``, ``recall/{class}``, ``f1/{class}``; and + for each range bucket, the same keys prefixed with + ``{range_label}/`` (e.g. ``0-20m/miou``). + cm: Total confusion matrix ``(num_classes, num_classes)``. + range_cms: Per-range confusion matrices keyed by range label. + """ + + metrics: Dict[str, float] = field(default_factory=dict) + cm: np.ndarray = field(default_factory=lambda: np.zeros((0, 0))) + range_cms: Dict[str, np.ndarray] = field(default_factory=dict) + + +def update_seg_eval_histograms( + total_hist: np.ndarray, + pred: np.ndarray, + gt: np.ndarray, + num_classes: int, + ignore_index: int, + range_hists: Optional[Dict[str, np.ndarray]] = None, + coord: Optional[np.ndarray] = None, + distance_ranges: Optional[List[Tuple[float, float]]] = None, +) -> None: + """Accumulate one sample into total and optional range confusion matrices.""" + pred = np.asarray(pred, dtype=np.int64).copy() + gt = np.asarray(gt, dtype=np.int64).copy() + + pred[gt == ignore_index] = ignore_index + gt[gt == ignore_index] = ignore_index + + total_hist += fast_hist(pred, gt, num_classes) + + if not range_hists or not distance_ranges or coord is None: + return + + coord = np.asarray(coord) + if coord.ndim != 2 or coord.shape[1] < 2 or coord.shape[0] != gt.size: + return + + dist = compute_bev_distance(coord) + for lo, hi in distance_ranges: + lbl = range_label(lo, hi) + mask = (dist >= lo) & (dist < hi) + if not np.any(mask): + continue + range_hists[lbl] += fast_hist(pred[mask], gt[mask], num_classes) + + +def _compute_bucket_metrics( + hist: np.ndarray, + label2cat: Dict[int, str], + ignore_index: int, + prefix: str, +) -> Dict[str, float]: + """Derive all scalar metrics from a confusion histogram. + + Args: + hist: ``(num_classes, num_classes)`` cumulative confusion matrix. + label2cat: ``{index: class_name}`` mapping. + ignore_index: Class index to exclude from averages. + prefix: String prepended to every key (e.g. ``'0-20m/'``). + + Returns: + Flat dict of scalar metrics for this bucket. + """ + num_classes = hist.shape[0] + out: Dict[str, float] = {} + + # Per-class IoU - identical to mmdet3d seg_eval + iou = per_class_iou(hist) + if 0 <= ignore_index < num_classes: + iou[ignore_index] = np.nan + miou = float(np.nanmean(iou)) + + # Per-class precision / recall / F1 + prec = per_class_precision(hist) + rec = per_class_recall(hist) + f1 = per_class_f1(hist) + if 0 <= ignore_index < num_classes: + prec[ignore_index] = np.nan + rec[ignore_index] = np.nan + f1[ignore_index] = np.nan + + out[f"{prefix}miou"] = miou + out[f"{prefix}acc"] = get_acc(hist) + out[f"{prefix}acc_cls"] = get_acc_cls(hist) + out[f"{prefix}mprecision"] = float(np.nanmean(prec)) + out[f"{prefix}mrecall"] = float(np.nanmean(rec)) + out[f"{prefix}mf1"] = float(np.nanmean(f1)) + + for idx in range(num_classes): + if idx == ignore_index: + continue + name = label2cat.get(idx, str(idx)) + out[f"{prefix}{name}"] = float(iou[idx]) + out[f"{prefix}precision/{name}"] = float(prec[idx]) + out[f"{prefix}recall/{name}"] = float(rec[idx]) + out[f"{prefix}f1/{name}"] = float(f1[idx]) + + return out + + +def _print_bucket_table( + hist: np.ndarray, + label2cat: Dict[int, str], + ignore_index: int, + title: str, + logger=None, +) -> None: + """Print an AsciiTable for one evaluation bucket.""" + num_classes = hist.shape[0] + iou = per_class_iou(hist) + if 0 <= ignore_index < num_classes: + iou[ignore_index] = np.nan + prec = per_class_precision(hist) + rec = per_class_recall(hist) + f1 = per_class_f1(hist) + + header = ["class", "IoU", "Prec", "Rec", "F1"] + rows = [header] + for idx in range(num_classes): + if idx == ignore_index: + continue + name = label2cat.get(idx, str(idx)) + rows.append( + [ + name, + f"{iou[idx]:.4f}" if not np.isnan(iou[idx]) else "N/A", + f"{prec[idx]:.4f}" if not np.isnan(prec[idx]) else "N/A", + f"{rec[idx]:.4f}" if not np.isnan(rec[idx]) else "N/A", + f"{f1[idx]:.4f}" if not np.isnan(f1[idx]) else "N/A", + ] + ) + miou = float(np.nanmean(iou)) + mprec = float(np.nanmean(prec)) + mrec = float(np.nanmean(rec)) + mf1 = float(np.nanmean(f1)) + rows.append(["mean", f"{miou:.4f}", f"{mprec:.4f}", f"{mrec:.4f}", f"{mf1:.4f}"]) + rows.append(["acc", f"{get_acc(hist):.4f}", "-", "-", "-"]) + rows.append(["acc_cls", f"{get_acc_cls(hist):.4f}", "-", "-", "-"]) + + table = AsciiTable(rows, title=title) + table.inner_footing_row_border = True + print_log("\n" + table.table, logger=logger) + + +def t4_seg_eval_from_hists( + total_hist: np.ndarray, + label2cat: Dict[int, str], + ignore_index: int, + range_hists: Optional[Dict[str, np.ndarray]] = None, + logger=None, +) -> SegEvalResult: + """Build scalar metrics and tables from pre-aggregated confusion matrices.""" + total_hist = np.asarray(total_hist, dtype=np.float64) + range_hists = range_hists or {} + + _print_bucket_table(total_hist, label2cat, ignore_index, title="Total", logger=logger) + metrics = _compute_bucket_metrics(total_hist, label2cat, ignore_index, prefix="") + + for lbl, hist_r in range_hists.items(): + hist_r = np.asarray(hist_r, dtype=np.float64) + if hist_r.sum() == 0: + continue + _print_bucket_table(hist_r, label2cat, ignore_index, title=lbl, logger=logger) + metrics.update(_compute_bucket_metrics(hist_r, label2cat, ignore_index, prefix=f"{lbl}/")) + + return SegEvalResult(metrics=metrics, cm=total_hist, range_cms=range_hists) + + +def t4_seg_eval( + gt_labels: List[np.ndarray], + seg_preds: List[np.ndarray], + label2cat: Dict[int, str], + ignore_index: int, + coords_list: Optional[List[Optional[np.ndarray]]] = None, + distance_ranges: Optional[List[Tuple[float, float]]] = None, + logger=None, +) -> SegEvalResult: + """Semantic segmentation evaluation with optional range-based breakdown. + + Produces the same top-level keys as ``mmdet3d.evaluation.seg_eval`` + (``miou``, ``acc``, ``acc_cls``, per-class IoU by name) and additionally + adds precision / recall / F1 metrics and optional per-range variants. + + Args: + gt_labels: Ground-truth label arrays, one per sample. + seg_preds: Predicted label arrays, one per sample. + label2cat: ``{output_index: class_name}`` mapping. + ignore_index: Label to exclude from metric computation. + coords_list: Optional per-sample XYZ coordinate arrays ``(N, ≥2)``. + When provided together with ``distance_ranges``, range-based + metrics are computed; otherwise only total metrics are returned. + distance_ranges: List of ``(lo, hi)`` metre pairs, e.g. + ``[(0, 20), (20, 40), ..., (100, 120)]``. + logger: Optional logger for tabular output. + + Returns: + :class:`SegEvalResult` with scalar metrics dict, total CM, and + per-range CMs. + """ + assert len(gt_labels) == len(seg_preds), ( + f"gt and pred lists must have the same length " f"({len(gt_labels)} vs {len(seg_preds)})" + ) + + num_classes = len(label2cat) + use_ranges = bool(distance_ranges and coords_list is not None) + + total_hist = np.zeros((num_classes, num_classes), dtype=np.float64) + + if use_ranges: + range_hists: Dict[str, np.ndarray] = { + range_label(lo, hi): np.zeros((num_classes, num_classes), dtype=np.float64) + for lo, hi in distance_ranges # type: ignore[union-attr] + } + else: + range_hists = {} + + for i in range(len(gt_labels)): + update_seg_eval_histograms( + total_hist=total_hist, + pred=seg_preds[i], + gt=gt_labels[i], + num_classes=num_classes, + ignore_index=ignore_index, + range_hists=range_hists, + coord=coords_list[i] if use_ranges else None, + distance_ranges=distance_ranges if use_ranges else None, + ) + + return t4_seg_eval_from_hists( + total_hist=total_hist, + label2cat=label2cat, + ignore_index=ignore_index, + range_hists=range_hists, + logger=logger, + ) diff --git a/autoware_ml/segmentation3d/evaluation/metrics/__init__.py b/autoware_ml/segmentation3d/evaluation/metrics/__init__.py new file mode 100644 index 000000000..0c56a0fda --- /dev/null +++ b/autoware_ml/segmentation3d/evaluation/metrics/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) TIER IV, Inc. All rights reserved. +from .t4_seg_metric import T4SegMetric + +__all__ = ["T4SegMetric"] diff --git a/autoware_ml/segmentation3d/evaluation/metrics/t4_seg_metric.py b/autoware_ml/segmentation3d/evaluation/metrics/t4_seg_metric.py new file mode 100644 index 000000000..a1baf8e24 --- /dev/null +++ b/autoware_ml/segmentation3d/evaluation/metrics/t4_seg_metric.py @@ -0,0 +1,261 @@ +# Copyright (c) TIER IV, Inc. All rights reserved. +"""MMEngine metric adapter for shared T4 segmentation evaluation.""" + +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +from mmdet3d.registry import METRICS +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from autoware_ml.segmentation3d.evaluation.functional.t4_seg_eval import ( + t4_seg_eval, +) + + +@dataclass +class T4SegMetricSample: + pred: np.ndarray + gt: np.ndarray + coord: Optional[np.ndarray] = None + + +@METRICS.register_module() +class T4SegMetric(BaseMetric): + """3D semantic segmentation evaluation metric for T4 datasets. + + Parameters + ---------- + num_classes: + Number of output classes (excluding the ignore class). + ignore_index: + Label value to skip during evaluation. Defaults to the value set in + ``dataset_meta``; the explicit argument takes priority. + distance_ranges: + Optional list of ``(lo, hi)`` metre pairs for range-based breakdown, + e.g. ``[(0, 20), (20, 40), (40, 60), (60, 80), (80, 100), (100, 120)]``. + collect_device: + Device used for collecting results across ranks. ``'cpu'`` or ``'gpu'``. + prefix: + Optional metric-name prefix. + """ + + default_prefix: Optional[str] = None + + def __init__( + self, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + distance_ranges: Optional[List[Tuple[float, float]]] = None, + collect_device: str = "cpu", + prefix: Optional[str] = None, + **kwargs, + ): + super().__init__(prefix=prefix, collect_device=collect_device) + self._num_classes = num_classes + self._ignore_index = ignore_index + self.distance_ranges = distance_ranges or [] + self.last_eval_result = None + self.last_label2cat: Dict[int, str] = {} + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Collect one batch of model outputs for later aggregation.""" + batch_coords = self._extract_batch_coords(data_batch, data_samples) + logger: MMLogger = MMLogger.get_current_instance() + + for i, data_sample in enumerate(data_samples): + pred_field = data_sample.get("pred_pts_seg", {}) + ann_field = data_sample.get("eval_ann_info", {}) + + pred = self._to_numpy(pred_field.get("pts_semantic_mask")) + gt = self._to_numpy(ann_field.get("pts_semantic_mask")) + + if pred is None or gt is None: + logger.warning("T4SegMetric: skipping sample with missing prediction or ground-truth labels.") + continue + if pred.size != gt.size: + logger.warning( + "T4SegMetric: skipping sample because prediction and ground-truth lengths differ: " + f"{pred.size} vs {gt.size}." + ) + continue + + coord_i = batch_coords[i] if batch_coords else None + if coord_i is not None: + if coord_i.shape[0] > gt.size: + coord_i = coord_i[: gt.size] + elif coord_i.shape[0] < gt.size: + coord_i = None + + self.results.append(T4SegMetricSample(pred=pred, gt=gt, coord=coord_i)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Aggregate per-batch results and return the full metrics dict.""" + logger: MMLogger = MMLogger.get_current_instance() + + if not results: + logger.warning("T4SegMetric: no results to evaluate.") + return {} + + ignore_index = self._get_ignore_index() + label2cat = self._get_label2cat() + + # Do not include ignore_index in label2cat. When ignore_index sits outside [0, num_classes) fast_hist naturally + # drops those points via its ``labels < num_classes`` mask. Adding it would expand the confusion matrix + # and pollute acc / acc_cls. + label2cat.pop(ignore_index, None) + target_num_classes = self._num_classes or len(label2cat) + for idx in range(target_num_classes): + if idx not in label2cat and idx != ignore_index: + label2cat[idx] = str(idx) + + gt_labels = [r.gt for r in results] + seg_preds = [r.pred for r in results] + coords_list = [r.coord for r in results] if self.distance_ranges else None + if self.distance_ranges and (not coords_list or all(c is None for c in coords_list)): + logger.warning( + "T4SegMetric: distance_ranges is configured but no coordinates " + "were extracted from data_batch. Range-based confusion matrices " + "will be empty." + ) + + eval_result = t4_seg_eval( + gt_labels, + seg_preds, + label2cat, + ignore_index, + coords_list=coords_list, + distance_ranges=self.distance_ranges if self.distance_ranges else None, + logger=logger, + ) + + if self.distance_ranges and eval_result.cm.sum() > 0: + covered = sum(cm.sum() for cm in eval_result.range_cms.values()) + if covered == 0: + logger.warning( + "T4SegMetric: total confusion matrix is non-empty but all " + "range-based confusion matrices are empty. This usually " + "means distance_ranges do not cover observed distances or " + "coordinate extraction is still misaligned." + ) + + self.last_eval_result = eval_result + self.last_label2cat = dict(label2cat) + + return eval_result.metrics + + @staticmethod + def _to_numpy(v) -> Optional[np.ndarray]: + """Convert tensor / array-like to a flat int64 numpy array.""" + if v is None: + return None + if hasattr(v, "cpu"): + v = v.cpu().numpy() + arr = np.asarray(v, dtype=np.int64) + return arr.ravel() + + @staticmethod + def _extract_batch_coords(data_batch: dict, data_samples: Sequence[dict]) -> Optional[List]: + """Try to extract XY coordinates from packed input points. + + Returns a list of length ``len(data_samples)`` where each entry is either a + ``(N, 2)`` float32 array or ``None``. + """ + try: + n_samples = len(data_samples) + + def _unwrap_points_tensor(obj): + """Best-effort unwrapping for collate/data wrappers.""" + cur = obj + for _ in range(8): + if cur is None: + return None + if hasattr(cur, "tensor"): + cur = cur.tensor + continue + if hasattr(cur, "data") and not isinstance(cur, np.ndarray): + nxt = getattr(cur, "data") + if nxt is cur: + break + cur = nxt + continue + if isinstance(cur, (list, tuple)) and len(cur) == 1: + cur = cur[0] + continue + break + return cur + + inputs = data_batch.get("inputs") or {} + if not isinstance(inputs, dict): + inputs = {} + points_data = inputs.get("points") + if points_data is None: + return None + + num_points_list = [] + for ds in data_samples: + meta = getattr(ds, "metainfo", {}) or {} + n = meta.get("num_points", None) + num_points_list.append(int(n) if isinstance(n, (int, np.integer)) else None) + + if not isinstance(points_data, (list, tuple)): + raw = _unwrap_points_tensor(points_data) + if raw is not None and hasattr(raw, "cpu"): + raw = raw.cpu().numpy() + raw_arr = np.asarray(raw) if raw is not None else None + if ( + raw_arr is not None + and raw_arr.ndim >= 2 + and all(v is not None for v in num_points_list) + and sum(num_points_list) <= raw_arr.shape[0] + ): + split = [] + st = 0 + for n in num_points_list: + ed = st + int(n) + split.append(raw_arr[st:ed]) + st = ed + points_data = split + else: + points_data = [points_data] * n_samples + + coords = [] + for pts in points_data[:n_samples]: + if pts is None: + coords.append(None) + continue + tens = _unwrap_points_tensor(pts) + if tens is None: + coords.append(None) + continue + if hasattr(tens, "cpu"): + tens = tens.cpu().numpy() + arr = np.asarray(tens, dtype=np.float32) + if arr.ndim >= 2 and arr.shape[1] >= 2: + coords.append(arr[:, :2]) + else: + coords.append(None) + return coords + except Exception: + return None + + def _get_label2cat(self) -> Dict[int, str]: + """Resolve {output_index: class_name} from constructor args or meta.""" + meta = getattr(self, "dataset_meta", {}) or {} + label2cat = meta.get("label2cat") + if isinstance(label2cat, dict): + return {int(k): str(v) for k, v in label2cat.items()} + # Fallback: use class_names list if available + class_names = meta.get("classes") or meta.get("class_names") + if isinstance(class_names, (list, tuple)): + return {i: str(name) for i, name in enumerate(class_names)} + # Last resort: numeric class names + nc = self._num_classes or 1 + return {i: str(i) for i in range(nc)} + + def _get_ignore_index(self) -> int: + if self._ignore_index is not None: + return self._ignore_index + meta = getattr(self, "dataset_meta", {}) or {} + return int(meta.get("ignore_index", -1)) diff --git a/autoware_ml/segmentation3d/evaluation/tensorboard.py b/autoware_ml/segmentation3d/evaluation/tensorboard.py new file mode 100644 index 000000000..d79006f38 --- /dev/null +++ b/autoware_ml/segmentation3d/evaluation/tensorboard.py @@ -0,0 +1,61 @@ +# Copyright (c) TIER IV, Inc. All rights reserved. +"""Shared TensorBoard naming for T4 segmentation metrics.""" + +from __future__ import annotations + +from typing import Dict, Iterable, List, Optional, Tuple + +from .functional.t4_seg_eval import plot_confusion_matrix, range_label + +_SUMMARY_KEYS = ("miou", "acc", "acc_cls", "mprecision", "mrecall", "mf1") + + +def build_t4_seg_tb_scalars( + metrics: Dict[str, float], + class_names: List[str], + stage: str, + distance_ranges: Optional[Iterable[Tuple[float, float]]] = None, +) -> Dict[str, float]: + """Map canonical metric keys to the shared TensorBoard naming scheme.""" + tb_scalars: Dict[str, float] = {} + + for key in _SUMMARY_KEYS: + if key in metrics: + tb_scalars[f"{stage}/{key}"] = metrics[key] + + for class_name in class_names: + if class_name in metrics: + tb_scalars[f"{stage}/class_iou/{class_name}"] = metrics[class_name] + for sub in ("precision", "recall", "f1"): + metric_key = f"{sub}/{class_name}" + if metric_key in metrics: + tb_scalars[f"{stage}/class_{sub}/{class_name}"] = metrics[metric_key] + + for lo, hi in distance_ranges or []: + bucket = range_label(lo, hi) + for key in _SUMMARY_KEYS: + metric_key = f"{bucket}/{key}" + if metric_key in metrics: + tb_scalars[f"{stage}/range/{bucket}/{key}"] = metrics[metric_key] + for class_name in class_names: + metric_key = f"{bucket}/{class_name}" + if metric_key in metrics: + tb_scalars[f"{stage}/range/{bucket}/class_iou/{class_name}"] = metrics[metric_key] + for sub in ("precision", "recall", "f1"): + metric_key = f"{bucket}/{sub}/{class_name}" + if metric_key in metrics: + tb_scalars[f"{stage}/range/{bucket}/class_{sub}/{class_name}"] = metrics[metric_key] + + return tb_scalars + + +def iter_t4_seg_confusion_matrix_figures(eval_result, class_names: List[str], stage: str): + """Yield standardised TensorBoard tags and matplotlib figures.""" + if eval_result.cm is not None and eval_result.cm.sum() > 0: + yield f"{stage}/confusion_matrix", plot_confusion_matrix(eval_result.cm, class_names) + + for bucket, range_cm in eval_result.range_cms.items(): + if range_cm is None or range_cm.sum() == 0: + continue + tag = f"{stage}/confusion_matrix_{bucket.replace('-', '_').replace(' ', '_')}" + yield tag, plot_confusion_matrix(range_cm, class_names, label=bucket) diff --git a/projects/FRNet/configs/nuscenes/frnet_1xb4_nus-seg.py b/projects/FRNet/configs/nuscenes/frnet_1xb4_nus-seg.py index 8a79a6ce8..cdeac1488 100644 --- a/projects/FRNet/configs/nuscenes/frnet_1xb4_nus-seg.py +++ b/projects/FRNet/configs/nuscenes/frnet_1xb4_nus-seg.py @@ -309,4 +309,7 @@ log_processor = dict(type="LogProcessor", window_size=50, by_epoch=False) -default_hooks = dict(checkpoint=dict(type="CheckpointHook", by_epoch=False, interval=-1, save_best="miou")) +default_hooks = dict( + logger=dict(type="LoggerHook", log_metric_by_epoch=False), + checkpoint=dict(type="CheckpointHook", by_epoch=False, interval=-1, save_best="miou"), +) diff --git a/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-ot128-seg.py b/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-ot128-seg.py index e14c903b8..65f981b99 100644 --- a/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-ot128-seg.py +++ b/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-ot128-seg.py @@ -9,6 +9,8 @@ "projects.FRNet.frnet.datasets", "projects.FRNet.frnet.datasets.transforms", "projects.FRNet.frnet.models", + "autoware_ml.hooks", + "autoware_ml.segmentation3d.evaluation.metrics", ], allow_failed_imports=False, ) @@ -290,8 +292,19 @@ ) test_dataloader = val_dataloader -val_evaluator = dict(type="SegMetric") -test_evaluator = val_evaluator +distance_ranges = [(0, 20), (20, 40), (40, 60), (60, 80), (80, 100.0), (100.0, 120.0)] +val_evaluator = dict( + type="T4SegMetric", + num_classes=num_classes, + ignore_index=ignore_index, + distance_ranges=distance_ranges, +) +test_evaluator = dict( + type="T4SegMetric", + num_classes=num_classes, + ignore_index=ignore_index, + distance_ranges=distance_ranges, +) vis_backends = [dict(type="LocalVisBackend"), dict(type="TensorboardVisBackend")] @@ -323,4 +336,8 @@ log_processor = dict(type="LogProcessor", window_size=50, by_epoch=False) -default_hooks = dict(checkpoint=dict(type="CheckpointHook", by_epoch=False, interval=-1, save_best="miou")) +default_hooks = dict( + logger=dict(type="T4SegLoggerHook", log_metric_by_epoch=False), + checkpoint=dict(type="CheckpointHook", by_epoch=False, interval=-1, save_best="miou"), +) +custom_hooks = [dict(type="T4SegTensorboardHook")] diff --git a/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-qt128-seg.py b/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-qt128-seg.py index 4bed85c4b..c614f9758 100644 --- a/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-qt128-seg.py +++ b/projects/FRNet/configs/t4dataset/frnet_1xb8_t4dataset-qt128-seg.py @@ -9,6 +9,8 @@ "projects.FRNet.frnet.datasets", "projects.FRNet.frnet.datasets.transforms", "projects.FRNet.frnet.models", + "autoware_ml.hooks", + "autoware_ml.segmentation3d.evaluation.metrics", ], allow_failed_imports=False, ) @@ -290,8 +292,19 @@ ) test_dataloader = val_dataloader -val_evaluator = dict(type="SegMetric") -test_evaluator = val_evaluator +distance_ranges = [(0, 20), (20, 40), (40, 60), (60, 80), (80, 100.0), (100.0, 120.0)] +val_evaluator = dict( + type="T4SegMetric", + num_classes=num_classes, + ignore_index=ignore_index, + distance_ranges=distance_ranges, +) +test_evaluator = dict( + type="T4SegMetric", + num_classes=num_classes, + ignore_index=ignore_index, + distance_ranges=distance_ranges, +) vis_backends = [dict(type="LocalVisBackend"), dict(type="TensorboardVisBackend")] @@ -323,4 +336,8 @@ log_processor = dict(type="LogProcessor", window_size=50, by_epoch=False) -default_hooks = dict(checkpoint=dict(type="CheckpointHook", by_epoch=False, interval=-1, save_best="miou")) +default_hooks = dict( + logger=dict(type="T4SegLoggerHook", log_metric_by_epoch=False), + checkpoint=dict(type="CheckpointHook", by_epoch=False, interval=-1, save_best="miou"), +) +custom_hooks = [dict(type="T4SegTensorboardHook")] diff --git a/projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py b/projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py index 756fd4c32..e000357c3 100644 --- a/projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py +++ b/projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py @@ -54,6 +54,8 @@ "unpainted": ignore_index, } num_classes = 26 +distance_ranges = [(0, 20), (20, 40), (40, 60), (60, 80), (80, 100.0), (100.0, 120.0)] +metric_options = dict(distance_ranges=distance_ranges) # model settings model = dict( diff --git a/projects/PTv3/engines/hooks/evaluator.py b/projects/PTv3/engines/hooks/evaluator.py index 8732447bd..409f76c65 100644 --- a/projects/PTv3/engines/hooks/evaluator.py +++ b/projects/PTv3/engines/hooks/evaluator.py @@ -5,13 +5,23 @@ Please cite our work if the code is helpful to you. """ +import matplotlib import numpy as np + +matplotlib.use("Agg") +import matplotlib.pyplot as plt import torch import torch.distributed as dist import utils.comm as comm -from utils.misc import intersection_and_union_gpu from autoware_ml.segmentation3d.datasets.utils import class_mapping_to_names +from autoware_ml.segmentation3d.evaluation import ( + SegEvalResult, + build_t4_seg_tb_scalars, + iter_t4_seg_confusion_matrix_figures, + t4_seg_eval_from_hists, + update_seg_eval_histograms, +) from .builder import HOOKS from .default import HookBase @@ -26,6 +36,23 @@ def after_epoch(self): def eval(self): self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") self.trainer.model.eval() + cfg = self.trainer.cfg + num_classes = cfg.data.num_classes + ignore_index = cfg.data.ignore_index + metric_options = getattr(cfg, "metric_options", None) or {} + distance_ranges = metric_options.get("distance_ranges") or [] + reduce_device = ( + torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu") + ) + + total_hist = torch.zeros((num_classes, num_classes), dtype=torch.float64, device=reduce_device) + range_hist_tensors = { + f"{lo:g}-{hi:g}m": torch.zeros((num_classes, num_classes), dtype=torch.float64, device=reduce_device) + for lo, hi in distance_ranges + } + loss_sum = 0.0 + loss_count = 0 + for i, input_dict in enumerate(self.trainer.val_loader): for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): @@ -34,73 +61,93 @@ def eval(self): output_dict = self.trainer.model(input_dict) output = output_dict["seg_logits"] loss = output_dict["loss"] - pred = output.max(1)[1] - segment = input_dict["segment"] - intersection, union, target = intersection_and_union_gpu( - pred, - segment, - self.trainer.cfg.data.num_classes, - self.trainer.cfg.data.ignore_index, - ) - if comm.get_world_size() > 1: - dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target) - intersection, union, target = ( - intersection.cpu().numpy(), - union.cpu().numpy(), - target.cpu().numpy(), + pred = output.max(1)[1].detach().cpu().numpy() + segment = input_dict["segment"].detach().cpu().numpy() + + # Extract BEV coordinate for range-based metrics. + coord_np = None + if "coord" in input_dict: + coord = input_dict["coord"] + if isinstance(coord, torch.Tensor): + coord_np = coord.detach().cpu().numpy() + if coord_np.ndim != 2 or coord_np.shape[1] < 2: + coord_np = None + + sample_total_hist = np.zeros((num_classes, num_classes), dtype=np.float64) + sample_range_hists = { + label: np.zeros((num_classes, num_classes), dtype=np.float64) for label in range_hist_tensors + } + update_seg_eval_histograms( + total_hist=sample_total_hist, + pred=pred, + gt=segment, + num_classes=num_classes, + ignore_index=ignore_index, + range_hists=sample_range_hists, + coord=coord_np, + distance_ranges=distance_ranges if distance_ranges else None, ) - # Here there is no need to sync since sync happened in dist.all_reduce - self.trainer.storage.put_scalar("val_intersection", intersection) - self.trainer.storage.put_scalar("val_union", union) - self.trainer.storage.put_scalar("val_target", target) - self.trainer.storage.put_scalar("val_loss", loss.item()) - info = "Test: [{iter}/{max_iter}] ".format(iter=i + 1, max_iter=len(self.trainer.val_loader)) - if "origin_coord" in input_dict.keys(): + total_hist += torch.from_numpy(sample_total_hist).to(device=total_hist.device) + for label, hist in sample_range_hists.items(): + range_hist_tensors[label] += torch.from_numpy(hist).to(device=total_hist.device) + loss_sum += float(loss.item()) + loss_count += 1 + + info = f"Test: [{i + 1}/{len(self.trainer.val_loader)}] " + if "origin_coord" in input_dict: info = "Interp. " + info - self.trainer.logger.info( - info + "Loss {loss:.4f} ".format(iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()) - ) - loss_avg = self.trainer.storage.history("val_loss").avg - intersection = self.trainer.storage.history("val_intersection").total - union = self.trainer.storage.history("val_union").total - target = self.trainer.storage.history("val_target").total - iou_class = intersection / (union + 1e-10) - acc_class = intersection / (target + 1e-10) - m_iou = np.mean(iou_class) - m_acc = np.mean(acc_class) - all_acc = sum(intersection) / (sum(target) + 1e-10) - self.trainer.logger.info("Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(m_iou, m_acc, all_acc)) - - mapped_class_names = class_mapping_to_names( - self.trainer.cfg.class_mapping, - self.trainer.cfg.data.ignore_index, + self.trainer.logger.info(info + f"Loss {loss.item():.4f}") + + comm.synchronize() + if comm.get_world_size() > 1: + dist.reduce(total_hist, dst=0) + for hist in range_hist_tensors.values(): + dist.reduce(hist, dst=0) + loss_reduced = comm.reduce_dict( + { + "loss_sum": torch.tensor(loss_sum, dtype=torch.float64, device=reduce_device), + "loss_count": torch.tensor(loss_count, dtype=torch.float64, device=reduce_device), + }, + average=False, ) - assert len(mapped_class_names) == self.trainer.cfg.data.num_classes, ( - "class_mapping_to_names length must match num_classes: " - f"{len(mapped_class_names)} vs {self.trainer.cfg.data.num_classes}" + if not comm.is_main_process(): + return + + loss_avg = float(loss_reduced["loss_sum"] / loss_reduced["loss_count"].clamp_min(1.0)) + + mapped_class_names = class_mapping_to_names(cfg.class_mapping, ignore_index) + assert len(mapped_class_names) == num_classes, ( + "class_mapping_to_names length must match num_classes: " f"{len(mapped_class_names)} vs {num_classes}" ) - for i in range(self.trainer.cfg.data.num_classes): - self.trainer.logger.info( - "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( - idx=i, - name=mapped_class_names[i], - iou=iou_class[i], - accuracy=acc_class[i], - ) - ) - current_epoch = self.trainer.epoch + 1 - if self.trainer.writer is not None: - self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) - self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch) - self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch) - self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch) - for i in range(self.trainer.cfg.data.num_classes): - name = mapped_class_names[i] - self.trainer.writer.add_scalar(f"val_class_iou/{name}", iou_class[i], current_epoch) - self.trainer.writer.add_scalar(f"val_class_acc/{name}", acc_class[i], current_epoch) + label2cat = {i: mapped_class_names[i] for i in range(num_classes)} + + eval_result: SegEvalResult = t4_seg_eval_from_hists( + total_hist=total_hist.cpu().numpy(), + label2cat=label2cat, + ignore_index=ignore_index, + range_hists={label: hist.cpu().numpy() for label, hist in range_hist_tensors.items()}, + logger=self.trainer.logger, + ) + + epoch = self.trainer.epoch + 1 + writer = self.trainer.writer + if writer is not None: + writer.add_scalar("val/loss", loss_avg, epoch) + for tag, value in build_t4_seg_tb_scalars( + metrics=eval_result.metrics, + class_names=mapped_class_names, + stage="val", + distance_ranges=distance_ranges, + ).items(): + writer.add_scalar(tag, value, epoch) + + for tag, fig in iter_t4_seg_confusion_matrix_figures(eval_result, mapped_class_names, "val"): + writer.add_figure(tag, fig, epoch) + plt.close(fig) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") - self.trainer.comm_info["current_metric_value"] = m_iou # save for saver - self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver + self.trainer.comm_info["current_metric_value"] = eval_result.metrics.get("miou", 0.0) + self.trainer.comm_info["current_metric_name"] = "miou" def after_train(self): - self.trainer.logger.info("Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)) + self.trainer.logger.info("Best {}: {:.4f}".format("miou", self.trainer.best_metric_value)) diff --git a/projects/PTv3/engines/test.py b/projects/PTv3/engines/test.py index f06e989bc..39a5ccbbc 100644 --- a/projects/PTv3/engines/test.py +++ b/projects/PTv3/engines/test.py @@ -5,27 +5,40 @@ Please cite our work if the code is helpful to you. """ +import json import os import time from collections import OrderedDict +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt import numpy as np import torch +import torch.distributed as dist import torch.nn.functional as F import torch.utils.data import utils.comm as comm from datasets import build_dataset, collate_fn from models import build_model +from tensorboardX import SummaryWriter from utils.logger import get_root_logger from utils.misc import ( AverageMeter, - intersection_and_union, make_dirs, ) from utils.registry import Registry from utils.visualization import get_segmentation_colors, visualize_point_cloud from autoware_ml.segmentation3d.datasets.utils import class_mapping_to_names +from autoware_ml.segmentation3d.evaluation import ( + SegEvalResult, + build_t4_seg_tb_scalars, + iter_t4_seg_confusion_matrix_figures, + t4_seg_eval_from_hists, + update_seg_eval_histograms, +) from .defaults import create_ddp_model @@ -42,6 +55,7 @@ def __init__(self, cfg, model=None, test_loader=None, verbose=False) -> None: self.logger.info("=> Loading config ...") self.cfg = cfg self.verbose = verbose + self.writer = self.build_writer() if self.verbose: self.logger.info(f"Save path: {cfg.save_path}") self.logger.info(f"Config:\n{cfg.pretty_text}") @@ -100,6 +114,12 @@ def build_test_loader(self): ) return test_loader + def build_writer(self): + if not comm.is_main_process(): + return None + self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") + return SummaryWriter(self.cfg.save_path) + def test(self): raise NotImplementedError @@ -128,17 +148,23 @@ def test(self): logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") batch_time = AverageMeter() - intersection_meter = AverageMeter() - union_meter = AverageMeter() - target_meter = AverageMeter() + num_classes = self.cfg.data.num_classes + ignore_index = self.cfg.data.ignore_index + metric_options = getattr(self.cfg, "metric_options", None) or {} + distance_ranges = metric_options.get("distance_ranges") or [] + reduce_device = ( + torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu") + ) + total_hist = torch.zeros((num_classes, num_classes), dtype=torch.float64, device=reduce_device) + range_hist_tensors = { + f"{lo:g}-{hi:g}m": torch.zeros((num_classes, num_classes), dtype=torch.float64, device=reduce_device) + for lo, hi in distance_ranges + } self.model.eval() save_path = os.path.join(self.cfg.save_path, "result") make_dirs(save_path) - # create submit folder only on main process if self.cfg.data.test.type == "NuScenesDataset" and comm.is_main_process(): - import json - make_dirs(os.path.join(save_path, "submit", "lidarseg", "test")) make_dirs(os.path.join(save_path, "submit", "test")) submission = dict( @@ -153,8 +179,6 @@ def test(self): with open(os.path.join(save_path, "submit", "test", "submission.json"), "w") as f: json.dump(submission, f, indent=4) comm.synchronize() - record = {} - # fragment inference for idx, data_dict in enumerate(self.test_loader): end = time.time() data_dict = data_dict[0] # current assume batch size is 1 @@ -162,15 +186,20 @@ def test(self): segment = data_dict.pop("segment") data_name = data_dict.pop("name") pred_save_path = os.path.join(save_path, "{}_pred.npy".format(data_name)) - feat_save_path = os.path.join(save_path, "{}_feat.npy".format(data_name)) result_save_path = os.path.join(save_path, "{}_{}_pred.npz".format(idx, data_name)) if os.path.isfile(pred_save_path): logger.info("{}/{}: {}, loaded pred and label.".format(idx + 1, len(self.test_loader), data_name)) pred = np.load(pred_save_path) + # Try to recover cached features from the corresponding NPZ file, if available. + feat_np = None + if os.path.isfile(result_save_path): + cached_result = np.load(result_save_path) + if "feat" in getattr(cached_result, "files", []): + feat_np = cached_result["feat"] if "origin_segment" in data_dict.keys(): segment = data_dict["origin_segment"] else: - pred = torch.zeros((segment.size, self.cfg.data.num_classes)).cuda() + pred = torch.zeros((segment.size, num_classes)).cuda() feat = torch.zeros((segment.size, 4)).cuda() for i in range(len(fragment_list)): fragment_batch_size = 1 @@ -201,17 +230,15 @@ def test(self): ) ) pred = pred.max(1)[1].data.cpu().numpy() + feat_np = feat.cpu().numpy() if "origin_segment" in data_dict.keys(): assert "inverse" in data_dict.keys() pred = pred[data_dict["inverse"]] - feat = feat[data_dict["inverse"]] + feat_np = feat_np[data_dict["inverse"]] segment = data_dict["origin_segment"] - # np.save(pred_save_path, pred) - # np.save(feat_save_path, feat.cpu().numpy()) - np.savez_compressed(result_save_path, pred=pred, feat=feat.cpu().numpy()) + np.savez_compressed(result_save_path, pred=pred, feat=feat_np) - # Call visualization if self.cfg.show: outputs = {"pred": pred, "segment": segment, "result_path": result_save_path} self.visualize_results(outputs, result_save_path) @@ -227,85 +254,85 @@ def test(self): ) ) - intersection, union, target = intersection_and_union( - pred, segment, self.cfg.data.num_classes, self.cfg.data.ignore_index + coord_np = feat_np[:, :3] if feat_np is not None and feat_np.ndim == 2 and feat_np.shape[1] >= 3 else None + sample_total_hist = np.zeros((num_classes, num_classes), dtype=np.float64) + sample_range_hists = { + label: np.zeros((num_classes, num_classes), dtype=np.float64) for label in range_hist_tensors + } + update_seg_eval_histograms( + total_hist=sample_total_hist, + pred=pred, + gt=segment, + num_classes=num_classes, + ignore_index=ignore_index, + range_hists=sample_range_hists, + coord=coord_np, + distance_ranges=distance_ranges if distance_ranges else None, ) - intersection_meter.update(intersection) - union_meter.update(union) - target_meter.update(target) - record[data_name] = dict(intersection=intersection, union=union, target=target) - - mask = union != 0 - iou_class = intersection / (union + 1e-10) - iou = np.mean(iou_class[mask]) - acc = sum(intersection) / (sum(target) + 1e-10) - - m_iou = np.mean(intersection_meter.sum / (union_meter.sum + 1e-10)) - m_acc = np.mean(intersection_meter.sum / (target_meter.sum + 1e-10)) + total_hist += torch.from_numpy(sample_total_hist).to(device=total_hist.device) + for label, hist in sample_range_hists.items(): + range_hist_tensors[label] += torch.from_numpy(hist).to(device=total_hist.device) batch_time.update(time.time() - end) logger.info( "Test: {} [{}/{}]-{} " - "Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " - "Accuracy {acc:.4f} ({m_acc:.4f}) " - "mIoU {iou:.4f} ({m_iou:.4f})".format( + "Batch {batch_time.val:.3f} ({batch_time.avg:.3f})".format( data_name, idx + 1, len(self.test_loader), segment.size, batch_time=batch_time, - acc=acc, - m_acc=m_acc, - iou=iou, - m_iou=m_iou, ) ) logger.info("Syncing ...") comm.synchronize() - record_sync = comm.gather(record, dst=0) + if comm.get_world_size() > 1: + dist.reduce(total_hist, dst=0) + for hist in range_hist_tensors.values(): + dist.reduce(hist, dst=0) if comm.is_main_process(): - record = {} - for _ in range(len(record_sync)): - r = record_sync.pop() - record.update(r) - del r - intersection = np.sum([meters["intersection"] for _, meters in record.items()], axis=0) - union = np.sum([meters["union"] for _, meters in record.items()], axis=0) - target = np.sum([meters["target"] for _, meters in record.items()], axis=0) + mapped_class_names = class_mapping_to_names(self.cfg.class_mapping, ignore_index) + assert len(mapped_class_names) == num_classes, ( + "class_mapping_to_names length must match num_classes: " f"{len(mapped_class_names)} vs {num_classes}" + ) + label2cat = {i: mapped_class_names[i] for i in range(num_classes)} if self.cfg.data.test.type == "S3DISDataset": + s3dis_hist = total_hist.cpu().numpy() + intersection = np.diag(s3dis_hist) + union = s3dis_hist.sum(1) + s3dis_hist.sum(0) - np.diag(s3dis_hist) + target = s3dis_hist.sum(1) torch.save( dict(intersection=intersection, union=union, target=target), os.path.join(save_path, f"{self.test_loader.dataset.split}.pth"), ) - iou_class = intersection / (union + 1e-10) - accuracy_class = intersection / (target + 1e-10) - mIoU = np.mean(iou_class) - mAcc = np.mean(accuracy_class) - allAcc = sum(intersection) / (sum(target) + 1e-10) - - logger.info("Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}".format(mIoU, mAcc, allAcc)) - mapped_class_names = class_mapping_to_names( - self.cfg.class_mapping, - self.cfg.data.ignore_index, - ) - assert len(mapped_class_names) == self.cfg.data.num_classes, ( - "class_mapping_to_names length must match num_classes: " - f"{len(mapped_class_names)} vs {self.cfg.data.num_classes}" + eval_result: SegEvalResult = t4_seg_eval_from_hists( + total_hist=total_hist.cpu().numpy(), + label2cat=label2cat, + ignore_index=ignore_index, + range_hists={label: hist.cpu().numpy() for label, hist in range_hist_tensors.items()}, + logger=logger, ) - for i in range(self.cfg.data.num_classes): - logger.info( - "Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( - idx=i, - name=mapped_class_names[i], - iou=iou_class[i], - accuracy=accuracy_class[i], - ) - ) - logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + + if self.writer is not None: + for tag, value in build_t4_seg_tb_scalars( + metrics=eval_result.metrics, + class_names=mapped_class_names, + stage="test", + distance_ranges=distance_ranges, + ).items(): + self.writer.add_scalar(tag, value, 0) + for tag, fig in iter_t4_seg_confusion_matrix_figures(eval_result, mapped_class_names, "test"): + self.writer.add_figure(tag, fig, 0) + plt.close(fig) + self.writer.flush() + + if self.writer is not None: + self.writer.close() + logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") @staticmethod def collate_fn(batch): diff --git a/projects/PTv3/tools/export.py b/projects/PTv3/tools/export.py index a96ded2ae..ae24c23b3 100644 --- a/projects/PTv3/tools/export.py +++ b/projects/PTv3/tools/export.py @@ -1,5 +1,4 @@ import numpy as np -import SparseConvolution # NOTE(knzo25): do not remove this import, it is needed for onnx export import spconv.pytorch as spconv import torch from engines.defaults import ( @@ -12,6 +11,9 @@ from models.utils.structure import Point, bit_length_tensor from torch.nn import functional as F +# NOTE: keep this import last; it overrides sparse conv registration for export. +import SparseConvolution # isort: skip + class WrappedModel(torch.nn.Module):