Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ build/
*.so
.eggs/
demo_render/
outputs/
CLAUDE.md
.claude/
.agents/
skyseg.onnx
skyseg.onnx
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,19 @@ python demo.py --model_path /path/to/checkpoint.pt \
| `--point_size` | `0.00001` | Point cloud point size |
| `--downsample_factor` | `10` | Spatial downsampling for point cloud display |

### Saving Predictions for Later Visualization

Use `--save_predictions` to keep a processed reconstruction and reopen it later
without rerunning model inference.

```bash
python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ \
--save_predictions outputs/scene.pt

python demo.py --load_predictions outputs/scene.pt
```

### Without FlashInfer (SDPA fallback)

```bash
Expand Down
81 changes: 63 additions & 18 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,54 @@ def prepare_for_visualization(predictions, images=None):
return vis_predictions


def save_predictions(path, predictions, images, image_folder=None):
"""Save postprocessed predictions for later visualization."""
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
torch.save(
{
"predictions": predictions,
"images": images,
"image_folder": image_folder,
},
path,
)
print(f"Saved predictions to {path}")


def load_saved_predictions(path):
"""Load predictions saved by --save_predictions."""
payload = torch.load(path, map_location="cpu", weights_only=False)
if not isinstance(payload, dict) or "predictions" not in payload:
raise ValueError(f"{path} is not a saved LingBot-Map predictions file")
return (
payload["predictions"],
payload.get("images"),
payload.get("image_folder"),
)


def launch_viewer(args, predictions, images_cpu, resolved_image_folder):
"""Launch the Viser point-cloud viewer for postprocessed predictions."""
try:
from lingbot_map.vis import PointCloudViewer
viewer = PointCloudViewer(
pred_dict=prepare_for_visualization(predictions, images_cpu),
port=args.port,
vis_threshold=args.conf_threshold,
downsample_factor=args.downsample_factor,
point_size=args.point_size,
mask_sky=args.mask_sky,
image_folder=resolved_image_folder,
sky_mask_dir=args.sky_mask_dir,
sky_mask_visualization_dir=args.sky_mask_visualization_dir,
)
print(f"3D viewer at http://localhost:{args.port}")
viewer.run()
except ImportError:
print("viser not installed. Install with: pip install lingbot-map[vis]")
print(f"Predictions contain keys: {list(predictions.keys())}")


# =============================================================================
# Main
# =============================================================================
Expand Down Expand Up @@ -290,8 +338,20 @@ def main():
help="Save sky mask visualizations (original | mask | overlay) to this directory")
parser.add_argument("--export_preprocessed", type=str, default=None,
help="Export stride-sampled, resized/cropped images to this folder")
parser.add_argument("--save_predictions", type=str, default=None,
help="Save postprocessed predictions to this .pt file before visualization")
parser.add_argument("--load_predictions", type=str, default=None,
help="Load predictions saved with --save_predictions and launch visualization only")

args = parser.parse_args()
if args.load_predictions:
if args.video_path is not None:
raise ValueError("--load_predictions cannot be combined with --video_path")
predictions, images_cpu, saved_image_folder = load_saved_predictions(args.load_predictions)
resolved_image_folder = args.image_folder or saved_image_folder
launch_viewer(args, predictions, images_cpu, resolved_image_folder)
return

assert args.image_folder or args.video_path, \
"Provide --image_folder or --video_path"

Expand Down Expand Up @@ -397,26 +457,11 @@ def main():
images_for_post = images

predictions, images_cpu = postprocess(predictions, images_for_post)
if args.save_predictions:
save_predictions(args.save_predictions, predictions, images_cpu, resolved_image_folder)

# ── Visualize ────────────────────────────────────────────────────────────
try:
from lingbot_map.vis import PointCloudViewer
viewer = PointCloudViewer(
pred_dict=prepare_for_visualization(predictions, images_cpu),
port=args.port,
vis_threshold=args.conf_threshold,
downsample_factor=args.downsample_factor,
point_size=args.point_size,
mask_sky=args.mask_sky,
image_folder=resolved_image_folder,
sky_mask_dir=args.sky_mask_dir,
sky_mask_visualization_dir=args.sky_mask_visualization_dir,
)
print(f"3D viewer at http://localhost:{args.port}")
viewer.run()
except ImportError:
print("viser not installed. Install with: pip install lingbot-map[vis]")
print(f"Predictions contain keys: {list(predictions.keys())}")
launch_viewer(args, predictions, images_cpu, resolved_image_folder)


if __name__ == "__main__":
Expand Down