diff --git a/.gitignore b/.gitignore index cbcbae1..fdb76a2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,8 @@ build/ *.so .eggs/ demo_render/ +outputs/ CLAUDE.md .claude/ .agents/ -skyseg.onnx \ No newline at end of file +skyseg.onnx diff --git a/README.md b/README.md index d73b1c6..fbdeedb 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,19 @@ python demo.py --model_path /path/to/checkpoint.pt \ | `--point_size` | `0.00001` | Point cloud point size | | `--downsample_factor` | `10` | Spatial downsampling for point cloud display | +### Saving Predictions for Later Visualization + +Use `--save_predictions` to keep a processed reconstruction and reopen it later +without rerunning model inference. + +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ \ + --save_predictions outputs/scene.pt + +python demo.py --load_predictions outputs/scene.pt +``` + ### Without FlashInfer (SDPA fallback) ```bash diff --git a/demo.py b/demo.py index a8c8636..c2d0099 100644 --- a/demo.py +++ b/demo.py @@ -227,6 +227,54 @@ def prepare_for_visualization(predictions, images=None): return vis_predictions +def save_predictions(path, predictions, images, image_folder=None): + """Save postprocessed predictions for later visualization.""" + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + torch.save( + { + "predictions": predictions, + "images": images, + "image_folder": image_folder, + }, + path, + ) + print(f"Saved predictions to {path}") + + +def load_saved_predictions(path): + """Load predictions saved by --save_predictions.""" + payload = torch.load(path, map_location="cpu", weights_only=False) + if not isinstance(payload, dict) or "predictions" not in payload: + raise ValueError(f"{path} is not a saved LingBot-Map predictions file") + return ( + payload["predictions"], + payload.get("images"), + payload.get("image_folder"), + ) + + +def launch_viewer(args, predictions, images_cpu, resolved_image_folder): + """Launch the Viser point-cloud viewer for postprocessed predictions.""" + try: + from lingbot_map.vis import PointCloudViewer + viewer = PointCloudViewer( + pred_dict=prepare_for_visualization(predictions, images_cpu), + port=args.port, + vis_threshold=args.conf_threshold, + downsample_factor=args.downsample_factor, + point_size=args.point_size, + mask_sky=args.mask_sky, + image_folder=resolved_image_folder, + sky_mask_dir=args.sky_mask_dir, + sky_mask_visualization_dir=args.sky_mask_visualization_dir, + ) + print(f"3D viewer at http://localhost:{args.port}") + viewer.run() + except ImportError: + print("viser not installed. Install with: pip install lingbot-map[vis]") + print(f"Predictions contain keys: {list(predictions.keys())}") + + # ============================================================================= # Main # ============================================================================= @@ -290,8 +338,20 @@ def main(): help="Save sky mask visualizations (original | mask | overlay) to this directory") parser.add_argument("--export_preprocessed", type=str, default=None, help="Export stride-sampled, resized/cropped images to this folder") + parser.add_argument("--save_predictions", type=str, default=None, + help="Save postprocessed predictions to this .pt file before visualization") + parser.add_argument("--load_predictions", type=str, default=None, + help="Load predictions saved with --save_predictions and launch visualization only") args = parser.parse_args() + if args.load_predictions: + if args.video_path is not None: + raise ValueError("--load_predictions cannot be combined with --video_path") + predictions, images_cpu, saved_image_folder = load_saved_predictions(args.load_predictions) + resolved_image_folder = args.image_folder or saved_image_folder + launch_viewer(args, predictions, images_cpu, resolved_image_folder) + return + assert args.image_folder or args.video_path, \ "Provide --image_folder or --video_path" @@ -397,26 +457,11 @@ def main(): images_for_post = images predictions, images_cpu = postprocess(predictions, images_for_post) + if args.save_predictions: + save_predictions(args.save_predictions, predictions, images_cpu, resolved_image_folder) # ── Visualize ──────────────────────────────────────────────────────────── - try: - from lingbot_map.vis import PointCloudViewer - viewer = PointCloudViewer( - pred_dict=prepare_for_visualization(predictions, images_cpu), - port=args.port, - vis_threshold=args.conf_threshold, - downsample_factor=args.downsample_factor, - point_size=args.point_size, - mask_sky=args.mask_sky, - image_folder=resolved_image_folder, - sky_mask_dir=args.sky_mask_dir, - sky_mask_visualization_dir=args.sky_mask_visualization_dir, - ) - print(f"3D viewer at http://localhost:{args.port}") - viewer.run() - except ImportError: - print("viser not installed. Install with: pip install lingbot-map[vis]") - print(f"Predictions contain keys: {list(predictions.keys())}") + launch_viewer(args, predictions, images_cpu, resolved_image_folder) if __name__ == "__main__":