Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8f72db8
feat: refactor config
vividf Feb 2, 2026
ddd6190
chore: fix property
vividf Feb 2, 2026
4471756
chore: clean code
vividf Feb 2, 2026
e979652
chore: temp remove centerpoint files
vividf Feb 2, 2026
f38c324
chore: clean code
vividf Feb 10, 2026
825e79a
feat: integrate centerpoint to deployment framework
vividf Feb 2, 2026
2eeeb3a
chore: centerpoint - clean or {}
vividf Feb 2, 2026
041cd45
chore: centerpoint-clean code
vividf Feb 2, 2026
fe26084
chore: temp remove centerpoint files
vividf Feb 2, 2026
d2e56bc
chore: add files back
vividf Feb 2, 2026
a6dec87
ci(pre-commit): autofix
pre-commit-ci[bot] Feb 16, 2026
9076534
chore: clean code
vividf Feb 16, 2026
92fae7c
chore: update threshold for centerpoint
vividf Feb 18, 2026
36662d1
chore: clean up code
vividf Feb 18, 2026
c77e548
chore: refactor base config - centerpoint
vividf Mar 5, 2026
bc51ed8
chore: clean up code: device spec, remove unused fucntion .etc - cent…
vividf Mar 10, 2026
227fa82
chore: fix Any
vividf Mar 10, 2026
2f93f33
chore: add docstring
vividf Mar 10, 2026
e3717ca
chore: refactor export compenent - centerpoint
vividf Mar 10, 2026
687480a
chore: fix more Device spec - centerpoint
vividf Mar 10, 2026
e6864cf
chore: fix
vividf Mar 10, 2026
55b37ca
chore: add more docstring
vividf Mar 10, 2026
dc5d845
chore: change file name
vividf Mar 10, 2026
2c0c385
chore: remove redundant check
vividf Mar 10, 2026
b8a7452
chore: orangize directory
vividf Mar 11, 2026
d098217
chore: rename sample file
vividf Mar 11, 2026
bbd3b42
chore: remove init
vividf Mar 11, 2026
a439d98
chore: add init back
vividf Mar 11, 2026
e449752
chore: fix trt verification
vividf Mar 25, 2026
eb982e3
chore: update deploy config
vividf Mar 25, 2026
8b47dfc
chore: fix more deploy config
vividf Mar 25, 2026
02a2fdb
chore: update deploy config
vividf Mar 25, 2026
c93d54a
chore: for loop for clean code
vividf Mar 25, 2026
f36ab0e
chore: remove duplicate code
vividf Mar 25, 2026
a6bdb02
chore: clean up sample adapter
vividf Mar 25, 2026
465e1de
clean code
vividf Mar 26, 2026
f73cdaa
chore: clean up centerpoint
vividf Mar 27, 2026
15ba5d4
chore: replace pring to logging
vividf Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions deployment/projects/centerpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""CenterPoint deployment bundle.

This package owns all CenterPoint deployment-specific code (runner/evaluator/loader/pipelines/export).
It registers a ProjectAdapter into the global `project_registry` so the unified CLI can invoke it.
"""

from __future__ import annotations

from deployment.projects.centerpoint.cli import add_args
from deployment.projects.centerpoint.entrypoint import run

# Trigger pipeline factory registration for this project.
from deployment.projects.centerpoint.pipelines.factory import CenterPointPipelineFactory # noqa: F401
from deployment.projects.registry import ProjectAdapter, project_registry

project_registry.register(
ProjectAdapter(
name="centerpoint",
add_args=add_args,
run=run,
)
)
14 changes: 14 additions & 0 deletions deployment/projects/centerpoint/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""CenterPoint CLI extensions."""

from __future__ import annotations

import argparse


def add_args(parser: argparse.ArgumentParser) -> None:
"""Register CenterPoint-specific CLI flags onto a project subparser."""
parser.add_argument(
"--rot-y-axis-reference",
action="store_true",
help="Convert rotation to y-axis clockwise reference (CenterPoint ONNX-compatible format)",
)
54 changes: 33 additions & 21 deletions deployment/projects/centerpoint/config/deploy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
CenterPoint Deployment Configuration
"""

# ============================================================================
# Task type for pipeline building
# Options: 'detection2d', 'detection3d', 'classification', 'segmentation'
# ============================================================================
task_type = "detection3d"

# ============================================================================
# Checkpoint Path - Single source of truth for PyTorch model
# ============================================================================
checkpoint_path = "work_dirs/centerpoint/best_checkpoint.pth"

# Log file path (relative paths are under export.work_dir). Set to None to disable file logging.
deploy_log_path = "deployment.log"

# ============================================================================
# Device settings (shared by export, evaluation, verification)
# ============================================================================
Expand All @@ -21,33 +18,37 @@
cuda="cuda:0",
)

# Single literal for deployment output root (used before `export` exists).
_DEPLOY_WORK_DIR = "work_dirs/centerpoint_deployment"
_WORK_DIR = _DEPLOY_WORK_DIR.rstrip("/")
_ONNX_DIR = f"{_WORK_DIR}/onnx"
_TENSORRT_DIR = f"{_WORK_DIR}/tensorrt"

# ============================================================================
# Export Configuration
# mode: "onnx", "trt", "both", "none"
# work_dir: path to the deployment output root
# onnx_path: path to the ONNX output directory (if mode="trt" and ONNX already exists)
# ============================================================================
export = dict(
mode="both",
work_dir="work_dirs/centerpoint_deployment",
onnx_path=None,
work_dir=_DEPLOY_WORK_DIR,
onnx_path=_ONNX_DIR,
)

# Derived artifact directories
_WORK_DIR = str(export["work_dir"]).rstrip("/")
_ONNX_DIR = f"{_WORK_DIR}/onnx"
_TENSORRT_DIR = f"{_WORK_DIR}/tensorrt"

# ============================================================================
# Unified Component Configuration (Single Source of Truth)
#
# Component key is the unique identifier (used for config lookup, filenames, logs).
# Each component defines:
# - name: Component identifier used in export
# - onnx_file: Output ONNX filename
# - engine_file: Output TensorRT engine filename
# - io: Input/output specification for ONNX export
# - tensorrt_profile: TensorRT optimization profile (min/opt/max shapes)
# ============================================================================
components = dict(
voxel_encoder=dict(
name="pts_voxel_encoder",
pts_voxel_encoder=dict(
onnx_file="pts_voxel_encoder.onnx",
engine_file="pts_voxel_encoder.engine",
io=dict(
Expand All @@ -64,14 +65,14 @@
),
tensorrt_profile=dict(
input_features=dict(
# Make sure to match the shape of the input to the model
min_shape=[1000, 32, 11],
opt_shape=[20000, 32, 11],
max_shape=[64000, 32, 11],
max_shape=[96000, 32, 11],
),
),
),
backbone_head=dict(
name="pts_backbone_neck_head",
pts_backbone_neck_head=dict(
onnx_file="pts_backbone_neck_head.onnx",
engine_file="pts_backbone_neck_head.engine",
io=dict(
Expand All @@ -98,6 +99,8 @@
),
tensorrt_profile=dict(
spatial_features=dict(
# Make sure to match the shape of the input to the model
# check grid size in the model config
min_shape=[1, 32, 1020, 1020],
opt_shape=[1, 32, 1020, 1020],
max_shape=[1, 32, 1020, 1020],
Expand All @@ -111,8 +114,8 @@
# ============================================================================
runtime_io = dict(
# This should be a path relative to `data_root` in the model config.
info_file="info/t4dataset_j6gen2_infos_val.pkl",
sample_idx=1,
info_file="info/t4dataset_j6gen2_base_infos_test.pkl",
sample_idx=5,
)

# ============================================================================
Expand All @@ -128,6 +131,7 @@

# ============================================================================
# TensorRT Build Settings (shared across all components)
# Supports `auto`, `fp16`, `fp32_tf32`, and `strongly_typed`
# ============================================================================
tensorrt_config = dict(
precision_policy="auto",
Expand Down Expand Up @@ -161,10 +165,18 @@

# ============================================================================
# Verification Configuration
#
# Tolerance is backend- and machine-dependent:
# - The same scenario can show very different max/mean diffs on different machines: GPU
# architecture, driver, ORT/CUDA/TRT versions, and ORT's CUDA graph partitioning (CPU
# fallback nodes for small ops) all change numerics. ONNX on CPU, ONNX on CUDA, and
# TensorRT on CUDA are not directly comparable to each other as "one true" references.
# - Additionally, the verification configuration should use a precision-aware tolerance,
# especially when FP16 is enabled.
# ============================================================================
verification = dict(
enabled=False,
tolerance=1e-1,
tolerance=1,
num_verify_samples=1,
devices=devices,
scenarios=dict(
Expand Down
84 changes: 84 additions & 0 deletions deployment/projects/centerpoint/entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""CenterPoint deployment entrypoint invoked by the unified CLI."""

from __future__ import annotations

import argparse
import logging

from mmengine.config import Config

from deployment.cli.args import add_deployment_file_logging, setup_logging
from deployment.configs.base import BaseDeploymentConfig
from deployment.core.contexts import CenterPointExportContext
from deployment.projects.centerpoint.eval.evaluator import CenterPointEvaluator
from deployment.projects.centerpoint.eval.metrics_utils import extract_t4metric_v2_config
from deployment.projects.centerpoint.io.data_loader import CenterPointDataLoader
from deployment.projects.centerpoint.runner import CenterPointDeploymentRunner

_REQUIRED_COMPONENTS = ("pts_voxel_encoder", "pts_backbone_neck_head")


def _validate_required_components(components_cfg) -> None:
"""Validate that all CenterPoint required components exist in the config.

Args:
components_cfg: Components config with get_component(name).

Raises:
KeyError or similar: If any of _REQUIRED_COMPONENTS is missing.
"""
for component_name in _REQUIRED_COMPONENTS:
components_cfg.get_component(component_name)


def run(args: argparse.Namespace) -> int:
"""Run the CenterPoint deployment workflow for the unified CLI.

Args:
args: Parsed command-line arguments containing deploy_cfg and model_cfg paths.

Returns:
Exit code (0 for success).
"""
logger = setup_logging(args.log_level)

deploy_cfg = Config.fromfile(args.deploy_cfg)
model_cfg = Config.fromfile(args.model_cfg)
config = BaseDeploymentConfig(deploy_cfg)

log_file = config.resolved_deploy_log_file
if log_file:
add_deployment_file_logging(log_file)
logger.info("Deployment log file: %s", log_file)

_validate_required_components(config.components_cfg)

logger.info("=" * 80)
logger.info("CenterPoint Deployment Pipeline")
logger.info("=" * 80)

data_loader = CenterPointDataLoader(
info_file=config.runtime_config.info_file,
model_cfg=model_cfg,
)
logger.info(f"Loaded {data_loader.num_samples} samples")

metrics_config = extract_t4metric_v2_config(model_cfg, logger=logger)

evaluator = CenterPointEvaluator(
model_cfg=model_cfg,
metrics_config=metrics_config,
components_cfg=config.components_cfg,
)

runner = CenterPointDeploymentRunner(
data_loader=data_loader,
evaluator=evaluator,
config=config,
model_cfg=model_cfg,
logger=logger,
)

context = CenterPointExportContext(rot_y_axis_reference=bool(getattr(args, "rot_y_axis_reference", False)))
runner.run(context=context)
return 0
Empty file.
Loading