diff --git a/.gitignore b/.gitignore index 544d455..dbb7ff0 100755 --- a/.gitignore +++ b/.gitignore @@ -206,4 +206,13 @@ dmypy.json # Pyre type checker .pyre/ -learnableearthparser/fast_sampler/_sampler.c \ No newline at end of file +learnableearthparser/fast_sampler/_sampler.c +/.idea/inspectionProfiles/profiles_settings.xml +/.idea/.gitignore +/.idea/MILo_rtx50.iml +/.idea/misc.xml +/.idea/modules.xml +/.idea/vcs.xml +milo/data/* +!milo/data/.gitkeep +/milo/runs/* diff --git a/README.md b/README.md old mode 100755 new mode 100644 index b9ced4a..4c35562 --- a/README.md +++ b/README.md @@ -1,654 +1,484 @@ -
+# MILo_rtx50 — CUDA 12.8 / RTX 50 Series Local Compilation and Execution Guide (Ubuntu 24.04 + uv + PyTorch 2.7.1+cu128) -

-MILo: Mesh-In-the-Loop Gaussian Splatting for Detailed and Efficient Surface Reconstruction -

+> This README documents our **local compilation adaptation, key modifications, and reproducible experimental steps** for the forked project [Anttwo/MILo](https://github.com/Anttwo/MILo) in the **RTX 50 series + CUDA 12.8** environment. +> Goal: Complete submodule compilation, training, mesh extraction, rendering, and evaluation using **uv + venv** without Conda. - -SIGGRAPH Asia 2025 - Journal Track (TOG)
-
+--- - -Antoine Guédon*1  -Diego Gomez*1  -Nissim Maruani2
-Bingchen Gong1  -George Drettakis2  -Maks Ovsjanikov1  -
-
+## Environment - -1Ecole polytechnique, France
-2Inria, Université Côte d'Azur, France
-
+- OS: Ubuntu 24.04 +- GPU: RTX 50 Series (Blackwell) +- CUDA Toolkit: 12.8 (NVCC `/usr/local/cuda-12.8/bin/nvcc`) +- Python: 3.12.3 (venv management, package management with **uv**) +- PyTorch: **2.7.1+cu128** (official binary, **C++11 ABI=1**) +- C/C++: GCC 13.3 +- CMake: System version (apt) +- Important environment variables (commonly used during training/extraction/rendering): + ```bash + export NVDIFRAST_BACKEND=cuda + export TORCH_CUDA_ARCH_LIST="12.0+PTX" + export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:32,garbage_collection_threshold:0.6" + export CUDA_DEVICE_MAX_CONNECTIONS=1 + # Mesh regularization grid resolution scaling, smaller value saves VRAM + export MILO_MESH_RES_SCALE=0.3 + # (Optional) Triangle chunk size to mitigate nvdiffrast CUDA backend VRAM peaks + export MILO_RAST_TRI_CHUNK=150000 + ``` - -*Both authors contributed equally to the paper. - +--- -| Webpage | arXiv | Presentation video | Data | +## Submodules Installation -![Teaser image](assets/teaser.png) -
+> We have verified these can be successfully compiled/installed with CUDA 12.8 + PyTorch 2.7.1. -## Abstract +### 1) Install Gaussian Splatting Submodules (via **pip**) -_Our method introduces a novel differentiable mesh extraction framework that operates during the optimization of 3D Gaussian Splatting representations. At every training iteration, we differentiably extract a mesh—including both vertex locations and connectivity—only from Gaussian parameters. This enables gradient flow from the mesh to Gaussians, allowing us to promote bidirectional consistency between volumetric (Gaussians) and surface (extracted mesh) representations. This approach guides Gaussians toward configurations better suited for surface reconstruction, resulting in higher quality meshes with significantly fewer vertices. Our framework can be plugged into any Gaussian splatting representation, increasing performance while generating an order of magnitude fewer mesh vertices. MILo makes the reconstructions more practical for downstream applications like physics simulations and animation._ - -## To-do List - -- ⬛ Implement a simple training viewer using the GraphDeco viewer. -- ⬛ Add the mesh-based rendering evaluation scripts in `./milo/eval/mesh_nvs`. -- ✅ Add low-res and very-low-res training for light output meshes (under 50MB and under 20MB). -- ✅ Add T&T evaluation scripts in `./milo/eval/tnt/`. -- ✅ Add Blender add-on (for mesh-based editing and animation) to the repo. -- ✅ Clean code. -- ✅ Basic refacto. +```bash +pip install submodules/diff-gaussian-rasterization_ms +pip install submodules/diff-gaussian-rasterization +pip install submodules/diff-gaussian-rasterization_gof +pip install submodules/simple-knn +pip install submodules/fused-ssim +``` -## License +> Note: `nvdiffrast` uses **JIT compilation** (triggered at runtime by PyTorch cpp_extension). +> If choosing **OpenGL(GL) backend**, system headers are required: `sudo apt install -y libegl-dev libopengl-dev libgles2-mesa-dev ninja-build`. +> We switched to **CUDA backend** for simplicity: `export NVDIFRAST_BACKEND=cuda` (no EGL headers needed). -
-Click here to see content. +### 2) Install System Dependencies for `tetra_triangulation` (Delaunay Triangulation) -
This project builds on existing open-source implementations of various projects cited in the __Acknowledgements__ section. +The original project uses **conda** to install system-level C/C++ dependencies (cmake/gmp/cgal). Since we use **uv** for Python packages only, we need to install these C/C++ libraries via **apt** (system package manager): -Specifically, it builds on the original implementation of [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting); As a result, parts of this code are licensed under the Gaussian-Splatting License (see `./LICENSE.md`). +```bash +# Install C/C++ dependencies via apt (Ubuntu 24.04) +sudo apt update +sudo apt install -y \ + build-essential \ + cmake ninja-build \ + libgmp-dev libmpfr-dev libcgal-dev \ + libboost-all-dev + +# (Optional) May be needed: +# sudo apt install -y libeigen3-dev +``` -This codebase also builds on various other repositories such as [Nvdiffrast](https://github.com/NVlabs/nvdiffrast); Please refer to the license files of the submodules for more details. -
+**Notes:** +- `libcgal-dev` provides CGAL headers (header-only on Ubuntu 24.04) +- `libgmp-dev` and `libmpfr-dev` are numerical backends for CGAL +- **uv only manages Python packages**; C/C++ dependencies must be installed via system package managers (apt/brew/pacman) +- For **macOS**: `brew install cmake cgal gmp mpfr boost` +- For **Arch Linux**: `sudo pacman -S cgal gmp mpfr boost cmake base-devel` -## 0. Quickstart +### 3) Compile `tetra_triangulation` with ABI Alignment -
-Click here to see content. +> **Important:** This module requires ABI alignment with PyTorch 2.7.1 (C++11 ABI=1). We use a header file approach to enforce this. -Please start by creating or downloading a COLMAP dataset, such as our COLMAP run for the Ignatius scene from the Tanks&Temples dataset. You can move the Ignatius directory to `./milo/data`. +**a) Create ABI enforcement header:** -After installing MILo as described in the next section, you can reconstruct a surface mesh from images by going to the `./milo/` directory and running the following commands: +Create `submodules/tetra_triangulation/src/force_abi.h`: +```cpp +#pragma once +// Force new ABI before any STL headers +#if defined(_GLIBCXX_USE_CXX11_ABI) +# undef _GLIBCXX_USE_CXX11_ABI +#endif +#define _GLIBCXX_USE_CXX11_ABI 1 +``` -```bash -# Training for an outdoor scene -python train.py -s ./data/Ignatius -m ./output/Ignatius --imp_metric outdoor --rasterizer radegs +**b) Modify source files:** -# Saves mesh as PLY with vertex colors after training -python mesh_extract_sdf.py -s ./data/Ignatius -m ./output/Ignatius --rasterizer radegs -``` -Please change `--imp_metric outdoor` to `--imp_metric indoor` if your scene is indoor. +Add `#include "force_abi.h"` as the **first line** of: +- `submodules/tetra_triangulation/src/py_binding.cpp` +- `submodules/tetra_triangulation/src/triangulation.cpp` -These commands use the lightest version of our approach, resulting in a small number of Gaussians and a light mesh. You can increase the number of Gaussians by adding `--dense_gaussians`, and improve the robustness to exposure variations with `--decoupled_appearance` as follows: +**c) Build and install:** ```bash -# Training with dense gaussians and better appearance model -python train.py -s ./data/Ignatius -m ./output/Ignatius --imp_metric outdoor --rasterizer radegs --dense_gaussians --decoupled_appearance - -# Saves mesh as PLY with vertex colors after training -python mesh_extract_sdf.py -s ./data/Ignatius -m ./output/Ignatius --rasterizer radegs +cd submodules/tetra_triangulation +rm -rf build CMakeCache.txt CMakeFiles tetranerf/utils/extension/tetranerf_cpp_extension*.so + +# Point to current PyTorch's CMake prefix/dynamic library path +export CMAKE_PREFIX_PATH="$(python - <<'PY' +import torch; print(torch.utils.cmake_prefix_path) +PY +)" +export TORCH_LIB_DIR="$(python - <<'PY' +import os, torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib')) +PY +)" +export LD_LIBRARY_PATH="$TORCH_LIB_DIR:$LD_LIBRARY_PATH" + +cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" . +cmake --build . -j"$(nproc)" + +# Install (optional, convenient for editable reference) +uv pip install -e . +cd ../../ ``` -Please refer to the following sections for additional details on our training and mesh extraction scripts, including: -- How to use other rasterizers -- How to train MILo with high-resolution meshes -- Various mesh extraction methods -- How to easily integrate MILo's differentiable GS-to-mesh pipeline to your own GS project -
+> **Note:** For troubleshooting ABI issues, see **Key Issue 1** section below. +--- -## 1. Installation +## Key Issue 1: `tetra_triangulation` ABI Mismatch (Resolved) -
-Click here to see content. +**Symptom** +Running `from tetranerf.utils import extension as ext` throws error: -### Clone this repository. -```bash -git clone https://github.com/Anttwo/MILo.git --recursive ``` - -### Install dependencies. - -Please start by creating an environment: -```bash -conda create -n milo python=3.9 -conda activate milo +undefined symbol: _ZN3c106detail14torchCheckFailEPKcS2_jRKSs ``` -Then, specify your own CUDA paths depending on your CUDA version: -```bash -# You can specify your own cuda path (depending on your CUDA version) -export CPATH=/usr/local/cuda-11.8/targets/x86_64-linux/include:$CPATH -export LD_LIBRARY_PATH=/usr/local/cuda-11.8/targets/x86_64-linux/lib:$LD_LIBRARY_PATH -export PATH=/usr/local/cuda-11.8/bin:$PATH -``` +The trailing `RKSs` indicates **old ABI (_GLIBCXX_USE_CXX11_ABI=0)**, while our PyTorch 2.7.1 uses **new ABI (=1)**. -Finally, you can run the following script to install all dependencies, including PyTorch and Gaussian Splatting submodules: -```bash -python install.py -``` -By default, the environment will be installed for CUDA 11.8. If using CUDA 12.1, you can provide the argument `--cuda_version 12.1` to `install.py`. **Please note that only CUDA 11.8 has been tested.** +**Fix** +Add a header file to force ABI in `submodules/tetra_triangulation` to **stably lock ABI=1**: -If you encounter problems or if the installation script does not work, please follow the detailed installation steps below. +* Create file: `src/force_abi.h` -
-Click here for detailed installation instructions + ```cpp + #pragma once + // Force new ABI before any STL headers + #if defined(_GLIBCXX_USE_CXX11_ABI) + # undef _GLIBCXX_USE_CXX11_ABI + #endif + #define _GLIBCXX_USE_CXX11_ABI 1 + ``` -```bash -# For CUDA 11.8 -conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 mkl=2023.1.0 -c pytorch -c nvidia +* Modify: Add `#include "force_abi.h"` as the **first line** of `src/py_binding.cpp` and `src/triangulation.cpp` -# For CUDA 12.1 (The code has only been tested on CUDA 11.8) -conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 mkl=2023.1.0 -c pytorch -c nvidia + ```cpp + #include "force_abi.h" + ``` -pip install -r requirements.txt +> **Note:** This header file approach is sufficient to enforce ABI=1. No additional CMakeLists.txt modifications are needed. -# Install submodules for Gaussian Splatting, including different rasterizers, aggressive densification, simplification, and utilities -pip install submodules/diff-gaussian-rasterization_ms -pip install submodules/diff-gaussian-rasterization -pip install submodules/diff-gaussian-rasterization_gof -pip install submodules/simple-knn -pip install submodules/fused-ssim +**Build Commands (in-source, outputs to package path)** -# Delaunay Triangulation from Tetra-Nerf +```bash cd submodules/tetra_triangulation -conda install cmake -conda install conda-forge::gmp -conda install conda-forge::cgal - -# You can specify your own cuda path (depending on your CUDA version) -export CPATH=/usr/local/cuda-11.8/targets/x86_64-linux/include:$CPATH -export LD_LIBRARY_PATH=/usr/local/cuda-11.8/targets/x86_64-linux/lib:$LD_LIBRARY_PATH -export PATH=/usr/local/cuda-11.8/bin:$PATH - -cmake . -make -pip install -e . -cd ../../ - -# Nvdiffrast for efficient mesh rasterization -cd ./submodules/nvdiffrast -pip install . -cd ../../ +rm -rf build CMakeCache.txt CMakeFiles tetranerf/utils/extension/tetranerf_cpp_extension*.so + +# Point to current PyTorch's CMake prefix/dynamic library path +export CMAKE_PREFIX_PATH="$(python - <<'PY' +import torch; print(torch.utils.cmake_prefix_path) +PY +)" +export TORCH_LIB_DIR="$(python - <<'PY' +import os, torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib')) +PY +)" +export LD_LIBRARY_PATH="$TORCH_LIB_DIR:$LD_LIBRARY_PATH" + +cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" . +cmake --build . -j"$(nproc)" + +# Install (optional, convenient for editable reference) +uv pip install -e . ``` -
+--- + +## Key Issue 2: nvdiffrast GL Backend Missing `EGL/egl.h` (Bypassed) + +* Option A: `sudo apt install -y libegl-dev libopengl-dev libgles2-mesa-dev` and continue with GL. +* **Option B (We adopted)**: Switch to **CUDA backend**: `export NVDIFRAST_BACKEND=cuda`, no EGL header dependency. + +--- + +## Key Issue 3: nvdiffrast CUDA Backend VRAM OOM (Resolved) + +**Symptom** +`cudaMalloc(&m_gpuPtr, bytes)` OOM (error: 2), especially during Mesh regularization phase. + +**Fix (Two Points)** + +1. Replace `nvdiff_rasterization` implementation in `milo/scene/mesh.py`: + + * Support **triangle chunking** (env variable `MILO_RAST_TRI_CHUNK` specifies chunk size) + * **Fix CUDA backend `ranges` must be on CPU** (`dr.rasterize(..., ranges=)`) + +
+ Modified function (click to expand) + + ```python + def nvdiff_rasterization( + camera, + image_height: int, + image_width: int, + verts: torch.Tensor, + faces: torch.Tensor, + return_indices_only: bool = False, + glctx=None, + return_rast_out: bool = False, + return_positions: bool = False, + ): + """ + Replacement version equivalent to original function, supports triangle chunking (env: MILO_RAST_TRI_CHUNK), + and fixes: nvdiffrast CUDA backend's `ranges` must be on CPU. + """ + import os + import torch + import nvdiffrast.torch as dr + + device = verts.device + dtype = verts.dtype -
+ cam_mtx = camera.full_proj_transform + pos = torch.cat([verts, torch.ones([verts.shape[0], 1], device=device, dtype=dtype)], dim=1) + pos = torch.matmul(pos, cam_mtx)[None] # [1,V,4] -## 2. Training with MILo + faces = faces.to(torch.int32).contiguous() + faces_dev = faces.to(pos.device) -
-Click here to see content. + H, W = int(image_height), int(image_width) + chunk = int(os.getenv("MILO_RAST_TRI_CHUNK", "0") or "0") + use_chunking = chunk > 0 and faces.shape[0] > chunk -First, go to the MILo folder: -```bash -cd milo -``` + if not use_chunking: + rast_out, _ = dr.rasterize(glctx, pos=pos, tri=faces_dev, resolution=[H, W]) + bary_coords = rast_out[..., :2] + zbuf = rast_out[..., 2] + pix_to_face = rast_out[..., 3].to(torch.int32) - 1 + if return_indices_only: + return pix_to_face + _out = (bary_coords, zbuf, pix_to_face) + if return_rast_out: + _out += (rast_out,) + if return_positions: + _out += (pos,) + return _out -Then, to optimize a Gaussian Splatting representation with MILo using a COLMAP dataset, you can run the following command: -```bash -python train.py \ - -s \ - -m \ - --imp_metric <"indoor" OR "outdoor"> \ - --rasterizer <"radegs" OR "gof"> -``` -The main arguments are the following: -| Argument | Values | Default | Description | -|----------|--------|---------|-------------| -| `--imp_metric` | `"indoor"` or `"outdoor"` | Required | Type of scene to optimize. Modifies the importance sampling to better handle indoor or outdoor scenes. | -| `--rasterizer` | `"radegs"` or `"gof"` | `"radegs"` | Rasterization technique used during training. | -| `--dense_gaussians` | flag | disabled | Use more Gaussians during training. When active, only a subset of Gaussians will generate pivots for Delaunay triangulation. When inactive, all Gaussians generate pivots.| - -You can use a dense set of Gaussians by adding the argument `--dense_gaussians`: -```bash -python train.py \ - -s \ - -m \ - --imp_metric <"indoor" OR "outdoor"> \ - --rasterizer <"radegs" OR "gof"> \ - --dense_gaussians \ - --data_device cpu -``` + z_ndc = (pos[..., 2:3] / (pos[..., 3:4] + 1e-20)).contiguous() -The list of optional arguments is provided below: -| Category | Argument | Values | Default | Description | -|----------|----------|---------|---------|-------------| -| **Performance & Logging** | `--data_device` | `"cpu"` or `"cuda"` | `"cuda"` | Forces data to be loaded on CPU (less GPU memory usage, slightly slower training) | -| | `--log_interval` | integer | - | Log images every N training iterations (e.g., `200`) | -| **Mesh Configuration** | `--mesh_config` | `"default"`, `"highres"`, `"veryhighres"`, `"lowres"`, `"verylowres"` | `"default"` | Config file for mesh resolution and quality | -| **Evaluation & Appearance** | `--eval` | flag | disabled | Performs the usual train/test split for evaluation | -| | `--decoupled_appearance` | flag | disabled | Better handling of exposure variations | -| **Depth-Order Regularization** | `--depth_order` | flag | disabled | Enable depth-order regularization with DepthAnythingV2 | -| | `--depth_order_config` | `"default"` or `"strong"` | `"default"` | Strength of depth regularization | - -You can change the config file used during training (useful for ablation runs) by -specifying `--mesh_config `. The different config files are the following: - -- **Default config**: The default config file name is `default`. This config results in -lighter representations and lower resolution meshes, containing around 2M Delaunay vertices -for the base setting and 5M Delaunay vertices for the `--dense_gaussians` setting. -- **High Res config**: You can use `--mesh_config highres --dense_gaussians` for higher -resolution meshes. We recommend using this config with `--dense_gaussians`. This config -results in higher resolution representations, containing up to 9M Delaunay vertices. -- **Very High Res config**: You can use `--mesh_config veryhighres --dense_gaussians` for -even higher resolution meshes. We recommend using this config with `--dense_gaussians`. -This config results in even higher resolution representations, containing up to 14M Delaunay -vertices. This configuration requires more memory for training. -- **Low Res config**: You can use `--mesh_config lowres` for lower resolution meshes (less than 50MB). -This config results in lower resolution representations, containing up to 500k Delaunay vertices. -You can adjust the number of Gaussians used during training accordingly by decreasing the sampling factor -with `--sampling_factor 0.3`, for instance. -- **Very Low Res config**: You can use `--mesh_config verylowres` for even lower resolution meshes (less than 20MB). -This config results in even lower resolution representations, containing up to 250k Delaunay vertices. -You can adjust the number of Gaussians used during training accordingly by decreasing the sampling factor -with `--sampling_factor 0.1`, for instance. - -Please refer to the DepthAnythingV2 repo to download the `vitl` checkpoint required for Depth-Order regularization. Then, move the checkpoint file to `./submodules/Depth-Anything-V2/checkpoints/`. - -### Example Commands - -Basic training for indoor scenes with logging: -```bash -python train.py -s -m --imp_metric indoor --rasterizer radegs --log_interval 200 -``` + best_rast, best_depth = None, None + n_faces, start = int(faces.shape[0]), 0 -Dense Gaussians with high resolution in outdoor scenes: -```bash -python train.py -s -m --imp_metric outdoor --rasterizer radegs --dense_gaussians --mesh_config highres --data_device cpu -``` + def _normalize_tri_id(rast_chunk, start_idx, count_idx): + tri_raw = rast_chunk[..., 3:4].to(torch.int64) + if tri_raw.numel() == 0: + return rast_chunk[..., 3:4] + maxid = int(tri_raw.max().item()) + if maxid == 0: + return rast_chunk[..., 3:4] + if maxid <= count_idx: + tri_adj = torch.where(tri_raw > 0, tri_raw + start_idx, tri_raw) + else: + tri_adj = tri_raw + return tri_adj.to(rast_chunk.dtype) -Full featured training with very high resolution: -```bash -python train.py -s -m --imp_metric indoor --rasterizer radegs --dense_gaussians --mesh_config veryhighres --decoupled_appearance --log_interval 200 --data_device cpu -``` + while start < n_faces: + count = min(chunk, n_faces - start) + # ranges must be on CPU + ranges_cpu = torch.tensor([[start, count]], device="cpu", dtype=torch.int32) -Very low resolution training in indoor scenes for very light meshes (less than 20MB): -```bash -python train.py -s -m --imp_metric indoor --rasterizer radegs --sampling_factor 0.1 --mesh_config verylowres -``` - -Training with depth-order regularization: -```bash -python train.py -s -m --imp_metric indoor --rasterizer radegs --depth_order --depth_order_config strong --log_interval 200 --data_device cpu -``` - -
+ rast_chunk, _ = dr.rasterize(glctx, pos=pos, tri=faces_dev, resolution=[H, W], ranges=ranges_cpu) + depth_chunk, _ = dr.interpolate(z_ndc, rast_chunk, faces_dev) + tri_id_adj = _normalize_tri_id(rast_chunk, start, count) -## 3. Extracting a Mesh after Optimization + if best_rast is None: + best_rast = torch.zeros_like(rast_chunk) + best_depth = torch.full_like(depth_chunk, float("inf")) + + hit = (tri_id_adj > 0) + prev_hit = (best_rast[..., 3:4] > 0) + closer = hit & (~prev_hit | (depth_chunk < best_depth)) + + rast_chunk = torch.cat([rast_chunk[..., :3], tri_id_adj], dim=-1) + + best_depth = torch.where(closer, depth_chunk, best_depth) + best_rast = torch.where(closer.expand_as(best_rast), rast_chunk, best_rast) + + start += count + + rast_out = best_rast + bary_coords = rast_out[..., :2] + zbuf = rast_out[..., 2] + pix_to_face = rast_out[..., 3].to(torch.int32) - 1 + + if return_indices_only: + return pix_to_face + + _output = (bary_coords, zbuf, pix_to_face) + if return_rast_out: + _output += (rast_out,) + if return_positions: + _output += (pos,) + return _output + ``` + +
+ +2. Reduce memory peak at runtime: + + * `MILO_MESH_RES_SCALE=0.3` (mesh regularization resolution scaling) + * `MILO_RAST_TRI_CHUNK=150000` (triangle chunk size) + * `--data_device cpu` (cameras/data on CPU) + +--- + +## Reproduction Steps (Ignatius) -
-Click here to see content. +> Data path: `./data/Ignatius` +> Output directory: `./output/Ignatius` -First go to `./milo/`. +### 1) Training -### 3.1. Use learned SDF values - -You can then use the following command: ```bash -python mesh_extract_sdf.py \ - -s \ - -m \ - --rasterizer <"radegs" OR "gof"> +cd milo +export NVDIFRAST_BACKEND=cuda +export TORCH_CUDA_ARCH_LIST="12.0+PTX" +export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:32,garbage_collection_threshold:0.6" +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export MILO_MESH_RES_SCALE=0.3 +export MILO_RAST_TRI_CHUNK=150000 + +python train.py -s ./data/Ignatius -m ./output/Ignatius \ + --imp_metric outdoor \ + --rasterizer gof \ + --mesh_config verylowres \ + --sampling_factor 0.2 \ + --data_device cpu \ + --log_interval 200 ``` -This script will further refine the SDF values for a short period of time (1000 iterations by default) with frozen Gaussians, then save the mesh as a PLY file with vertex colors. The mesh will be located at `/mesh_learnable_sdf.ply`. -**WARNING:** Make sure you use the same mesh config file as the one used during training. You can change the config file by specifying `--config `. The default config file name is `default`, but you can switch to `highres`, `veryhighres`, `lowres` or `verylowres`. +**Output** (located in `./output/Ignatius`) -You can use the usual train/test split by adding the argument `--eval`. +* Trained scene (Gaussians + learnable SDF and mesh regularization state, etc.) +* Logs and intermediate files (as configured by script, console prints training progress) -### 3.2. Use Integrated Opacity Field or scalable TSDF +### 2) Mesh Extraction (SDF) -To extract a mesh using the integrated opacity field as defined by the Gaussians in GOF and RaDe-GS, you can run the following command: ```bash -python mesh_extract_integration.py \ - -s \ - -m +python mesh_extract_sdf.py \ + -s ./data/Ignatius \ + -m ./output/Ignatius \ + --rasterizer gof \ + --config verylowres \ + --data_device cpu ``` -You can use the argument `--rasterizer ` to change the rasterization technique for computing the opacity field. Default is `gof`. We recommend using GOF in this context (even if RaDe-GS was used during training), as the opacity computation from GOF is more precise and will produce less surface erosion. -You can also use the argument `--sdf_mode <"integration" OR "depth_fusion">` to modify the SDF computation strategy. Default mode is `integration`, which uses the integrated opacity field. Please note that `depth_fusion` is not traditional TSDF performed on a regular grid, but our more efficient depth fusion strategy relying on the same Gaussian pivots as the ones used for `integration`. +**Output** -If using `integration`, you can modify the isosurface value with the argument `--isosurface_value `. The default value is 0.5. -```bash -python mesh_extract_integration.py \ - -s \ - -m \ - --rasterizer gof \ - --sdf_mode integration \ - --isosurface_value 0.5 -``` +* **`./output/Ignatius/mesh_learnable_sdf.ply`** (confirmed to open normally in MeshLab) -If using `depth_fusion`, you can modify the truncation margin with the argument `--trunc_margin `. If not provided, the value is automatically computed depending on the scale of the scene. We recommend not changing this value. -```bash -python mesh_extract_integration.py \ - -s \ - -m \ - --rasterizer gof \ - --sdf_mode depth_fusion \ - --trunc_margin 0.005 -``` +### 3) Rendering -The mesh will be saved at either `/mesh_integration_sdf.ply` or `/mesh_depth_fusion_sdf.ply` depending on the SDF computation method. - -
- -## 4. Using our differentiable Gaussians-to-Mesh pipeline in your own 3DGS project - -
-Click here to see content. -
- -In `milo.functional`, we provide straightforward functions to leverage our differentiable *Gaussians-to-Mesh pipeline* in your own 3DGS projects. - -These functions only require Gaussian parameters as inputs (`means`, `scales`, `rotations`, `opacities`) and can extract a mesh from these parameters in a differentiable manner, allowing for **performing differentiable operations on the surface mesh and backpropating gradients directly to the Gaussians**. - -We only assume that your own `Camera` class has the same structure as the class from the original [3DGS](https://github.com/graphdeco-inria/gaussian-splatting) implementation. - -Specifically, we propose the following functions: -- `sample_gaussians_on_surface`: This function samples Gaussians that are the most likely to be located on the surface of the scene. For more efficiency, we propose using only these Gaussians for generating pivots and applying Delaunay triangulation. -- `extract_gaussian_pivots`: This differentiable function builds pivots from the parameters of the sampled Gaussians. In practice, there is no need to explicitely call this function, as our other functions can recompute pivots on the fly. However, you might want to perform special treatment on the pivots. -- `compute_initial_sdf_values`: This function estimates initial truncated SDF values for any set of Gaussian pivots by performing depth-fusion over the provided viewpoints. You can directly provide the gaussian parameters to this function, in which case pivots will be computed on the fly. In the paper, we propose to learn optimal SDF values by maximizing the consistency between volumetric GS renderings and surface mesh renderings; We use this function to initialize the SDF values. -- `compute_delaunay_triangulation`: This function computes the Delaunay triangulation for a set of sampled Gaussians pivots. You can provide either pivots as inputs, or directly the parameters of the Gaussians (means, scales, rotations...), in which case the pivots will be recomputed on the fly. This function should not be applied at every training iteration as it is very slow, and the Delaunay graph does not change that much during training. -- `extract_mesh`: This differentiable function extracts the mesh from the Gaussian parameters, given a Delaunay triangulation and SDF values for the Gaussian pivots. - -We also propose additional functions such as `frustum_cull_mesh` which culls mesh vertices based on the view frustum of an input camera. - -We provide an example of how to use these functions below, using our codebase or any codebase following the same template as the original [3DGS](https://github.com/graphdeco-inria/gaussian-splatting) implementation. - -```python -from functional import ( - sample_gaussians_on_surface, - extract_gaussian_pivots, - compute_initial_sdf_values, - compute_delaunay_triangulation, - extract_mesh, - frustum_cull_mesh, -) - -# Load or initialize a 3DGS-like model and training cameras -gaussians = ... -train_cameras = ... - -# Define a simple wrapper for your Gaussian Splatting rendering function, -# following this template. It will be used only for initializing SDF values. -# The wrapper should accept just a camera as input, and return a dictionary -# with "render" and "depth" keys. -from gaussian_renderer.radegs import render_radegs -pipe = ... -background = torch.tensor([0., 0., 0.], device="cuda") -def render_func(view): - render_pkg = render_radegs( - viewpoint_camera=view, - pc=gaussians, - pipe=pipe, - bg_color=background, - kernel_size=0.0, - scaling_modifier = 1.0, - require_coord=False, - require_depth=True - ) - return { - "render": render_pkg["render"], - "depth": render_pkg["median_depth"], - } - -# Only the parameters of the Gaussians are needed for extracting the mesh. -means = gaussians.get_xyz -scales = gaussians.get_scaling -rotations = gaussians.get_rotation -opacities = gaussians.get_opacity - -# Sample Gaussians on the surface. -# Should be performed only once, or just once in a while. -# In this example, we sample at most 600_000 Gaussians. -surface_gaussians_idx = sample_gaussians_on_surface( - views=train_cameras, - means=means, - scales=scales, - rotations=rotations, - opacities=opacities, - n_max_samples=600_000, - scene_type='indoor', -) - -# Compute initial SDF values for pivots. Should be performed only once. -# In the paper, we propose to learn optimal SDF values by maximizing the -# consistency between volumetric renderings and surface mesh renderings. -initial_pivots_sdf = compute_initial_sdf_values( - views=train_cameras, - render_func=render_func, - means=means, - scales=scales, - rotations=rotations, - gaussian_idx=surface_gaussians_idx, -) - -# Compute Delaunay Triangulation. -# Can be performed once in a while. -delaunay_tets = compute_delaunay_triangulation( - means=means, - scales=scales, - rotations=rotations, - gaussian_idx=surface_gaussians_idx, -) - -# Differentiably extract a mesh from Gaussian parameters, including initial -# or updated SDF values for the Gaussian pivots. -# This function is differentiable with respect to the parameters of the Gaussians, -# as well as the SDF values. Can be performed at every training iteration. -mesh = extract_mesh( - delaunay_tets=delaunay_tets, - pivots_sdf=initial_pivots_sdf, - means=means, - scales=scales, - rotations=rotations, - gaussian_idx=surface_gaussians_idx, -) - -# You can now apply any differentiable operation on the extracted mesh, -# and backpropagate gradients back to the Gaussians! -# In the paper, we propose to use differentiable mesh rendering. -from scene.mesh import MeshRasterizer, MeshRenderer -renderer = MeshRenderer(MeshRasterizer(cameras=train_cameras)) - -# We cull the mesh based on the view frustum for more efficiency -i_view = np.random.randint(0, len(train_cameras)) -mesh_render_pkg = renderer( - frustum_cull_mesh(mesh, train_cameras[i_view]), - cam_idx=i_view, - return_depth=True, return_normals=True -) -mesh_depth = mesh_render_pkg["depth"] -mesh_normals = mesh_render_pkg["normals"] +```bash +python render.py \ + -m ./output/Ignatius \ + -s ./data/Ignatius \ + --rasterizer gof \ + --eval ``` -
- -## 5. Creating a COLMAP dataset with your own images - -
-Click here to see content. -
+**Output** -### 5.1. Estimate camera poses with COLMAP +* Rendered images (train/test views), saved to the rendering subdirectory in the model output directory (as indicated by script output) -Please first install a recent version of COLMAP (ideally CUDA-powered) and make sure to put the images you want to use in a directory `/input`. Then, run the script `milo/convert.py` from the original Gaussian splatting implementation to compute the camera poses for the images using COLMAP. Please refer to the original 3D Gaussian Splatting repository for more details. +### 4) Image Metrics -```shell -python milo/convert.py -s -``` - -Sometimes COLMAP fails to reconstruct all images into the same model and hence produces multiple sub-models. The smaller sub-models generally contain only a few images. However, by default, the script `convert.py` will apply Image Undistortion only on the first sub-model, which may contain only a few images. - -If this is the case, a simple solution is to keep only the largest sub-model and discard the others. To do this, open the source directory containing your input images, then open the sub-directory `/distorted/sparse/`. You should see several sub-directories named `0/`, `1/`, etc., each containing a sub-model. Remove all sub-directories except the one containing the largest files, and rename it to `0/`. Then, run the script `convert.py` one more time but skip the matching process: - -```shell -python milo/convert.py -s --skip_matching +```bash +python metrics.py -m ./output/Ignatius ``` -_Note: If the sub-models have common registered images, they could be merged into a single model as post-processing step using COLMAP; However, merging sub-models requires to run another global bundle adjustment after the merge, which can be time consuming._ - -### 5.2. Estimate camera poses with VGGT +**Output** -Coming soon. +* Console output of PSNR/SSIM (and corresponding files saved by repo script, located in model directory; based on actual implementation) -
+### 5) PLY Format Conversion (Optional) +The extracted PLY mesh can be converted to other common 3D formats (OBJ/GLB) using the `clean_convert_mesh.py` script for use in various 3D software. The script also provides optional mesh cleaning functionality. -## 6. Mesh-Based Editing and Animation with the MILo Blender Addon +**Install Additional Dependencies** +```bash +pip install pymeshlab trimesh plyfile +``` -
-Click here to see content. -
+**Basic Usage** +```bash +# Basic conversion (outputs PLY, OBJ, GLB) +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply -While MILo provides a differentiable solution for extracting meshes from 3DGS representations, it also implicitly encourages Gaussians to align with the surface of the mesh. As a result, any modification made to the mesh can be easily propagated to the Gaussians, making the reconstructed mesh an excellent proxy for editing and animating the 3DGS representation. +# Convert and simplify to 300k triangles +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --simplify 300000 -Similarly to previous works SuGaR and Gaussian Frosting, we provide a Blender addon allowing to combine, edit and animate 3DGS representations just by manipulating meshes reconstructed with MILo in Blender. +# Clean small components during conversion (default 0.02 = remove components with diameter < 2% bbox diagonal) +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --keep-components 0.02 -### 6.1. Installing the addon +# Output only specific formats +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --no-glb # Skip GLB +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --no-obj # Skip OBJ -1. Please start by installing `torch_geometric` and `torch_cluster` in your `milo` conda environment: -```shell -pip install torch_geometric -pip install torch_cluster +# Specify output directory and filename +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply \ + --out-dir ./output/Ignatius/converted \ + --stem mesh_final ``` -2. Then, install Blender (version 4.0.2 is recommended but not mandatory). - -3. Open Blender, and go to `Edit` > `Preferences` > `Add-ons` > `Install`, and select the file `milo_addon.py` located in `./milo/blender/`.
- -You have now installed the MILo addon for Blender! - -### 6.2. Usage +**Main Features** +- **Format Conversion**: Convert PLY to OBJ and GLB formats (suitable for different 3D software and Web display) +- **Optional Cleaning**: Remove duplicate vertices/faces, fix non-manifold edges, remove small floating components +- **Optional Simplification**: Shape-preserving simplification based on Quadric decimation -This Blender addon is almost identical as the SuGaR x Frosting Blender addon. You can refer to this previous repo for more details and illustrations. To combine, edit or animate scene with the addon, please follow the steps below: +**Output** (saved in input file directory by default) +* `mesh_clean.ply` - Converted PLY mesh (with vertex colors) +* `mesh_clean.obj` - OBJ format (Note: OBJ doesn't support vertex colors) +* `mesh_clean.glb` - GLB format (suitable for Web display and import into Blender/Unity etc.) -1. Please start by training Gaussians with MILo and extracting a mesh, as described in the Quickstart section. +--- -2. Open a new scene in Blender, and go to the `Render` tab in the Properties. You should see a panel named `Add MILo mesh`. The panel is not necessary at the top of the tab, so you may need to scroll down to find it. +## Our Modifications (Relative to Upstream) -3. **(a) Select a mesh.** Enter the path to the final mesh extracted with MILo in the `Path to mesh PLY` field. You can also click on the folder icon to select the file. The mesh should be located at `/mesh_learnable_sdf.ply`.

-**(b) Select a checkpoint.** Similarly, enter the path to the final checkpoint of the optimization in the `Path to 3DGS PLY` field. You can also click on the folder icon to select the file. The checkpoint should be located at `/point_cloud/iteration_18000/point_cloud.ply`.

-**(c) Load the mesh.** Finally, click on `Add mesh` to load the mesh in Blender. Feel free to rotate the mesh and change the shading mode to better visualize the mesh and its colors. +1. **`submodules/tetra_triangulation`** -4. **Now, feel free to edit your mesh using Blender!** -
You can segment it into different pieces, sculpt it, rig it, animate it using a parent armature, *etc*. You can also add other MILo meshes to the scene, and combine elements from different scenes.
-Feel free to set a camera in the scene and prepare an animation: You can animate the camera, the mesh, *etc*.
-Please avoid using `Apply Location`, `Apply Rotation`, or `Apply Scale` on the edited mesh, as we are still unsure how it will affect the correspondence between the mesh and the optimized checkpoint. + * Added `src/force_abi.h`, and `#include "force_abi.h"` at the first line of `src/py_binding.cpp` and `src/triangulation.cpp`: **Force use of C++11 new ABI (=1)** -5. Once you're done with your editing, you can prepare a rendering package ready to be rendered with Gaussians. To do so, go to the `Render` tab in the Properties again, and select the `./milo/` directory in the `Path to MILo directory` field.
-Finally, click on `Render Image` or `Render Animation` to save a rendering package for the scene.

-`Render Image` will render a single image of the scene, with the current camera position and mesh editions/poses.

-`Render Animation` will render a full animation of the scene, from the first frame to the last frame you set in the Blender Timeline. -

-The package should be saved as a `JSON` file and located in `./milo/blender/packages/`. +2. **`milo/scene/mesh.py`** -7. Finally, you can render the package with Gaussian Splatting. You just need to go to `./milo/` and run the following command: -```shell -python render_blender_scene.py \ - -p \ - --rasterizer <"radegs" or "gof">. -``` - -By default, renderings are saved in `./milo/blender/renders//`. However, you can change the output directory by adding `-o `. + * Replaced `nvdiff_rasterization`: -Please check the documentation of the `render_blender_scene.py` scripts for more information on the additional arguments. -If you get artifacts in the rendering, you can try to play with the various following hyperparameters: `binding_mode`, `filter_big_gaussians_with_th`, `clamp_big_gaussians_with_th`, and `filter_distant_gaussians_with_th`. + * Support **MILO_RAST_TRI_CHUNK** triangle chunking + * **Fixed CUDA backend `ranges` must be CPU Tensor** + * Other behavior remains consistent with original function -
+3. **Runtime Configuration** -## 7. Evaluation + * Default to nvdiffrast **CUDA** backend (`NVDIFRAST_BACKEND=cuda`), avoiding EGL dependency + * Specify `TORCH_CUDA_ARCH_LIST="12.0+PTX"` for Blackwell + * Reduce peak VRAM: `MILO_MESH_RES_SCALE=0.3`, `MILO_RAST_TRI_CHUNK=150000`, `--data_device cpu` -
-Click here to see content. -
+--- -For evaluation, please start by downloading [our COLMAP runs for the Tanks and Temples dataset](https://drive.google.com/drive/folders/1Bf7DM2DFtQe4J63bEFLceEycNf4qTcqm?usp=sharing), and make sure to move all COLMAP scene directories (Barn, Caterpillar, _etc._) inside the same directory. +## Common Issues and Troubleshooting -Then, please download ground truth point cloud, camera poses, alignments and cropfiles from [Tanks and Temples dataset](https://www.tanksandtemples.org/download/). The ground truth dataset should be organized as: -``` -GT_TNT_dataset -│ -└─── Barn -│ │ -| └─── Barn.json -│ │ -| └─── Barn.ply -│ │ -| └─── Barn_COLMAP_SfM.log -│ │ -| └─── Barn_trans.txt -│ -└─── Caterpillar -│ │ -...... -``` +* **`undefined symbol: ... torchCheckFail ... RKSs`** + This is an ABI=0 symbol; please recompile `tetra_triangulation` with the above patch. -We follow the exact same pipeline as GOF and RaDe-GS for evaluating MILo on T&T. Please go to `./milo/` and run the following script to run the full training and evaluation pipeline on all scenes: +* **`fatal error: EGL/egl.h: No such file or directory`** + If insisting on GL path: `sudo apt install -y libegl-dev libopengl-dev libgles2-mesa-dev ninja-build`; + Or directly use `export NVDIFRAST_BACKEND=cuda` for CUDA path. -```bash -python scripts/evaluate_tnt.py \ - --data_dir \ - --gt_dir \ - --output_dir \ - --rasterizer <"radegs" or "gof"> \ - --mesh_config <"default" or "highres" or "veryhighres"> -``` -You can add `--dense_gaussians` for using more Gaussians during training. Please note that `--dense_gaussians` will be automatically set to `True` if using `--mesh_config highres` or `--mesh_config veryhighres`. - -For evaluating only a single scene, you can run the following commands: - -```bash -# Training (you can add --dense_gaussians for higher performance) -python train.py \ - -s \ - -m \ - --imp_metric <"indoor" or "outdoor"> \ - --rasterizer <"radegs" or "gof"> \ - --mesh_config <"default" or "highres" or "veryhighres"> \ - --eval \ - --decoupled_appearance \ - --data_device cpu - -# Mesh extraction -python mesh_extract_sdf.py \ - -s \ - -m \ - --rasterizer <"radegs" or "gof"> \ - --config <"default" or "highres" or "veryhighres"> \ - --eval \ - --data_device cpu - -# Evaluation -python eval/tnt/run.py \ - --dataset-dir \ - --traj-path \ - --ply-path /recon.ply -``` - -### Novel View Synthesis -After training MILo on a scene with test/train split by using the argument `--eval`, you can evaluate the performance of the Novel View Synthesis by running the scripts below: - -```bash -python render.py \ - -m \ - -s \ - -- rasterizer <"radegs" or "gof"> - -python metrics.py -m # Compute error metrics on renderings -``` +* **nvdiffrast JIT compilation failure / wrong architecture** + Confirm `TORCH_CUDA_ARCH_LIST="12.0+PTX"` is exported, and clear cache: `rm -rf ~/.cache/torch_extensions`. -
+* **VRAM OOM** + Reduce `MILO_MESH_RES_SCALE` (e.g., 0.5 → 0.3 → 0.25), enable triangle chunking `MILO_RAST_TRI_CHUNK`, and use `--data_device cpu`. -## 8. Acknowledgements +--- -We build this project based on [Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [Mini-Splatting2](https://github.com/fatPeter/mini-splatting2). +## Results Summary (This Ignatius Pipeline) -We propose to use rasterization techniques from [RaDe-GS](https://baowenz.github.io/radegs/) and [GOF](https://github.com/autonomousvision/gaussian-opacity-fields/tree/main). +* **Training (`train.py`)**: Completed, output model directory `./output/Ignatius` (contains training state and logs). +* **Mesh Extraction (`mesh_extract_sdf.py`)**: Obtained **`mesh_learnable_sdf.ply`**, verified visualization in MeshLab. +* **Rendering (`render.py`)**: Obtained rendered images from train/test views (saved in rendering subdirectory of output directory). +* **Metrics (`metrics.py`)**: Console prints PSNR/SSIM (and saves to model directory, filename based on actual implementation). -The latter incorporate the filters proposed in [Mip-Splatting](https://github.com/autonomousvision/mip-splatting), the loss functions of [2D GS](https://github.com/hbb1/2d-gaussian-splatting) and its preprocessed DTU dataset. +> For Tanks&Temples evaluation, you can symlink `mesh_learnable_sdf.ply` as `recon.ply`, then run evaluation scripts. -We use [Nvdiffrast](https://github.com/NVlabs/nvdiffrast) for differentiable triangle rasterization, and [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) for computing our optional depth-order regularization loss relying on monocular depth estimation. +--- -The evaluation scripts for the Tanks and Temples dataset are sourced from [TanksAndTemples](https://github.com/isl-org/TanksAndTemples/tree/master/python_toolbox/evaluation). +## License and Acknowledgments -We thank the authors of all the above projects for their great works. +This repository is an adaptation and engineering supplement to the original MILo project in the **CUDA 12.8 / RTX 50** environment, retaining the original project license and attribution. Thanks to the original authors and all submodule authors (Tetra-NeRF, nvdiffrast, 3D Gaussian Splatting, etc.) for their excellent work. \ No newline at end of file diff --git a/README_CN.md b/README_CN.md new file mode 100644 index 0000000..c1056ea --- /dev/null +++ b/README_CN.md @@ -0,0 +1,484 @@ +# MILo_rtx50 — CUDA 12.8 / RTX 50 系列本地编译与运行记录(Ubuntu 24.04 + uv + PyTorch 2.7.1+cu128) + +> 本 README 记录我们在 **RTX 50 系列 + CUDA 12.8** 环境下,fork 项目 [Anttwo/MILo](https://github.com/Anttwo/MILo) 的**本地编译适配、关键修改与可复现实验步骤**。 +> 目标:无需 Conda,使用 **uv + venv** 完成子模块编译与训练、网格提取、渲染和评测。 + +--- + +## 环境 + +- OS:Ubuntu 24.04 +- GPU:RTX 50 系列(Blackwell) +- CUDA Toolkit:12.8(NVCC `/usr/local/cuda-12.8/bin/nvcc`) +- Python:3.12.3(venv 管理,包管理用 **uv**) +- PyTorch:**2.7.1+cu128**(官方二进制,**C++11 ABI=1**) +- C/C++:GCC 13.3 +- CMake:系统版本(apt) +- 重要环境变量(训练/提取/渲染时常用): + ```bash + export NVDIFRAST_BACKEND=cuda + export TORCH_CUDA_ARCH_LIST="12.0+PTX" + export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:32,garbage_collection_threshold:0.6" + export CUDA_DEVICE_MAX_CONNECTIONS=1 + # Mesh 正则化网格分辨率缩放,越小越省显存 + export MILO_MESH_RES_SCALE=0.3 + # (可选)按三角形分块的大小,缓解 nvdiffrast CUDA 后端显存峰值 + export MILO_RAST_TRI_CHUNK=150000 + ``` + +--- + +## Submodules 安装 + +> 这些我们已验证可在 CUDA 12.8 + PyTorch 2.7.1 下成功编译/安装。 + +### 1) 安装 Gaussian Splatting 子模块(通过 **pip**) + +```bash +pip install submodules/diff-gaussian-rasterization_ms +pip install submodules/diff-gaussian-rasterization +pip install submodules/diff-gaussian-rasterization_gof +pip install submodules/simple-knn +pip install submodules/fused-ssim +``` + +> 备注:`nvdiffrast` 走 **JIT 编译**(运行时由 PyTorch cpp_extension 触发)。 +> 若选择 **OpenGL(GL) 后端**,需要系统头:`sudo apt install -y libegl-dev libopengl-dev libgles2-mesa-dev ninja-build`。 +> 我们为省事改用 **CUDA 后端**:`export NVDIFRAST_BACKEND=cuda`(无需 EGL 头)。 + +### 2) 安装 `tetra_triangulation` 的系统依赖(Delaunay 三角剖分) + +原项目用 **conda** 安装系统级 C/C++ 依赖(cmake/gmp/cgal)。由于我们用 **uv** 只管 Python 包,需要通过 **apt**(系统包管理器)安装这些 C/C++ 库: + +```bash +# 用 apt 安装 C/C++ 依赖(Ubuntu 24.04) +sudo apt update +sudo apt install -y \ + build-essential \ + cmake ninja-build \ + libgmp-dev libmpfr-dev libcgal-dev \ + libboost-all-dev + +# (可选)可能用到: +# sudo apt install -y libeigen3-dev +``` + +**说明:** +- `libcgal-dev` 提供 CGAL 头文件(Ubuntu 24.04 上主要是 header-only) +- `libgmp-dev` 和 `libmpfr-dev` 是 CGAL 的数值后端 +- **uv 仅负责 Python 侧依赖**;像 CGAL/GMP/MPFR 这种 C/C++ 依赖必须走系统包管理器(apt、brew、pacman) +- **macOS**:`brew install cmake cgal gmp mpfr boost` +- **Arch Linux**:`sudo pacman -S cgal gmp mpfr boost cmake base-devel` + +### 3) 编译 `tetra_triangulation` 并对齐 ABI + +> **重要:** 该模块需要与 PyTorch 2.7.1(C++11 ABI=1)对齐 ABI。我们使用头文件方式来强制这一点。 + +**a) 创建 ABI 强制头文件:** + +创建 `submodules/tetra_triangulation/src/force_abi.h`: +```cpp +#pragma once +// 在任何 STL 头之前强制新 ABI +#if defined(_GLIBCXX_USE_CXX11_ABI) +# undef _GLIBCXX_USE_CXX11_ABI +#endif +#define _GLIBCXX_USE_CXX11_ABI 1 +``` + +**b) 修改源文件:** + +在以下文件的**第一行**加入 `#include "force_abi.h"`: +- `submodules/tetra_triangulation/src/py_binding.cpp` +- `submodules/tetra_triangulation/src/triangulation.cpp` + +**c) 构建与安装:** + +```bash +cd submodules/tetra_triangulation +rm -rf build CMakeCache.txt CMakeFiles tetranerf/utils/extension/tetranerf_cpp_extension*.so + +# 指向当前 PyTorch 的 CMake 前缀/动态库路径 +export CMAKE_PREFIX_PATH="$(python - <<'PY' +import torch; print(torch.utils.cmake_prefix_path) +PY +)" +export TORCH_LIB_DIR="$(python - <<'PY' +import os, torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib')) +PY +)" +export LD_LIBRARY_PATH="$TORCH_LIB_DIR:$LD_LIBRARY_PATH" + +cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" . +cmake --build . -j"$(nproc)" + +# 安装(可选,便于可编辑引用) +uv pip install -e . +cd ../../ +``` + +> **说明:** 如需排查 ABI 问题,请参阅下方**关键问题 1**部分。 + +--- + +## 关键问题 1:`tetra_triangulation` ABI 不匹配(已解决) + +**现象** +运行 `from tetranerf.utils import extension as ext` 报错: + +``` +undefined symbol: _ZN3c106detail14torchCheckFailEPKcS2_jRKSs +``` + +尾部 `RKSs` 表示 **老 ABI(_GLIBCXX_USE_CXX11_ABI=0)**,而我们的 PyTorch 2.7.1 使用 **新 ABI(=1)**。 + +**修复** +在 `submodules/tetra_triangulation` 中加入强制 ABI 的头文件,**稳定锁定 ABI=1**: + +* 创建文件:`src/force_abi.h` + + ```cpp + #pragma once + // 在任何 STL 头之前强制新 ABI + #if defined(_GLIBCXX_USE_CXX11_ABI) + # undef _GLIBCXX_USE_CXX11_ABI + #endif + #define _GLIBCXX_USE_CXX11_ABI 1 + ``` + +* 修改:在 `src/py_binding.cpp` 与 `src/triangulation.cpp` 的**第一行**加入 + + ```cpp + #include "force_abi.h" + ``` + +> **说明:** 使用这种头文件方式即可强制 ABI=1,无需额外修改 CMakeLists.txt。 + +**构建命令(就地 in-source,产物落到包路径)** + +```bash +cd submodules/tetra_triangulation +rm -rf build CMakeCache.txt CMakeFiles tetranerf/utils/extension/tetranerf_cpp_extension*.so + +# 指向当前 PyTorch 的 CMake 前缀/动态库路径 +export CMAKE_PREFIX_PATH="$(python - <<'PY' +import torch; print(torch.utils.cmake_prefix_path) +PY +)" +export TORCH_LIB_DIR="$(python - <<'PY' +import os, torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib')) +PY +)" +export LD_LIBRARY_PATH="$TORCH_LIB_DIR:$LD_LIBRARY_PATH" + +cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" . +cmake --build . -j"$(nproc)" + +# 安装(可选,便于可编辑引用) +uv pip install -e . +``` + +--- + +## 关键问题 2:nvdiffrast GL 后端缺少 `EGL/egl.h`(已绕过) + +* 方案 A:`sudo apt install -y libegl-dev libopengl-dev libgles2-mesa-dev` 后继续用 GL。 +* **方案 B(我们采用)**:切到 **CUDA 后端**:`export NVDIFRAST_BACKEND=cuda`,不依赖 EGL 头。 + +--- + +## 关键问题 3:nvdiffrast CUDA 后端显存 OOM(已解决) + +**现象** +`cudaMalloc(&m_gpuPtr, bytes)` OOM(error: 2),尤其在 Mesh 正则化阶段。 + +**修复(两点)** + +1. 在 `milo/scene/mesh.py` 中替换 `nvdiff_rasterization` 实现: + + * 支持**按三角形分块**(环境变量 `MILO_RAST_TRI_CHUNK` 指定块大小) + * **修正 CUDA 后端的 `ranges` 必须在 CPU**(`dr.rasterize(..., ranges=)`) + +
+ 修改后的函数(点击展开) + + ```python + def nvdiff_rasterization( + camera, + image_height: int, + image_width: int, + verts: torch.Tensor, + faces: torch.Tensor, + return_indices_only: bool = False, + glctx=None, + return_rast_out: bool = False, + return_positions: bool = False, + ): + """ + 与原函数等价的替换版,支持按三角形分块(env: MILO_RAST_TRI_CHUNK), + 并修正:nvdiffrast CUDA 后端的 `ranges` 必须在 CPU。 + """ + import os + import torch + import nvdiffrast.torch as dr + + device = verts.device + dtype = verts.dtype + + cam_mtx = camera.full_proj_transform + pos = torch.cat([verts, torch.ones([verts.shape[0], 1], device=device, dtype=dtype)], dim=1) + pos = torch.matmul(pos, cam_mtx)[None] # [1,V,4] + + faces = faces.to(torch.int32).contiguous() + faces_dev = faces.to(pos.device) + + H, W = int(image_height), int(image_width) + chunk = int(os.getenv("MILO_RAST_TRI_CHUNK", "0") or "0") + use_chunking = chunk > 0 and faces.shape[0] > chunk + + if not use_chunking: + rast_out, _ = dr.rasterize(glctx, pos=pos, tri=faces_dev, resolution=[H, W]) + bary_coords = rast_out[..., :2] + zbuf = rast_out[..., 2] + pix_to_face = rast_out[..., 3].to(torch.int32) - 1 + if return_indices_only: + return pix_to_face + _out = (bary_coords, zbuf, pix_to_face) + if return_rast_out: + _out += (rast_out,) + if return_positions: + _out += (pos,) + return _out + + z_ndc = (pos[..., 2:3] / (pos[..., 3:4] + 1e-20)).contiguous() + + best_rast, best_depth = None, None + n_faces, start = int(faces.shape[0]), 0 + + def _normalize_tri_id(rast_chunk, start_idx, count_idx): + tri_raw = rast_chunk[..., 3:4].to(torch.int64) + if tri_raw.numel() == 0: + return rast_chunk[..., 3:4] + maxid = int(tri_raw.max().item()) + if maxid == 0: + return rast_chunk[..., 3:4] + if maxid <= count_idx: + tri_adj = torch.where(tri_raw > 0, tri_raw + start_idx, tri_raw) + else: + tri_adj = tri_raw + return tri_adj.to(rast_chunk.dtype) + + while start < n_faces: + count = min(chunk, n_faces - start) + # ranges 必须在 CPU + ranges_cpu = torch.tensor([[start, count]], device="cpu", dtype=torch.int32) + + rast_chunk, _ = dr.rasterize(glctx, pos=pos, tri=faces_dev, resolution=[H, W], ranges=ranges_cpu) + depth_chunk, _ = dr.interpolate(z_ndc, rast_chunk, faces_dev) + tri_id_adj = _normalize_tri_id(rast_chunk, start, count) + + if best_rast is None: + best_rast = torch.zeros_like(rast_chunk) + best_depth = torch.full_like(depth_chunk, float("inf")) + + hit = (tri_id_adj > 0) + prev_hit = (best_rast[..., 3:4] > 0) + closer = hit & (~prev_hit | (depth_chunk < best_depth)) + + rast_chunk = torch.cat([rast_chunk[..., :3], tri_id_adj], dim=-1) + + best_depth = torch.where(closer, depth_chunk, best_depth) + best_rast = torch.where(closer.expand_as(best_rast), rast_chunk, best_rast) + + start += count + + rast_out = best_rast + bary_coords = rast_out[..., :2] + zbuf = rast_out[..., 2] + pix_to_face = rast_out[..., 3].to(torch.int32) - 1 + + if return_indices_only: + return pix_to_face + + _output = (bary_coords, zbuf, pix_to_face) + if return_rast_out: + _output += (rast_out,) + if return_positions: + _output += (pos,) + return _output + ``` + +
+ +2. 运行时降低内存峰值: + + * `MILO_MESH_RES_SCALE=0.3`(mesh 正则化分辨率缩放) + * `MILO_RAST_TRI_CHUNK=150000`(三角形分块大小) + * `--data_device cpu`(相机/数据在 CPU) + +--- + +## 复现实验步骤(Ignatius) + +> 数据路径:`./data/Ignatius` +> 输出目录:`./output/Ignatius` + +### 1) 训练 + +```bash +cd milo +export NVDIFRAST_BACKEND=cuda +export TORCH_CUDA_ARCH_LIST="12.0+PTX" +export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:32,garbage_collection_threshold:0.6" +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export MILO_MESH_RES_SCALE=0.3 +export MILO_RAST_TRI_CHUNK=150000 + +python train.py -s ./data/Ignatius -m ./output/Ignatius \ + --imp_metric outdoor \ + --rasterizer gof \ + --mesh_config verylowres \ + --sampling_factor 0.2 \ + --data_device cpu \ + --log_interval 200 +``` + +**产出**(位于 `./output/Ignatius`) + +* 训练好的场景(Gaussians + learnable SDF 与 mesh 正则化状态等) +* 日志与中间文件(按脚本配置,控制台打印训练进度) + +### 2) 网格提取(SDF) + +```bash +python mesh_extract_sdf.py \ + -s ./data/Ignatius \ + -m ./output/Ignatius \ + --rasterizer gof \ + --config verylowres \ + --data_device cpu +``` + +**产出** + +* **`./output/Ignatius/mesh_learnable_sdf.ply`**(已确认可在 MeshLab 正常打开) + +### 3) 渲染 + +```bash +python render.py \ + -m ./output/Ignatius \ + -s ./data/Ignatius \ + --rasterizer gof \ + --eval +``` + +**产出** + +* 渲染图像(训练/测试视角),保存到模型输出目录中的渲染子目录(以脚本实际打印为准) + +### 4) 图像指标 + +```bash +python metrics.py -m ./output/Ignatius +``` + +**产出** + +* 控制台输出 PSNR/SSIM(以及仓库脚本保存的对应文件,位于模型目录下;以实际实现为准) + +### 5) PLY 格式转换(可选) + +提取的 PLY 网格可以使用 `clean_convert_mesh.py` 脚本转换为其他常用的 3D 格式(OBJ/GLB)以便在各种 3D 软件中使用。脚本同时提供网格清理功能作为可选项。 + +**安装额外依赖** +```bash +pip install pymeshlab trimesh plyfile +``` + +**基本用法** +```bash +# 基本转换(输出 PLY、OBJ、GLB) +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply + +# 转换并简化到 30 万三角形 +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --simplify 300000 + +# 转换时清理小组件(默认 0.02 = 移除直径小于 2% bbox 对角线的组件) +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --keep-components 0.02 + +# 只输出特定格式 +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --no-glb # 不输出 GLB +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --no-obj # 不输出 OBJ + +# 指定输出目录和文件名 +python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply \ + --out-dir ./output/Ignatius/converted \ + --stem mesh_final +``` + +**主要功能** +- **格式转换**:将 PLY 转换为 OBJ、GLB 格式(适合不同 3D 软件和 Web 展示) +- **可选清理**:去除重复顶点/面、修复非流形边、移除小浮块 +- **可选简化**:基于 Quadric decimation 的保形简化 + +**产出**(默认保存在输入文件同目录) +* `mesh_clean.ply` - 转换后的 PLY 网格(带顶点颜色) +* `mesh_clean.obj` - OBJ 格式(注意:OBJ 不支持顶点颜色) +* `mesh_clean.glb` - GLB 格式(适合 Web 展示和 Blender/Unity 等软件导入) + +--- + +## 我们做了什么修改(相对上游) + +1. **`submodules/tetra_triangulation`** + + * 新增 `src/force_abi.h`,并在 `src/py_binding.cpp` 和 `src/triangulation.cpp` 首行 `#include "force_abi.h"`:**强制使用 C++11 新 ABI (=1)** + +2. **`milo/scene/mesh.py`** + + * 替换 `nvdiff_rasterization`: + + * 支持 **MILO_RAST_TRI_CHUNK** 三角形分块 + * **修正 CUDA 后端 `ranges` 必须是 CPU Tensor** + * 其余行为与原函数保持一致 + +3. **运行配置** + + * 默认使用 nvdiffrast **CUDA** 后端(`NVDIFRAST_BACKEND=cuda`),规避 EGL 依赖 + * 为 Blackwell 指定 `TORCH_CUDA_ARCH_LIST="12.0+PTX"` + * 降低峰值显存:`MILO_MESH_RES_SCALE=0.3`、`MILO_RAST_TRI_CHUNK=150000`、`--data_device cpu` + +--- + +## 常见问题与排障 + +* **`undefined symbol: ... torchCheckFail ... RKSs`** + 这是 ABI=0 的符号;请按上面的补丁重编 `tetra_triangulation`。 + +* **`fatal error: EGL/egl.h: No such file or directory`** + 若坚持 GL 路径:`sudo apt install -y libegl-dev libopengl-dev libgles2-mesa-dev ninja-build`; + 或直接使用 `export NVDIFRAST_BACKEND=cuda` 走 CUDA 路径。 + +* **nvdiffrast JIT 编译失败 / 乱用架构** + 确认 `TORCH_CUDA_ARCH_LIST="12.0+PTX"` 已导出,并清除缓存:`rm -rf ~/.cache/torch_extensions`。 + +* **显存 OOM** + 降低 `MILO_MESH_RES_SCALE`(如 0.5 → 0.3 → 0.25)、开启三角形分块 `MILO_RAST_TRI_CHUNK`,并使用 `--data_device cpu`。 + +--- + +## 结果汇总(本次 Ignatius 流程) + +* **训练 (`train.py`)**:完成,输出模型目录 `./output/Ignatius`(含训练状态与日志)。 +* **网格提取 (`mesh_extract_sdf.py`)**:得到 **`mesh_learnable_sdf.ply`**,已在 MeshLab 验证可视化。 +* **渲染 (`render.py`)**:得到训练/测试视角渲染图像(保存于输出目录的渲染子目录)。 +* **指标 (`metrics.py`)**:控制台打印 PSNR/SSIM(并保存到模型目录,文件名以实际实现为准)。 + +> 如需 Tanks&Temples 评测,可将 `mesh_learnable_sdf.ply` 软链为 `recon.ply`,再按评测脚本执行。 + +--- + +## 许可与鸣谢 + +本仓库为对原始 MILo 项目在 **CUDA 12.8 / RTX 50** 环境下的适配与工程化补充,保留原始项目许可证与归属。感谢原作者与各子模块作者(Tetra-NeRF、nvdiffrast、3D Gaussian Splatting 等)的优秀工作。 diff --git a/README_old.md b/README_old.md new file mode 100755 index 0000000..b9ced4a --- /dev/null +++ b/README_old.md @@ -0,0 +1,654 @@ +
+ +

+MILo: Mesh-In-the-Loop Gaussian Splatting for Detailed and Efficient Surface Reconstruction +

+ + +SIGGRAPH Asia 2025 - Journal Track (TOG)
+
+ + +Antoine Guédon*1  +Diego Gomez*1  +Nissim Maruani2
+Bingchen Gong1  +George Drettakis2  +Maks Ovsjanikov1  +
+
+ + +1Ecole polytechnique, France
+2Inria, Université Côte d'Azur, France
+
+ + +*Both authors contributed equally to the paper. + + +| Webpage | arXiv | Presentation video | Data | + +![Teaser image](assets/teaser.png) +
+ +## Abstract + +_Our method introduces a novel differentiable mesh extraction framework that operates during the optimization of 3D Gaussian Splatting representations. At every training iteration, we differentiably extract a mesh—including both vertex locations and connectivity—only from Gaussian parameters. This enables gradient flow from the mesh to Gaussians, allowing us to promote bidirectional consistency between volumetric (Gaussians) and surface (extracted mesh) representations. This approach guides Gaussians toward configurations better suited for surface reconstruction, resulting in higher quality meshes with significantly fewer vertices. Our framework can be plugged into any Gaussian splatting representation, increasing performance while generating an order of magnitude fewer mesh vertices. MILo makes the reconstructions more practical for downstream applications like physics simulations and animation._ + +## To-do List + +- ⬛ Implement a simple training viewer using the GraphDeco viewer. +- ⬛ Add the mesh-based rendering evaluation scripts in `./milo/eval/mesh_nvs`. +- ✅ Add low-res and very-low-res training for light output meshes (under 50MB and under 20MB). +- ✅ Add T&T evaluation scripts in `./milo/eval/tnt/`. +- ✅ Add Blender add-on (for mesh-based editing and animation) to the repo. +- ✅ Clean code. +- ✅ Basic refacto. + +## License + +
+Click here to see content. + +
This project builds on existing open-source implementations of various projects cited in the __Acknowledgements__ section. + +Specifically, it builds on the original implementation of [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting); As a result, parts of this code are licensed under the Gaussian-Splatting License (see `./LICENSE.md`). + +This codebase also builds on various other repositories such as [Nvdiffrast](https://github.com/NVlabs/nvdiffrast); Please refer to the license files of the submodules for more details. +
+ +## 0. Quickstart + +
+Click here to see content. + +Please start by creating or downloading a COLMAP dataset, such as our COLMAP run for the Ignatius scene from the Tanks&Temples dataset. You can move the Ignatius directory to `./milo/data`. + +After installing MILo as described in the next section, you can reconstruct a surface mesh from images by going to the `./milo/` directory and running the following commands: + +```bash +# Training for an outdoor scene +python train.py -s ./data/Ignatius -m ./output/Ignatius --imp_metric outdoor --rasterizer radegs + +# Saves mesh as PLY with vertex colors after training +python mesh_extract_sdf.py -s ./data/Ignatius -m ./output/Ignatius --rasterizer radegs +``` +Please change `--imp_metric outdoor` to `--imp_metric indoor` if your scene is indoor. + +These commands use the lightest version of our approach, resulting in a small number of Gaussians and a light mesh. You can increase the number of Gaussians by adding `--dense_gaussians`, and improve the robustness to exposure variations with `--decoupled_appearance` as follows: + +```bash +# Training with dense gaussians and better appearance model +python train.py -s ./data/Ignatius -m ./output/Ignatius --imp_metric outdoor --rasterizer radegs --dense_gaussians --decoupled_appearance + +# Saves mesh as PLY with vertex colors after training +python mesh_extract_sdf.py -s ./data/Ignatius -m ./output/Ignatius --rasterizer radegs +``` + +Please refer to the following sections for additional details on our training and mesh extraction scripts, including: +- How to use other rasterizers +- How to train MILo with high-resolution meshes +- Various mesh extraction methods +- How to easily integrate MILo's differentiable GS-to-mesh pipeline to your own GS project +
+ + +## 1. Installation + +
+Click here to see content. + +### Clone this repository. +```bash +git clone https://github.com/Anttwo/MILo.git --recursive +``` + +### Install dependencies. + +Please start by creating an environment: +```bash +conda create -n milo python=3.9 +conda activate milo +``` + +Then, specify your own CUDA paths depending on your CUDA version: +```bash +# You can specify your own cuda path (depending on your CUDA version) +export CPATH=/usr/local/cuda-11.8/targets/x86_64-linux/include:$CPATH +export LD_LIBRARY_PATH=/usr/local/cuda-11.8/targets/x86_64-linux/lib:$LD_LIBRARY_PATH +export PATH=/usr/local/cuda-11.8/bin:$PATH +``` + +Finally, you can run the following script to install all dependencies, including PyTorch and Gaussian Splatting submodules: +```bash +python install.py +``` +By default, the environment will be installed for CUDA 11.8. If using CUDA 12.1, you can provide the argument `--cuda_version 12.1` to `install.py`. **Please note that only CUDA 11.8 has been tested.** + +If you encounter problems or if the installation script does not work, please follow the detailed installation steps below. + +
+Click here for detailed installation instructions + +```bash +# For CUDA 11.8 +conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 mkl=2023.1.0 -c pytorch -c nvidia + +# For CUDA 12.1 (The code has only been tested on CUDA 11.8) +conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 mkl=2023.1.0 -c pytorch -c nvidia + +pip install -r requirements.txt + +# Install submodules for Gaussian Splatting, including different rasterizers, aggressive densification, simplification, and utilities +pip install submodules/diff-gaussian-rasterization_ms +pip install submodules/diff-gaussian-rasterization +pip install submodules/diff-gaussian-rasterization_gof +pip install submodules/simple-knn +pip install submodules/fused-ssim + +# Delaunay Triangulation from Tetra-Nerf +cd submodules/tetra_triangulation +conda install cmake +conda install conda-forge::gmp +conda install conda-forge::cgal + +# You can specify your own cuda path (depending on your CUDA version) +export CPATH=/usr/local/cuda-11.8/targets/x86_64-linux/include:$CPATH +export LD_LIBRARY_PATH=/usr/local/cuda-11.8/targets/x86_64-linux/lib:$LD_LIBRARY_PATH +export PATH=/usr/local/cuda-11.8/bin:$PATH + +cmake . +make +pip install -e . +cd ../../ + +# Nvdiffrast for efficient mesh rasterization +cd ./submodules/nvdiffrast +pip install . +cd ../../ +``` + +
+ +
+ +## 2. Training with MILo + +
+Click here to see content. + +First, go to the MILo folder: +```bash +cd milo +``` + +Then, to optimize a Gaussian Splatting representation with MILo using a COLMAP dataset, you can run the following command: +```bash +python train.py \ + -s \ + -m \ + --imp_metric <"indoor" OR "outdoor"> \ + --rasterizer <"radegs" OR "gof"> +``` +The main arguments are the following: +| Argument | Values | Default | Description | +|----------|--------|---------|-------------| +| `--imp_metric` | `"indoor"` or `"outdoor"` | Required | Type of scene to optimize. Modifies the importance sampling to better handle indoor or outdoor scenes. | +| `--rasterizer` | `"radegs"` or `"gof"` | `"radegs"` | Rasterization technique used during training. | +| `--dense_gaussians` | flag | disabled | Use more Gaussians during training. When active, only a subset of Gaussians will generate pivots for Delaunay triangulation. When inactive, all Gaussians generate pivots.| + +You can use a dense set of Gaussians by adding the argument `--dense_gaussians`: +```bash +python train.py \ + -s \ + -m \ + --imp_metric <"indoor" OR "outdoor"> \ + --rasterizer <"radegs" OR "gof"> \ + --dense_gaussians \ + --data_device cpu +``` + +The list of optional arguments is provided below: +| Category | Argument | Values | Default | Description | +|----------|----------|---------|---------|-------------| +| **Performance & Logging** | `--data_device` | `"cpu"` or `"cuda"` | `"cuda"` | Forces data to be loaded on CPU (less GPU memory usage, slightly slower training) | +| | `--log_interval` | integer | - | Log images every N training iterations (e.g., `200`) | +| **Mesh Configuration** | `--mesh_config` | `"default"`, `"highres"`, `"veryhighres"`, `"lowres"`, `"verylowres"` | `"default"` | Config file for mesh resolution and quality | +| **Evaluation & Appearance** | `--eval` | flag | disabled | Performs the usual train/test split for evaluation | +| | `--decoupled_appearance` | flag | disabled | Better handling of exposure variations | +| **Depth-Order Regularization** | `--depth_order` | flag | disabled | Enable depth-order regularization with DepthAnythingV2 | +| | `--depth_order_config` | `"default"` or `"strong"` | `"default"` | Strength of depth regularization | + +You can change the config file used during training (useful for ablation runs) by +specifying `--mesh_config `. The different config files are the following: + +- **Default config**: The default config file name is `default`. This config results in +lighter representations and lower resolution meshes, containing around 2M Delaunay vertices +for the base setting and 5M Delaunay vertices for the `--dense_gaussians` setting. +- **High Res config**: You can use `--mesh_config highres --dense_gaussians` for higher +resolution meshes. We recommend using this config with `--dense_gaussians`. This config +results in higher resolution representations, containing up to 9M Delaunay vertices. +- **Very High Res config**: You can use `--mesh_config veryhighres --dense_gaussians` for +even higher resolution meshes. We recommend using this config with `--dense_gaussians`. +This config results in even higher resolution representations, containing up to 14M Delaunay +vertices. This configuration requires more memory for training. +- **Low Res config**: You can use `--mesh_config lowres` for lower resolution meshes (less than 50MB). +This config results in lower resolution representations, containing up to 500k Delaunay vertices. +You can adjust the number of Gaussians used during training accordingly by decreasing the sampling factor +with `--sampling_factor 0.3`, for instance. +- **Very Low Res config**: You can use `--mesh_config verylowres` for even lower resolution meshes (less than 20MB). +This config results in even lower resolution representations, containing up to 250k Delaunay vertices. +You can adjust the number of Gaussians used during training accordingly by decreasing the sampling factor +with `--sampling_factor 0.1`, for instance. + +Please refer to the DepthAnythingV2 repo to download the `vitl` checkpoint required for Depth-Order regularization. Then, move the checkpoint file to `./submodules/Depth-Anything-V2/checkpoints/`. + +### Example Commands + +Basic training for indoor scenes with logging: +```bash +python train.py -s -m --imp_metric indoor --rasterizer radegs --log_interval 200 +``` + +Dense Gaussians with high resolution in outdoor scenes: +```bash +python train.py -s -m --imp_metric outdoor --rasterizer radegs --dense_gaussians --mesh_config highres --data_device cpu +``` + +Full featured training with very high resolution: +```bash +python train.py -s -m --imp_metric indoor --rasterizer radegs --dense_gaussians --mesh_config veryhighres --decoupled_appearance --log_interval 200 --data_device cpu +``` + +Very low resolution training in indoor scenes for very light meshes (less than 20MB): +```bash +python train.py -s -m --imp_metric indoor --rasterizer radegs --sampling_factor 0.1 --mesh_config verylowres +``` + +Training with depth-order regularization: +```bash +python train.py -s -m --imp_metric indoor --rasterizer radegs --depth_order --depth_order_config strong --log_interval 200 --data_device cpu +``` + +
+ +## 3. Extracting a Mesh after Optimization + +
+Click here to see content. + +First go to `./milo/`. + +### 3.1. Use learned SDF values + +You can then use the following command: +```bash +python mesh_extract_sdf.py \ + -s \ + -m \ + --rasterizer <"radegs" OR "gof"> +``` +This script will further refine the SDF values for a short period of time (1000 iterations by default) with frozen Gaussians, then save the mesh as a PLY file with vertex colors. The mesh will be located at `/mesh_learnable_sdf.ply`. + +**WARNING:** Make sure you use the same mesh config file as the one used during training. You can change the config file by specifying `--config `. The default config file name is `default`, but you can switch to `highres`, `veryhighres`, `lowres` or `verylowres`. + +You can use the usual train/test split by adding the argument `--eval`. + +### 3.2. Use Integrated Opacity Field or scalable TSDF + +To extract a mesh using the integrated opacity field as defined by the Gaussians in GOF and RaDe-GS, you can run the following command: +```bash +python mesh_extract_integration.py \ + -s \ + -m +``` +You can use the argument `--rasterizer ` to change the rasterization technique for computing the opacity field. Default is `gof`. We recommend using GOF in this context (even if RaDe-GS was used during training), as the opacity computation from GOF is more precise and will produce less surface erosion. + +You can also use the argument `--sdf_mode <"integration" OR "depth_fusion">` to modify the SDF computation strategy. Default mode is `integration`, which uses the integrated opacity field. Please note that `depth_fusion` is not traditional TSDF performed on a regular grid, but our more efficient depth fusion strategy relying on the same Gaussian pivots as the ones used for `integration`. + +If using `integration`, you can modify the isosurface value with the argument `--isosurface_value `. The default value is 0.5. +```bash +python mesh_extract_integration.py \ + -s \ + -m \ + --rasterizer gof \ + --sdf_mode integration \ + --isosurface_value 0.5 +``` + +If using `depth_fusion`, you can modify the truncation margin with the argument `--trunc_margin `. If not provided, the value is automatically computed depending on the scale of the scene. We recommend not changing this value. +```bash +python mesh_extract_integration.py \ + -s \ + -m \ + --rasterizer gof \ + --sdf_mode depth_fusion \ + --trunc_margin 0.005 +``` + +The mesh will be saved at either `/mesh_integration_sdf.ply` or `/mesh_depth_fusion_sdf.ply` depending on the SDF computation method. + +
+ +## 4. Using our differentiable Gaussians-to-Mesh pipeline in your own 3DGS project + +
+Click here to see content. +
+ +In `milo.functional`, we provide straightforward functions to leverage our differentiable *Gaussians-to-Mesh pipeline* in your own 3DGS projects. + +These functions only require Gaussian parameters as inputs (`means`, `scales`, `rotations`, `opacities`) and can extract a mesh from these parameters in a differentiable manner, allowing for **performing differentiable operations on the surface mesh and backpropating gradients directly to the Gaussians**. + +We only assume that your own `Camera` class has the same structure as the class from the original [3DGS](https://github.com/graphdeco-inria/gaussian-splatting) implementation. + +Specifically, we propose the following functions: +- `sample_gaussians_on_surface`: This function samples Gaussians that are the most likely to be located on the surface of the scene. For more efficiency, we propose using only these Gaussians for generating pivots and applying Delaunay triangulation. +- `extract_gaussian_pivots`: This differentiable function builds pivots from the parameters of the sampled Gaussians. In practice, there is no need to explicitely call this function, as our other functions can recompute pivots on the fly. However, you might want to perform special treatment on the pivots. +- `compute_initial_sdf_values`: This function estimates initial truncated SDF values for any set of Gaussian pivots by performing depth-fusion over the provided viewpoints. You can directly provide the gaussian parameters to this function, in which case pivots will be computed on the fly. In the paper, we propose to learn optimal SDF values by maximizing the consistency between volumetric GS renderings and surface mesh renderings; We use this function to initialize the SDF values. +- `compute_delaunay_triangulation`: This function computes the Delaunay triangulation for a set of sampled Gaussians pivots. You can provide either pivots as inputs, or directly the parameters of the Gaussians (means, scales, rotations...), in which case the pivots will be recomputed on the fly. This function should not be applied at every training iteration as it is very slow, and the Delaunay graph does not change that much during training. +- `extract_mesh`: This differentiable function extracts the mesh from the Gaussian parameters, given a Delaunay triangulation and SDF values for the Gaussian pivots. + +We also propose additional functions such as `frustum_cull_mesh` which culls mesh vertices based on the view frustum of an input camera. + +We provide an example of how to use these functions below, using our codebase or any codebase following the same template as the original [3DGS](https://github.com/graphdeco-inria/gaussian-splatting) implementation. + +```python +from functional import ( + sample_gaussians_on_surface, + extract_gaussian_pivots, + compute_initial_sdf_values, + compute_delaunay_triangulation, + extract_mesh, + frustum_cull_mesh, +) + +# Load or initialize a 3DGS-like model and training cameras +gaussians = ... +train_cameras = ... + +# Define a simple wrapper for your Gaussian Splatting rendering function, +# following this template. It will be used only for initializing SDF values. +# The wrapper should accept just a camera as input, and return a dictionary +# with "render" and "depth" keys. +from gaussian_renderer.radegs import render_radegs +pipe = ... +background = torch.tensor([0., 0., 0.], device="cuda") +def render_func(view): + render_pkg = render_radegs( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + kernel_size=0.0, + scaling_modifier = 1.0, + require_coord=False, + require_depth=True + ) + return { + "render": render_pkg["render"], + "depth": render_pkg["median_depth"], + } + +# Only the parameters of the Gaussians are needed for extracting the mesh. +means = gaussians.get_xyz +scales = gaussians.get_scaling +rotations = gaussians.get_rotation +opacities = gaussians.get_opacity + +# Sample Gaussians on the surface. +# Should be performed only once, or just once in a while. +# In this example, we sample at most 600_000 Gaussians. +surface_gaussians_idx = sample_gaussians_on_surface( + views=train_cameras, + means=means, + scales=scales, + rotations=rotations, + opacities=opacities, + n_max_samples=600_000, + scene_type='indoor', +) + +# Compute initial SDF values for pivots. Should be performed only once. +# In the paper, we propose to learn optimal SDF values by maximizing the +# consistency between volumetric renderings and surface mesh renderings. +initial_pivots_sdf = compute_initial_sdf_values( + views=train_cameras, + render_func=render_func, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + +# Compute Delaunay Triangulation. +# Can be performed once in a while. +delaunay_tets = compute_delaunay_triangulation( + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + +# Differentiably extract a mesh from Gaussian parameters, including initial +# or updated SDF values for the Gaussian pivots. +# This function is differentiable with respect to the parameters of the Gaussians, +# as well as the SDF values. Can be performed at every training iteration. +mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=initial_pivots_sdf, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + +# You can now apply any differentiable operation on the extracted mesh, +# and backpropagate gradients back to the Gaussians! +# In the paper, we propose to use differentiable mesh rendering. +from scene.mesh import MeshRasterizer, MeshRenderer +renderer = MeshRenderer(MeshRasterizer(cameras=train_cameras)) + +# We cull the mesh based on the view frustum for more efficiency +i_view = np.random.randint(0, len(train_cameras)) +mesh_render_pkg = renderer( + frustum_cull_mesh(mesh, train_cameras[i_view]), + cam_idx=i_view, + return_depth=True, return_normals=True +) +mesh_depth = mesh_render_pkg["depth"] +mesh_normals = mesh_render_pkg["normals"] +``` + +
+ +## 5. Creating a COLMAP dataset with your own images + +
+Click here to see content. +
+ +### 5.1. Estimate camera poses with COLMAP + +Please first install a recent version of COLMAP (ideally CUDA-powered) and make sure to put the images you want to use in a directory `/input`. Then, run the script `milo/convert.py` from the original Gaussian splatting implementation to compute the camera poses for the images using COLMAP. Please refer to the original 3D Gaussian Splatting repository for more details. + +```shell +python milo/convert.py -s +``` + +Sometimes COLMAP fails to reconstruct all images into the same model and hence produces multiple sub-models. The smaller sub-models generally contain only a few images. However, by default, the script `convert.py` will apply Image Undistortion only on the first sub-model, which may contain only a few images. + +If this is the case, a simple solution is to keep only the largest sub-model and discard the others. To do this, open the source directory containing your input images, then open the sub-directory `/distorted/sparse/`. You should see several sub-directories named `0/`, `1/`, etc., each containing a sub-model. Remove all sub-directories except the one containing the largest files, and rename it to `0/`. Then, run the script `convert.py` one more time but skip the matching process: + +```shell +python milo/convert.py -s --skip_matching +``` + +_Note: If the sub-models have common registered images, they could be merged into a single model as post-processing step using COLMAP; However, merging sub-models requires to run another global bundle adjustment after the merge, which can be time consuming._ + +### 5.2. Estimate camera poses with VGGT + +Coming soon. + +
+ + +## 6. Mesh-Based Editing and Animation with the MILo Blender Addon + +
+Click here to see content. +
+ +While MILo provides a differentiable solution for extracting meshes from 3DGS representations, it also implicitly encourages Gaussians to align with the surface of the mesh. As a result, any modification made to the mesh can be easily propagated to the Gaussians, making the reconstructed mesh an excellent proxy for editing and animating the 3DGS representation. + +Similarly to previous works SuGaR and Gaussian Frosting, we provide a Blender addon allowing to combine, edit and animate 3DGS representations just by manipulating meshes reconstructed with MILo in Blender. + +### 6.1. Installing the addon + +1. Please start by installing `torch_geometric` and `torch_cluster` in your `milo` conda environment: +```shell +pip install torch_geometric +pip install torch_cluster +``` + +2. Then, install Blender (version 4.0.2 is recommended but not mandatory). + +3. Open Blender, and go to `Edit` > `Preferences` > `Add-ons` > `Install`, and select the file `milo_addon.py` located in `./milo/blender/`.
+ +You have now installed the MILo addon for Blender! + +### 6.2. Usage + +This Blender addon is almost identical as the SuGaR x Frosting Blender addon. You can refer to this previous repo for more details and illustrations. To combine, edit or animate scene with the addon, please follow the steps below: + +1. Please start by training Gaussians with MILo and extracting a mesh, as described in the Quickstart section. + +2. Open a new scene in Blender, and go to the `Render` tab in the Properties. You should see a panel named `Add MILo mesh`. The panel is not necessary at the top of the tab, so you may need to scroll down to find it. + +3. **(a) Select a mesh.** Enter the path to the final mesh extracted with MILo in the `Path to mesh PLY` field. You can also click on the folder icon to select the file. The mesh should be located at `/mesh_learnable_sdf.ply`.

+**(b) Select a checkpoint.** Similarly, enter the path to the final checkpoint of the optimization in the `Path to 3DGS PLY` field. You can also click on the folder icon to select the file. The checkpoint should be located at `/point_cloud/iteration_18000/point_cloud.ply`.

+**(c) Load the mesh.** Finally, click on `Add mesh` to load the mesh in Blender. Feel free to rotate the mesh and change the shading mode to better visualize the mesh and its colors. + +4. **Now, feel free to edit your mesh using Blender!** +
You can segment it into different pieces, sculpt it, rig it, animate it using a parent armature, *etc*. You can also add other MILo meshes to the scene, and combine elements from different scenes.
+Feel free to set a camera in the scene and prepare an animation: You can animate the camera, the mesh, *etc*.
+Please avoid using `Apply Location`, `Apply Rotation`, or `Apply Scale` on the edited mesh, as we are still unsure how it will affect the correspondence between the mesh and the optimized checkpoint. + +5. Once you're done with your editing, you can prepare a rendering package ready to be rendered with Gaussians. To do so, go to the `Render` tab in the Properties again, and select the `./milo/` directory in the `Path to MILo directory` field.
+Finally, click on `Render Image` or `Render Animation` to save a rendering package for the scene.

+`Render Image` will render a single image of the scene, with the current camera position and mesh editions/poses.

+`Render Animation` will render a full animation of the scene, from the first frame to the last frame you set in the Blender Timeline. +

+The package should be saved as a `JSON` file and located in `./milo/blender/packages/`. + +7. Finally, you can render the package with Gaussian Splatting. You just need to go to `./milo/` and run the following command: +```shell +python render_blender_scene.py \ + -p \ + --rasterizer <"radegs" or "gof">. +``` + +By default, renderings are saved in `./milo/blender/renders//`. However, you can change the output directory by adding `-o `. + +Please check the documentation of the `render_blender_scene.py` scripts for more information on the additional arguments. +If you get artifacts in the rendering, you can try to play with the various following hyperparameters: `binding_mode`, `filter_big_gaussians_with_th`, `clamp_big_gaussians_with_th`, and `filter_distant_gaussians_with_th`. + +
+ +## 7. Evaluation + +
+Click here to see content. +
+ +For evaluation, please start by downloading [our COLMAP runs for the Tanks and Temples dataset](https://drive.google.com/drive/folders/1Bf7DM2DFtQe4J63bEFLceEycNf4qTcqm?usp=sharing), and make sure to move all COLMAP scene directories (Barn, Caterpillar, _etc._) inside the same directory. + +Then, please download ground truth point cloud, camera poses, alignments and cropfiles from [Tanks and Temples dataset](https://www.tanksandtemples.org/download/). The ground truth dataset should be organized as: +``` +GT_TNT_dataset +│ +└─── Barn +│ │ +| └─── Barn.json +│ │ +| └─── Barn.ply +│ │ +| └─── Barn_COLMAP_SfM.log +│ │ +| └─── Barn_trans.txt +│ +└─── Caterpillar +│ │ +...... +``` + +We follow the exact same pipeline as GOF and RaDe-GS for evaluating MILo on T&T. Please go to `./milo/` and run the following script to run the full training and evaluation pipeline on all scenes: + +```bash +python scripts/evaluate_tnt.py \ + --data_dir \ + --gt_dir \ + --output_dir \ + --rasterizer <"radegs" or "gof"> \ + --mesh_config <"default" or "highres" or "veryhighres"> +``` +You can add `--dense_gaussians` for using more Gaussians during training. Please note that `--dense_gaussians` will be automatically set to `True` if using `--mesh_config highres` or `--mesh_config veryhighres`. + +For evaluating only a single scene, you can run the following commands: + +```bash +# Training (you can add --dense_gaussians for higher performance) +python train.py \ + -s \ + -m \ + --imp_metric <"indoor" or "outdoor"> \ + --rasterizer <"radegs" or "gof"> \ + --mesh_config <"default" or "highres" or "veryhighres"> \ + --eval \ + --decoupled_appearance \ + --data_device cpu + +# Mesh extraction +python mesh_extract_sdf.py \ + -s \ + -m \ + --rasterizer <"radegs" or "gof"> \ + --config <"default" or "highres" or "veryhighres"> \ + --eval \ + --data_device cpu + +# Evaluation +python eval/tnt/run.py \ + --dataset-dir \ + --traj-path \ + --ply-path /recon.ply +``` + +### Novel View Synthesis +After training MILo on a scene with test/train split by using the argument `--eval`, you can evaluate the performance of the Novel View Synthesis by running the scripts below: + +```bash +python render.py \ + -m \ + -s \ + -- rasterizer <"radegs" or "gof"> + +python metrics.py -m # Compute error metrics on renderings +``` + +
+ +## 8. Acknowledgements + +We build this project based on [Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [Mini-Splatting2](https://github.com/fatPeter/mini-splatting2). + +We propose to use rasterization techniques from [RaDe-GS](https://baowenz.github.io/radegs/) and [GOF](https://github.com/autonomousvision/gaussian-opacity-fields/tree/main). + +The latter incorporate the filters proposed in [Mip-Splatting](https://github.com/autonomousvision/mip-splatting), the loss functions of [2D GS](https://github.com/hbb1/2d-gaussian-splatting) and its preprocessed DTU dataset. + +We use [Nvdiffrast](https://github.com/NVlabs/nvdiffrast) for differentiable triangle rasterization, and [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) for computing our optional depth-order regularization loss relying on monocular depth estimation. + +The evaluation scripts for the Tanks and Temples dataset are sourced from [TanksAndTemples](https://github.com/isl-org/TanksAndTemples/tree/master/python_toolbox/evaluation). + +We thank the authors of all the above projects for their great works. diff --git a/milo/OPTIMIZATION_CONFIG_GUIDE.md b/milo/OPTIMIZATION_CONFIG_GUIDE.md new file mode 100644 index 0000000..ca15faa --- /dev/null +++ b/milo/OPTIMIZATION_CONFIG_GUIDE.md @@ -0,0 +1,335 @@ +# 优化配置系统使用指南 + +## 概述 + +我们为`yufu2mesh_new.py`实现了基于YAML的优化配置系统,大幅简化了超参数管理。现在只需要**一个参数**`--opt_config`就可以控制: + +1. ✅ 高斯参数的训练/冻结状态 +2. ✅ 每个参数的独立学习率 +3. ✅ 所有loss权重 +4. ✅ 深度处理参数(裁剪范围) +5. ✅ Mesh正则化权重 + +## 核心改进 + +### 改进前 (旧方式) + +```bash +python yufu2mesh_new.py \ + --lr 5e-4 \ + --depth_loss_weight 1.0 \ + --normal_loss_weight 1.5 \ + --depth_clip_min 0.0 \ + --depth_clip_max 50.0 \ + --mesh_depth_weight 0.1 \ + --mesh_normal_weight 0.15 \ + # 只能优化xyz,其他参数硬编码冻结 +``` + +**问题**: +- 超参数分散在多个命令行参数中 +- 无法灵活控制哪些参数可训练 +- 每个参数只能用统一学习率 +- 配置难以复用和版本管理 + +### 改进后 (新方式) + +```bash +python yufu2mesh_new.py --opt_config xyz_geometry +``` + +**优势**: +- 单一参数控制所有优化行为 +- YAML配置文件清晰易读,带详细注释 +- 支持版本控制和复用 +- 预设多种常用配置 + +## 文件结构 + +``` +milo/ +├── yufu2mesh_new.py # 主程序(已修改) +├── test_opt_config.py # 配置测试脚本(新增) +├── OPTIMIZATION_CONFIG_GUIDE.md # 本文档(新增) +└── configs/ + └── optimization/ # 优化配置目录(新增) + ├── README.md # 详细使用文档 + ├── default.yaml # 默认配置 + ├── xyz_only.yaml # 仅位置优化 + ├── xyz_geometry.yaml # 位置+几何优化 + ├── xyz_occupancy.yaml # 位置+占用优化 + └── full.yaml # 全参数优化 +``` + +## 快速开始 + +### 1. 使用预设配置 + +```bash +# 方式1: 使用默认配置(只优化位置) +python yufu2mesh_new.py --opt_config default + +# 方式2: 使用预设配置 +python yufu2mesh_new.py --opt_config xyz_geometry +``` + +### 2. 创建自定义配置 + +```bash +# 复制模板 +cd configs/optimization +cp default.yaml my_config.yaml + +# 编辑配置文件 +vim my_config.yaml + +# 使用自定义配置 +python yufu2mesh_new.py --opt_config my_config +``` + +### 3. 使用完整路径 + +```bash +python yufu2mesh_new.py --opt_config /path/to/my_config.yaml +``` + +## 预设配置对比 + +| 配置名称 | 可训练参数 | 适用场景 | 特点 | +|---------|-----------|---------|------| +| `default` | xyz | 已有良好初始化 | 默认选择,稳定 | +| `xyz_only` | xyz | 同default | 更保守,所有冻结参数lr=0 | +| `xyz_geometry` | xyz, scaling, rotation | 需要调整形状 | 增强法线loss | +| `xyz_occupancy` | xyz, occupancy_shift | 优化mesh质量 | 增强mesh正则化 | +| `full` | 所有(除base) | 初始化较差 | ⚠️ 容易过拟合 | + +## 配置文件示例 + +```yaml +# configs/optimization/xyz_geometry.yaml + +gaussian_params: + _xyz: + trainable: true # 优化位置 + lr: 5.0e-4 + + _scaling: + trainable: true # 优化尺度 + lr: 1.0e-4 + + _rotation: + trainable: true # 优化旋转 + lr: 1.0e-4 + + _features_dc: + trainable: false # 冻结颜色 + lr: 0.0 + + # ... 其他参数 + +loss_weights: + depth: 1.0 # 深度loss权重 + normal: 1.5 # 法线loss权重(增强) + mesh_depth: 0.1 + mesh_normal: 0.15 + +depth_processing: + clip_min: 0.0 + clip_max: null # 不裁剪最大深度 + +mesh_regularization: + depth_weight: 0.1 + normal_weight: 0.15 # 增强法线正则化 +``` + +## 代码修改说明 + +### 新增函数 + +1. **`load_optimization_config(config_name: str) -> Dict`** + - 加载并验证YAML配置文件 + - 支持配置名称或完整路径 + +2. **`setup_gaussian_optimization(gaussians, opt_config) -> (optimizer, loss_weights)`** + - 根据配置设置参数的可训练性 + - 创建多参数组的Adam优化器 + - 返回优化器和loss权重字典 + +### 主要修改点 + +1. **命令行参数** (yufu2mesh_new.py:632-768) + - 新增 `--opt_config` 参数 + - 旧参数标记为已弃用,保留向后兼容 + +2. **参数初始化** (yufu2mesh_new.py:817-822) + - 移除硬编码的参数冻结逻辑 + - 使用 `setup_gaussian_optimization` 替代 + +3. **深度处理** (yufu2mesh_new.py:872-887) + - 从YAML配置读取裁剪参数 + +4. **Mesh正则化** (yufu2mesh_new.py:854-858) + - 从YAML配置读取权重覆盖 + +5. **Loss计算** (yufu2mesh_new.py:1024-1033) + - 使用配置中的loss权重 + +6. **梯度计算** (yufu2mesh_new.py:1035-1047) + - 支持多参数组的梯度范数计算 + +## 测试验证 + +### 运行配置测试 + +```bash +cd /home/zoyo/Desktop/MILo_rtx50/milo +python test_opt_config.py +``` + +预期输出: +``` +🎉 所有配置测试通过! +``` + +### 查看参数设置 + +运行主程序时会打印配置信息: + +``` +[INFO] 加载优化配置:xyz_geometry +[INFO] 配置高斯参数优化... +[INFO] 参数 _xyz: trainable=True, lr=0.000500 +[INFO] 参数 _features_dc: trainable=False +[INFO] 参数 _scaling: trainable=True, lr=0.000100 +... +[INFO] Loss权重配置: + > depth: 1.0 + > normal: 1.5 + > mesh_depth: 0.1 + > mesh_normal: 0.15 +``` + +## 向后兼容性 + +旧的命令行参数仍然可以使用,但会被YAML配置覆盖并显示警告: + +```bash +python yufu2mesh_new.py --opt_config default --lr 1e-3 +# 输出: [WARNING] --lr 已弃用,将使用YAML配置中的学习率设置 +``` + +建议尽快迁移到YAML配置方式。 + +## 高级使用 + +### 渐进式优化策略 + +```bash +# 阶段1: 快速收敛位置 (前1000次迭代) +python yufu2mesh_new.py \ + --opt_config xyz_only \ + --num_iterations 1000 \ + --heatmap_dir stage1_xyz + +# 阶段2: 细化几何 (1000次迭代) +# 先将stage1的最终高斯复制为初始值,然后: +python yufu2mesh_new.py \ + --opt_config xyz_geometry \ + --num_iterations 1000 \ + --heatmap_dir stage2_geometry + +# 阶段3: 优化mesh质量 (500次迭代) +python yufu2mesh_new.py \ + --opt_config xyz_occupancy \ + --num_iterations 500 \ + --heatmap_dir stage3_occupancy +``` + +### 调参建议 + +1. **学习率太大** → 训练不稳定,loss震荡 + - 解决:降低对应参数的lr,例如从5e-4降到1e-4 + +2. **学习率太小** → 收敛太慢 + - 解决:适当增加lr,或增加迭代次数 + +3. **过拟合** → 训练集loss很低但效果差 + - 解决:减少可训练参数,或降低loss权重 + +4. **欠拟合** → loss降不下来 + - 解决:增加可训练参数,或增加loss权重 + +## 常见问题 + +**Q: 如何查看所有可用的预设配置?** +```bash +ls configs/optimization/*.yaml +``` + +**Q: 如何知道某个配置具体训练哪些参数?** +```bash +python test_opt_config.py # 查看所有配置 +# 或查看YAML文件内容 +cat configs/optimization/xyz_geometry.yaml +``` + +**Q: 配置文件修改后需要重启吗?** + +不需要,每次运行时都会重新加载配置。 + +**Q: 可以在训练中切换配置吗?** + +不建议。如需切换,请使用checkpoint机制,分阶段训练。 + +**Q: 如何备份我的配置?** + +配置文件是纯文本,可以直接用git管理: +```bash +cd configs/optimization +git add my_config.yaml +git commit -m "Add custom optimization config" +``` + +## 相关文件 + +- 详细配置说明: `configs/optimization/README.md` +- 配置测试脚本: `test_opt_config.py` +- 主程序: `yufu2mesh_new.py` + +## 技术细节 + +### 优化器实现 + +使用PyTorch的参数组(parameter groups)功能: + +```python +optimizer = torch.optim.Adam([ + {"params": [gaussians._xyz], "lr": 5e-4, "name": "_xyz"}, + {"params": [gaussians._scaling], "lr": 1e-4, "name": "_scaling"}, + {"params": [gaussians._rotation], "lr": 1e-4, "name": "_rotation"}, +]) +``` + +每个参数可以有独立的学习率,Adam优化器会为每个参数组维护独立的动量和自适应学习率状态。 + +### 梯度范数计算 + +计算所有可训练参数的L2梯度范数: + +```python +grad_norm = sqrt(sum(||param.grad||^2 for param in trainable_params)) +``` + +用于监控训练稳定性和调试。 + +## 总结 + +通过YAML配置系统,我们实现了: + +✅ **超参数管理简化**: 从8+个命令行参数减少到1个 +✅ **灵活性提升**: 可以自由控制任意参数的训练状态 +✅ **可复用性**: 配置文件易于分享和版本控制 +✅ **可读性**: 清晰的YAML格式配合详细注释 +✅ **向后兼容**: 不影响现有代码的使用 + +这个系统让实验配置更加模块化和易于管理,非常适合需要频繁调整优化策略的研究场景。 diff --git a/milo/QUICK_START.md b/milo/QUICK_START.md new file mode 100644 index 0000000..3f96a18 --- /dev/null +++ b/milo/QUICK_START.md @@ -0,0 +1,114 @@ +# 快速开始指南 (Quick Start) + +## 基本用法 + +使用新的YAML配置系统,只需一个参数控制所有优化行为: + +```bash +python yufu2mesh_new.py --opt_config <配置名称> +``` + +## 预设配置 + +| 配置名称 | 优化参数 | 使用场景 | +|---------|---------|---------| +| `default` | 仅位置(xyz) | 默认选择,适合大多数情况 | +| `xyz_only` | 仅位置(xyz) | 最保守策略 | +| `xyz_geometry` | 位置+形状 | 需要调整高斯形状 | +| `xyz_occupancy` | 位置+占用 | 改善mesh提取质量 | +| `full` | 所有参数 | 初始化较差时使用 ⚠️ | + +## 常用命令 + +```bash +# 1. 使用默认配置 +python yufu2mesh_new.py --opt_config default + +# 2. 优化位置和几何形状 +python yufu2mesh_new.py --opt_config xyz_geometry + +# 3. 改善mesh质量 +python yufu2mesh_new.py --opt_config xyz_occupancy + +# 4. 指定迭代次数和输出目录 +python yufu2mesh_new.py \ + --opt_config xyz_geometry \ + --num_iterations 200 \ + --heatmap_dir my_output + +# 5. 使用自定义配置文件 +python yufu2mesh_new.py --opt_config /path/to/custom.yaml +``` + +## 自定义配置 + +### 1. 复制模板 + +```bash +cd configs/optimization +cp default.yaml my_config.yaml +``` + +### 2. 编辑配置 + +打开 `my_config.yaml`,修改你需要的部分: + +```yaml +gaussian_params: + _xyz: + trainable: true # 是否训练 + lr: 5.0e-4 # 学习率 + + _scaling: + trainable: false # 改为true可优化形状 + lr: 1.0e-4 + +loss_weights: + depth: 1.0 # 深度loss权重 + normal: 1.5 # 法线loss权重 +``` + +### 3. 使用自定义配置 + +```bash +python yufu2mesh_new.py --opt_config my_config +``` + +## 验证配置 + +测试所有配置文件是否正确: + +```bash +python test_opt_config.py +``` + +## 配置文件位置 + +- 预设配置: `configs/optimization/*.yaml` +- 详细文档: `configs/optimization/README.md` +- 完整指南: `OPTIMIZATION_CONFIG_GUIDE.md` + +## 与旧参数对比 + +### 旧方式(已弃用) +```bash +python yufu2mesh_new.py \ + --lr 5e-4 \ + --depth_loss_weight 1.0 \ + --normal_loss_weight 1.5 \ + --depth_clip_max 50.0 \ + --mesh_depth_weight 0.1 +``` + +### 新方式(推荐) +```bash +python yufu2mesh_new.py --opt_config xyz_geometry +``` + +所有参数都在YAML文件中统一管理,更清晰易维护。 + +## 需要帮助? + +- 查看预设配置: `ls configs/optimization/` +- 查看配置内容: `cat configs/optimization/default.yaml` +- 详细文档: `configs/optimization/README.md` diff --git a/milo/clean ply/clean_3dgs.py b/milo/clean ply/clean_3dgs.py new file mode 100644 index 0000000..262c029 --- /dev/null +++ b/milo/clean ply/clean_3dgs.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 +"""Filter 3D Gaussian Splatting models based on voxelized statistics. + +The script keeps voxels that are either sufficiently dense or contain very +fine Gaussians, then optionally expands the kept region by a halo and writes a +trimmed PLY. It follows the core workflow outlined in the clean_3dgs design +document. +""" + +from __future__ import annotations + +import argparse +import math +import shutil +from collections import deque +from pathlib import Path +from typing import Iterable + +import numpy as np +from plyfile import PlyData, PlyElement + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--input", required=True, help="Path to the source 3DGS PLY file") + parser.add_argument( + "--output", + help=( + "Path to the cleaned PLY file. Defaults to '_cleaned.ply' next " + "to the source file." + ), + ) + parser.add_argument( + "--voxel-size", + type=float, + default=0.25, + help="Voxel edge length used for statistics (default: 0.25 in world units)", + ) + parser.add_argument( + "--density-keep-ratio", + type=float, + default=0.3, + help="Fraction of the high-density reference used as the keep threshold", + ) + parser.add_argument( + "--density-reference-percentile", + type=float, + default=0.10, + help="Top percentile of voxel densities used to build the reference (default: 10%)", + ) + parser.add_argument( + "--volume-keep-ratio", + type=float, + default=3.0, + help="Multiplier applied to the fine-volume reference for the detail guard", + ) + parser.add_argument( + "--volume-reference-percentile", + type=float, + default=0.15, + help="Bottom percentile of voxel volumes used as the fine detail reference", + ) + parser.add_argument( + "--halo-voxels", + type=int, + default=3, + help="Voxel radius for spatial halo expansion (default: 3)", + ) + parser.add_argument( + "--max-gaussians", + type=int, + help="Optional cap on the number of Gaussians to keep after filtering", + ) + parser.add_argument( + "--disable-component-filter", + action="store_true", + help="Keep all disconnected voxel clusters instead of retaining only the largest", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed used when downsampling to --max-gaussians", + ) + args = parser.parse_args() + + if args.voxel_size <= 0: + parser.error("--voxel-size must be positive") + if args.density_keep_ratio <= 0: + parser.error("--density-keep-ratio must be positive") + if not 0 < args.density_reference_percentile <= 1: + parser.error("--density-reference-percentile must be in (0, 1]") + if args.volume_keep_ratio <= 0: + parser.error("--volume-keep-ratio must be positive") + if not 0 < args.volume_reference_percentile <= 1: + parser.error("--volume-reference-percentile must be in (0, 1]") + if args.halo_voxels < 0: + parser.error("--halo-voxels must be non-negative") + if args.max_gaussians is not None and args.max_gaussians <= 0: + parser.error("--max-gaussians must be a positive integer when provided") + + return args + + +def extract_positions(vertex_data: np.ndarray) -> np.ndarray: + required = ("x", "y", "z") + missing = [field for field in required if field not in vertex_data.dtype.names] + if missing: + raise ValueError(f"Input PLY is missing required position fields: {missing}") + + positions = np.empty((vertex_data.shape[0], 3), dtype=np.float64) + positions[:, 0] = vertex_data["x"] + positions[:, 1] = vertex_data["y"] + positions[:, 2] = vertex_data["z"] + return positions + + +def extract_log_scales(vertex_data: np.ndarray) -> tuple[np.ndarray, list[str]]: + scale_names = [name for name in vertex_data.dtype.names if name.startswith("scale_")] + if not scale_names: + raise ValueError("Input PLY is missing 'scale_*' fields required for volume estimates") + try: + scale_names.sort(key=lambda name: int(name.split("_")[-1])) + except ValueError as exc: + raise ValueError("Scale fields must follow the 'scale_' naming convention") from exc + + log_scales = np.empty((vertex_data.shape[0], len(scale_names)), dtype=np.float64) + for idx, field in enumerate(scale_names): + log_scales[:, idx] = vertex_data[field] + return log_scales, scale_names + + +def compute_reference(values: np.ndarray, fraction: float, descending: bool) -> float: + if values.size == 0: + return float("nan") + sorted_vals = np.sort(values) + if descending: + sorted_vals = sorted_vals[::-1] + top_n = max(int(math.ceil(len(sorted_vals) * fraction)), 1) + subset = sorted_vals[:top_n] + return float(np.median(subset)) + + +def describe_distribution(values: np.ndarray, label: str, fmt: str) -> None: + if values.size == 0: + print(f"[STATS] {label}: no data") + return + quantiles = np.percentile(values, [0, 50, 95, 100]) + summary = " ".join( + ( + f"min={format(quantiles[0], fmt)}", + f"median={format(quantiles[1], fmt)}", + f"p95={format(quantiles[2], fmt)}", + f"max={format(quantiles[3], fmt)}", + ) + ) + print(f"[STATS] {label}: {summary}") + + +def build_halo_offsets(radius: int) -> np.ndarray: + if radius <= 0: + return np.zeros((0, 3), dtype=np.int32) + coords = np.arange(-radius, radius + 1, dtype=np.int32) + grid = np.stack(np.meshgrid(coords, coords, coords, indexing="ij"), axis=-1) + offsets = grid.reshape(-1, 3) + keep = np.any(offsets != 0, axis=1) + return offsets[keep] + + +NEIGHBOR_OFFSETS = np.array( + [ + (1, 0, 0), + (-1, 0, 0), + (0, 1, 0), + (0, -1, 0), + (0, 0, 1), + (0, 0, -1), + ], + dtype=np.int32, +) + + +def filter_largest_component( + keep_mask: np.ndarray, + unique_voxels: np.ndarray, + voxel_to_index: dict[tuple[int, int, int], int], + neighbor_offsets: np.ndarray, +) -> tuple[np.ndarray, int, int]: + visited = np.zeros_like(keep_mask, dtype=bool) + largest_component: list[int] = [] + largest_size = 0 + component_count = 0 + + for start_idx in np.flatnonzero(keep_mask): + if visited[start_idx]: + continue + component_count += 1 + queue: deque[int] = deque([int(start_idx)]) + visited[start_idx] = True + current_component: list[int] = [] + + while queue: + idx = queue.popleft() + current_component.append(idx) + base = unique_voxels[idx] + for offset in neighbor_offsets: + neighbor_coord = tuple(int(v) for v in base + offset) + neighbor_idx = voxel_to_index.get(neighbor_coord) + if neighbor_idx is None or visited[neighbor_idx] or not keep_mask[neighbor_idx]: + continue + visited[neighbor_idx] = True + queue.append(neighbor_idx) + + if len(current_component) > largest_size: + largest_size = len(current_component) + largest_component = current_component + + new_keep = np.zeros_like(keep_mask, dtype=bool) + if largest_component: + new_keep[np.array(largest_component, dtype=np.int64)] = True + return new_keep, component_count, largest_size + + +def format_vec(values: Iterable[float]) -> str: + return ", ".join(f"{float(v):.3f}" for v in values) + + +def main() -> None: + args = parse_args() + input_path = Path(args.input).expanduser().resolve() + if args.output: + output_path = Path(args.output).expanduser().resolve() + else: + output_path = input_path.with_name(f"{input_path.stem}_cleaned{input_path.suffix}") + + if not input_path.exists(): + raise FileNotFoundError(f"Input PLY not found: {input_path}") + + print(f"[INFO] Loading PLY: {input_path}") + ply = PlyData.read(str(input_path)) + if "vertex" not in ply: + raise ValueError("Input PLY does not contain a 'vertex' element") + + vertex_data = np.asarray(ply["vertex"].data) + total_gaussians = vertex_data.shape[0] + print(f"[INFO] Total Gaussians: {total_gaussians}") + if total_gaussians == 0: + raise ValueError("Input PLY contains no Gaussians to process") + + positions = extract_positions(vertex_data) + log_scales, _ = extract_log_scales(vertex_data) + # 3DGS stores log-scale radii; exponentiate to obtain axis lengths and volume. + volumes = np.exp(np.sum(log_scales, axis=1)) + + bounds_min = positions.min(axis=0) + bounds_max = positions.max(axis=0) + print(f"[INFO] Bounding box min: [{format_vec(bounds_min)}]") + print(f"[INFO] Bounding box max: [{format_vec(bounds_max)}]") + + voxel_indices = np.floor((positions - bounds_min) / args.voxel_size).astype(np.int32) + unique_voxels, inverse_indices, counts = np.unique( + voxel_indices, axis=0, return_inverse=True, return_counts=True + ) + voxel_count = unique_voxels.shape[0] + print(f"[INFO] Occupied voxels: {voxel_count}") + + voxel_to_index = { + tuple(int(v) for v in coord): idx for idx, coord in enumerate(unique_voxels) + } + + volume_sums = np.bincount(inverse_indices, weights=volumes, minlength=voxel_count) + with np.errstate(divide="ignore", invalid="ignore"): + avg_volumes = volume_sums / counts + + describe_distribution(counts.astype(np.float64), "Voxel density", ".0f") + describe_distribution(avg_volumes, "Voxel avg_volume", ".3e") + + count_ref = compute_reference( + counts.astype(np.float64), args.density_reference_percentile, descending=True + ) + count_threshold = args.density_keep_ratio * count_ref + + volume_ref = compute_reference( + avg_volumes, args.volume_reference_percentile, descending=False + ) + if not np.isfinite(volume_ref) or volume_ref <= 0: + volume_ref = float(np.median(avg_volumes)) + volume_threshold = args.volume_keep_ratio * volume_ref + + print( + f"[INFO] Density ref={count_ref:.3f}, threshold={count_threshold:.3f} (ratio={args.density_keep_ratio})" + ) + print( + f"[INFO] Volume ref={volume_ref:.6e}, threshold={volume_threshold:.6e} (ratio={args.volume_keep_ratio})" + ) + + keep_by_density = counts >= count_threshold + keep_by_detail = (counts < count_threshold) & (avg_volumes <= volume_threshold) + keep_raw = keep_by_density | keep_by_detail + + kept_voxels_raw = int(np.count_nonzero(keep_raw)) + print( + f"[INFO] Voxels kept before halo: {kept_voxels_raw} ({kept_voxels_raw / voxel_count:.2%})" + ) + + keep_final = keep_raw.copy() + if args.halo_voxels > 0 and np.any(keep_raw): + offsets = build_halo_offsets(args.halo_voxels) + if offsets.size: + for idx in np.flatnonzero(keep_raw): + base = unique_voxels[idx] + for offset in offsets: + neighbor = tuple(int(v) for v in base + offset) + neighbor_idx = voxel_to_index.get(neighbor) + if neighbor_idx is not None: + keep_final[neighbor_idx] = True + + kept_voxels_final = int(np.count_nonzero(keep_final)) + print( + f"[INFO] Voxels kept after halo: {kept_voxels_final} ({kept_voxels_final / voxel_count:.2%})" + ) + + if not args.disable_component_filter and np.any(keep_final): + keep_final, component_count, largest_size = filter_largest_component( + keep_final, unique_voxels, voxel_to_index, NEIGHBOR_OFFSETS + ) + print( + f"[INFO] Component filter kept 1 of {component_count} components ({largest_size} voxels)" + ) + elif args.disable_component_filter: + print("[INFO] Component filter disabled; keeping all voxel clusters") + + keep_mask = keep_final[inverse_indices] + kept_gaussians = int(np.count_nonzero(keep_mask)) + print(f"[INFO] Gaussians kept after voxel filtering: {kept_gaussians}") + + if kept_gaussians == 0: + raise ValueError("Filtering removed all Gaussians; consider relaxing thresholds") + + if args.max_gaussians is not None and kept_gaussians > args.max_gaussians: + rng = np.random.default_rng(args.seed) + kept_indices = np.flatnonzero(keep_mask) + sample = rng.choice(kept_indices, size=args.max_gaussians, replace=False) + sample.sort() + new_keep_mask = np.zeros_like(keep_mask, dtype=bool) + new_keep_mask[sample] = True + keep_mask = new_keep_mask + kept_gaussians = args.max_gaussians + print(f"[INFO] Downsampled to max_gaussians={args.max_gaussians}") + + filtered_vertices = np.asarray(vertex_data[keep_mask], dtype=vertex_data.dtype) + + if output_path == input_path: + backup_path = input_path.with_suffix(input_path.suffix + ".bak") + if not backup_path.exists(): + print(f"[INFO] Creating backup at {backup_path}") + shutil.copyfile(input_path, backup_path) + else: + print(f"[WARNING] Backup already exists at {backup_path}; reusing it.") + + output_path.parent.mkdir(parents=True, exist_ok=True) + vertex_element = PlyElement.describe(filtered_vertices, "vertex") + new_ply = PlyData([vertex_element], text=ply.text, byte_order=ply.byte_order) + new_ply.comments = list(ply.comments) + print(f"[INFO] Writing cleaned PLY to {output_path}") + new_ply.write(str(output_path)) + print(f"[INFO] Done. Final Gaussians: {kept_gaussians}") + + +if __name__ == "__main__": + main() diff --git "a/milo/clean ply/clean_3dgs.py \350\256\276\350\256\241\350\223\235\345\233\276.md" "b/milo/clean ply/clean_3dgs.py \350\256\276\350\256\241\350\223\235\345\233\276.md" new file mode 100644 index 0000000..48049bd --- /dev/null +++ "b/milo/clean ply/clean_3dgs.py \350\256\276\350\256\241\350\223\235\345\233\276.md" @@ -0,0 +1,594 @@ +# clean\_3dgs.py设计蓝图 + +目标:输入一个大场景的3DGaussianSplatting模型,自动裁掉那些不重要、很稀疏、很糙的外围区域,只保留核心部分,然后导出结果。 + +--- + +## 0.背景与目标 + +我们当前的3DGS是从一个大范围实地扫描得到的:广场是高优先级区域,拍得多、近距离、细节高;广场之外(马路、码头)是低优先级区域,拍得少、远距离、细节糙。 + +标准3DGS表示法(Kerbletal.2023之后的常规形式)里,每个“高斯”就是一个各向异性的3D椭球,带有这些属性: + +* 中心点位置`(x,y,z)`(世界坐标,单位通常是米或类似尺度) + +* 尺度/半轴长度`(sx,sy,sz)`或等价参数,用来描述它的体积/空间覆盖范围 + +* 方向(旋转矩阵/四元数/SH主方向) + +* 颜色/不透明度等渲染属性 + + +经验规律: + +* 被相机近距离大量观测的部分,会出现更多高斯,且这些高斯更“小更精细”。 + +* 远处/很少扫到的部分,高斯稀疏、而且单个高斯会更“胖”(体积更大,用少量低分辨率大泡泡去解释远景)。 + + +我们想要做的,就是利用这些统计差异,把“重点区域”自动保留下来,而把“低价值外围”整体裁掉。 + +更具体地,脚本要实现: + +1. 把整个3DGS场景划分成一堆立方体小格(voxel分块,像3D魔方)。 + +2. 统计每个小格里的高斯分布特征。 + +3. 根据统计特征决定哪些格子属于“应该保留的区域”。 + +4. 删除“不重要格子”里的高斯。 + +5. 处理边界/细杆子之类的特殊情况,避免把我们其实想保留的东西误删。 + +6. 输出一个“干净版”的3DGS(同格式),并可选地导出网格(OBJ/STL)。 + + +--- + +## 1.预期的输入/输出接口 + +### 输入 + +* 主输入:一个3DGS点云/高斯云文件 + 通常是`.ply`,里面一行/一条记录就是一个高斯,至少包含: + + * `x,y,z` + + * 三个尺度参数(可以叫`scale_x/scale_y/scale_z`或`scale`,`rot`,`covariance_xx...`;细节随你的格式) + + * 颜色/不透明度等我们可以原样透传 + + + 假设这些字段是可以被读出来放到内存中的。 + +* 可选参数(命令行或函数参数): + + * `voxel_size`:体素边长,决定分块粒度。当前脚本默认 **0.25m**;缩小时能捕捉薄片结构,但要同步考虑`halo_voxels`带来的物理膨胀厚度。 + + * `density_keep_ratio`:决定“密度阈值”的比例。默认 **0.30**,越小越宽松,能保留稀疏区域,但也会留下更多噪声小岛。 + + * `density_reference_percentile`:用于估计高密度参考的排名百分比。默认 **0.10**,即取最密集的前10%体素来求参考密度。 + + * `volume_keep_ratio`:决定“高精细体素”的体积阈值比例,详见第4节。默认 **3.0**,表示体素平均体积只要不超过超精细参考的3倍就可被细节守护逻辑保留。 + + * `volume_reference_percentile`:用来抽取“超精细”体素的比例。默认 **0.15**,即最小15%的`avg_volume`用于建立精细参考值。 + + * `halo_voxels`:保护带膨胀半径(整数,单位=体素数)。默认 **3**;体素越小,需要的halo步数通常越大才能保证同等物理厚度。 + + * `max_gaussians`:(可选)清理完后最多保留多少高斯。如果数量仍然太多,就做最后一步的随机/优先采样。 + + * `disable_component_filter`:布尔开关。默认不开,表示会自动剔除与主体不连通的小岛;开启后保留所有连通块。 + + +这些参数中,`voxel_size`,`density_keep_ratio`,`halo_voxels`最直接决定裁剪范围;`density_reference_percentile`与`volume_reference_percentile`则细化阈值自适应,常常需要和前三者联动调整。 + +### 输出 + +* 一个新的3DGS文件(同输入格式,比如`.ply`),但只含“被保留”的高斯。 + + +--- + +## 2.总体流水线(高层概览) + +整个脚本逻辑分成8个阶段: + +1. **读取数据** + 读入原始3DGS文件,提取每个高斯的: + + * 坐标`xyz` + + * 尺度信息(用来估计体积/大小) + + * 其它属性(原样保存,以便最终写回) + +2. **统计全局范围** + 找到所有高斯的包围盒(boundingbox),用于后面的体素划分。 + +3. **建立体素网格索引** + 用`voxel_size`把整个包围盒量化成3D网格坐标`(i,j,k)`,并把每个高斯分配到它所在的体素。 + +4. **为每个体素计算特征** + 对每个体素,计算: + + * `count`:这个小格里有多少个高斯 + + * `avg_volume`:这些高斯的平均体积(高斯大小) + + * (可存更多:中位数体积、最大体积、透明度均值等,看你需要) + +5. **根据密度和高斯大小做第一次筛选** + 目标:打上标签“保留/丢弃”到每个体素。 + + * 密度很高的体素→保留 + + * 密度低但平均高斯体积很小(意味着这是高精度、细杆子那类局部结构)→也保留 + + * 其他→丢弃 + 这样可以避免“孤立柱子被删”的问题。 + +6. **空间膨胀(halo)** + 目标:避免硬切边界。 + 把“保留体素”做一层或多层空间膨胀:凡是靠近保留体素的邻居体素,也一并标记为保留。 + 这相当于在三维上做一个max-pooling风格的3×3×3(或更大核)膨胀操作。 + +7. **回写到高斯层面** + 只保留落在“保留体素”里的高斯,丢弃剩下的高斯。 + 如果还指定了`max_gaussians`,在这一层之后再做一次下采样。 + +8. **导出** + + * 写新的3DGS(.ply) + + +--- + +## 3.细节设计 + +### 3.1读取与内部表示 + +把每个高斯存成一条结构体/记录,包含: + +* `pos=(x,y,z)` + +* `scale=(sx,sy,sz)`:这是高斯主轴方向上的标准差/半径。不同实现里名字可能叫`scale`,`covA`,`covB`,或者通过log-scale存的。只要能把它变成三个实际长度。 + +* 旋转/方向信息(保持原样,不一定要用到) + +* 颜色/不透明度等视觉属性(保持原样) + +* 一个唯一的index(方便之后回写) + + +这些数据要被全部读进内存,以便进行分组和过滤。 + +如果文件非常大(几百万高斯),内存还是够的(几百万行的float通常是几百MB级别,现代工作站可以撑)。如果特别极端,才需要分块流式处理;本设计先默认全量载入内存。 + +### 3.2统计全局范围(boundingbox) + +对所有高斯的`(x,y,z)`分别取最小值和最大值,得到: + +* `min_x,min_y,min_z` + +* `max_x,max_y,max_z` + + +稍微往外扩一丢丢(比如加/减一个很小的epsilon),避免边界精度问题。 + +这个包围盒决定了体素网格的整体尺寸: +网格大小大约是 + +* `nx=ceil((max_x-min_x)/voxel_size)` + +* `ny=ceil((max_y-min_y)/voxel_size)` + +* `nz=ceil((max_z-min_z)/voxel_size)` + + +### 3.3把每个高斯映射到体素(voxelization) + +体素坐标`(i,j,k)`的计算: + +1. 对每个高斯中心`(x,y,z)`: + + * `i=floor((x-min_x)/voxel_size)` + + * `j=floor((y-min_y)/voxel_size)` + + * `k=floor((z-min_z)/voxel_size)` + +2. 用`(i,j,k)`作为key,把这个高斯的index放进一个字典(哈希表) + + * key:`(i,j,k)` + + * value:一个高斯索引列表 + + +这样我们不会真的分配一个`nx*ny*nz`的稠密3D数组(可能会很大),而是只记录actually占用的体素。场景通常是稀疏的,这样会节省内存。 + +> 备注:这一点很重要。真实户外/广场场景的包围盒可能是几十米甚至上百米,稠密0.2m体素会爆内存。哈希体素(sparsevoxelmap)可以避免这个问题。 + +### 3.4体素特征计算 + +对于每个体素key`(i,j,k)`,我们要算两类指标: + +1. **密度指标:** + + * `count`=这个体素里高斯的数量 + 你也可以归一化成“数量/体素体积”,不过体素体积是常数(`voxel_size^3`),所以没必要;直接用数量即可。 + +2. **高斯平均体积指标:** + 我们需要一个“这个体素里的高斯是不是很细”的度量,来保护细杆子、路灯、旗杆那类孤立结构。 + + 对单个高斯的“体积”可以用它的主轴尺度乘积近似: + + * `gaussian_volume≈sx*sy*sz` + + 直觉: + + * 近处、被充分观察到的几何会用很多又小又尖锐的高斯,`sx,sy,sz`都比较小→乘积小。 + + * 远处杂乱背景会用几个巨大、胖乎乎的高斯去糊→尺度大→乘积大。 + + 对体素内所有高斯: + + * `avg_volume=mean(gaussian_volume)`(或者中位数也可以,更稳健) + + +我们会用`avg_volume`来判断“这格子虽然稀疏,但是不是包含很高精度的小结构”。 + +同时把每个体素的这些统计都保存下来,供后续阈值判断用。 + +--- + +## 4.体素级别的第一次分类(保留/丢弃) + +### 4.1基于密度的筛选 + +我们需要确定“高密度是什么级别”。做法: + +1. 收集所有体素的`count`。 + +2. 找到一个“参考高密度值”。一个稳妥的办法是: + +* 把所有`count`从大到小排序, + +* 取排名前若干百分位(脚本默认`density_reference_percentile=0.10`,即前10%)的平均值或中位数,叫它`count_ref`。 + + * 直觉:这些就是“真的被重点扫过的区域”的典型密度。 + +3. 设定密度阈值: + +* `count_threshold=density_keep_ratio*count_ref` + +* 例如当前默认`density_keep_ratio=0.3`就是“只要达到核心高密度区域的大约三成,我也认为它够重要”。 + + +如果一个体素的`count>=count_threshold`,则标记为“KEEP\_BY\_DENSITY=True”。 + +这样就筛出一大块我们最关心的广场主体。 + +### 4.2基于细节(高斯体积)的补救 + +这个是用来保护“单根柱子/灯杆/薄栏杆”场景的。它们的问题是: + +* count可能很低(因为真就只有一小撮点) + +* 但是那些高斯都很小很尖锐(代表高精细结构,值得保留) + + +做法: + +1. 收集所有体素的`avg_volume`。 + +2. 找出“高精细”的参考体积: + +* 将所有`avg_volume`排序(从小到大,因为越小=越细致), + +* 取前若干百分位(默认`volume_reference_percentile=0.15`,即最小15%)的平均值或中位数,叫它`volume_ref`。 + +3. 设定体积阈值: + + * `volume_threshold=volume_keep_ratio*volume_ref` + + * 注意这里的乘法方向要想清楚: + + * 体素越精细,`avg_volume`越小。 + + * 我们希望保留那些“跟最精细区域差不多小”的体素。 + +* 如果`volume_keep_ratio=3.0`(脚本当前默认),意思是“只要这个体素的avg\_volume不比核心超细区域的大三倍,我也保留”。 + + +最终规则: + +* 如果一个体素`count=20就保留”。 + 这是为了让脚本对不同场景自适应:如果你的场景整体都很高密度,那阈值会相应变高;如果整体都稀疏,阈值就会下降。 + +* halo步骤(3Ddilation)就是你设想的“卷积保护”,防止切得太硬,也让广场边缘保留下来而不是被一刀切掉。`voxel_size`越小,记得同步调大`halo_voxels`,脚本当前默认组合是0.25m×3≈0.75m的保护带。 + +* 连通性过滤会自动保留最大的连通块,清除远处小岛;只有在场景确实包含多个独立主体时才需要用`--disable-component-filter`关闭。 + +* 输出阶段,.ply只是删行,尽量不改schema,这样MILo后面的mesh优化、渲染等都能无缝接着跑。 + +* 网格导出的两条路线(未来): + + * 快速但粗糙的点云→Poisson→OBJ/STL + + * 或者把裁剪后结果喂给MILo的mesh优化管线,拿高质量mesh,再用已有的`clean_convert_mesh`风格脚本去导出最终OBJ/STL。 + + +--- diff --git a/milo/clean ply/clean_convert_mesh.py b/milo/clean ply/clean_convert_mesh.py new file mode 100644 index 0000000..0873d31 --- /dev/null +++ b/milo/clean ply/clean_convert_mesh.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# clean_convert_mesh.py +# Usage: +# python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply +# python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --simplify 300000 +# python clean_convert_mesh.py --in ./output/Ignatius/mesh_learnable_sdf.ply --keep-components 0.02 --no-glb + +import argparse +import os +import sys +from pathlib import Path + +# deps: pip install pymeshlab trimesh plyfile +try: + import pymeshlab as ml +except Exception as e: + print("[ERR] pymeshlab not available. `pip install pymeshlab` first.\n", e) + sys.exit(1) + +try: + import trimesh +except Exception as e: + print("[WARN] trimesh not available. GLB export will be disabled.\n", e) + trimesh = None + + +def human(n: int) -> str: + return f"{n:,}" + + +def bbox_diag_from_ms(ms: ml.MeshSet) -> float: + """Compute bounding-box diagonal length from current mesh (in MeshSet).""" + vm = ms.current_mesh().vertex_matrix() + if vm.size == 0: + return 0.0 + vmin = vm.min(axis=0) + vmax = vm.max(axis=0) + return float(((vmax - vmin) ** 2).sum() ** 0.5) + + +def print_stats(tag: str, ms: ml.MeshSet): + m = ms.current_mesh() + print(f"{tag:>10} | V={human(m.vertex_number())} F={human(m.face_number())}") + + +def clean_mesh(ms: ml.MeshSet, keep_components_frac: float): + """ + 温和清理,不破坏整体拓扑: + - 去重顶点/面、移除未引用顶点 + - 修复/去除非流形边(依版本自动选择) + - 依据直径阈值移除小浮块 + - 重新计算法线 + """ + # 基础去重/引用修复 + if hasattr(ms, "meshing_remove_duplicate_faces"): + ms.meshing_remove_duplicate_faces() + if hasattr(ms, "meshing_remove_duplicate_vertices"): + ms.meshing_remove_duplicate_vertices() + if hasattr(ms, "meshing_remove_unreferenced_vertices"): + ms.meshing_remove_unreferenced_vertices() + + # 非流形边:不同版本函数名不一致,这里做兼容 + if hasattr(ms, "meshing_repair_non_manifold_edges"): + ms.meshing_repair_non_manifold_edges() + elif hasattr(ms, "meshing_remove_non_manifold_edges"): + ms.meshing_remove_non_manifold_edges() + + # 小组件/浮块移除(以整体 bbox 对角线为参照) + diag = bbox_diag_from_ms(ms) + if keep_components_frac > 0 and diag > 0 and hasattr(ms, "meshing_remove_isolated_pieces_wrt_diameter"): + thr = float(diag * keep_components_frac) # 绝对长度阈值 + ms.meshing_remove_isolated_pieces_wrt_diameter(mincomponentdiag=thr) + + # 法线 + if hasattr(ms, "compute_normals_for_point_sets"): + ms.compute_normals_for_point_sets() + if hasattr(ms, "compute_normals_for_meshes"): + ms.compute_normals_for_meshes() + + + +def simplify_mesh(ms: ml.MeshSet, target_tris: int): + """ + Quadric decimation(网格简化),保守设置;如果当前面数已小于目标则跳过。 + """ + current_f = ms.current_mesh().face_number() + if target_tris <= 0 or current_f <= target_tris: + print(f"[INFO] Skip simplify: current F={human(current_f)} <= target {human(max(target_tris,0))}") + return + + print(f"[INFO] Simplifying: {human(current_f)} → {human(target_tris)} (quadric)") + # 尽量保持边界与法线 + ms.meshing_decimation_quadric_edge_collapse( + targetfacenum=target_tris, + preservenormal=True, + preservetopology=True, + qualitythr=0.3, # 较保守 + optimalplacement=True, + planarquadric=True, + autoclean=True + ) + ms.compute_normals_for_meshes() + + +def export_all(ms: ml.MeshSet, out_dir: Path, stem: str, do_ply: bool, do_obj: bool, do_glb: bool): + """ + 为了兼容不同版本的 pymeshlab,这里保存时不传任何可选参数。 + 先保存 PLY/OBJ(如果启用),然后用 trimesh 从 PLY 导出 GLB。 + """ + out_dir.mkdir(parents=True, exist_ok=True) + ply_path = out_dir / f"{stem}.ply" + obj_path = out_dir / f"{stem}.obj" + glb_path = out_dir / f"{stem}.glb" + + # 1) 保存 PLY(建议启用,后面 GLB 也依赖它) + if do_ply: + ms.save_current_mesh(str(ply_path)) # 不传 save_* 参数以避免版本差异 + print(f"[OK] Saved: {ply_path}") + + # 2) 保存 OBJ(同样不带可选参数;注意 OBJ 不支持存储逐顶点颜色) + if do_obj: + ms.save_current_mesh(str(obj_path)) + print(f"[OK] Saved: {obj_path}") + + # 3) 导出 GLB(走 trimesh) + if do_glb: + if trimesh is None: + print("[WARN] trimesh not installed; skip GLB.") + return + + # 若用户禁用了 PLY 导出,则先写一个中间 PLY 供 trimesh 读取 + tmp_ply = ply_path + need_tmp = False + if not do_ply or not tmp_ply.exists(): + tmp_ply = out_dir / f"{stem}__tmp_for_glb.ply" + ms.save_current_mesh(str(tmp_ply)) + need_tmp = True + + tm = trimesh.load(str(tmp_ply), process=False) + tm.export(str(glb_path)) + print(f"[OK] Saved: {glb_path}") + + # 可选:清理临时文件 + if need_tmp and tmp_ply.exists(): + try: + tmp_ply.unlink() + except Exception: + pass + + + +def main(): + ap = argparse.ArgumentParser(description="Clean and convert MILo mesh_learnable_sdf.ply to PLY/OBJ/GLB.") + ap.add_argument("--in", dest="in_path", required=True, help="Input PLY path (e.g., ./output/Ignatius/mesh_learnable_sdf.ply)") + ap.add_argument("--out-dir", default=None, help="Output directory (default: same as input)") + ap.add_argument("--stem", default="mesh_clean", help="Output filename stem (default: mesh_clean)") + ap.add_argument("--keep-components", type=float, default=0.02, + help="Remove small isolated components with diameter < frac * bbox_diag (default: 0.02)") + ap.add_argument("--simplify", type=int, default=0, + help="Target triangle count for decimation (0=disable). Example: 300000") + ap.add_argument("--no-ply", action="store_true", help="Do not export PLY") + ap.add_argument("--no-obj", action="store_true", help="Do not export OBJ") + ap.add_argument("--no-glb", action="store_true", help="Do not export GLB") + args = ap.parse_args() + + in_path = Path(args.in_path) + if not in_path.exists(): + print(f"[ERR] Input not found: {in_path}") + sys.exit(2) + + out_dir = Path(args.out_dir) if args.out_dir else in_path.parent + do_ply = not args.no_ply + do_obj = not args.no_obj + do_glb = not args.no_glb + + ms = ml.MeshSet() + try: + ms.load_new_mesh(str(in_path)) + except Exception as e: + print(f"[ERR] Failed to load mesh: {in_path}\n{e}") + sys.exit(3) + + # 基本类型检查 + m = ms.current_mesh() + if m.face_number() == 0: + print("[ERR] The input PLY has 0 faces (looks like a point cloud). Aborting.") + sys.exit(4) + + print_stats("Loaded", ms) + + # 清理 + clean_mesh(ms, keep_components_frac=args.keep_components) + print_stats("Cleaned", ms) + + # 可选简化 + if args.simplify > 0: + simplify_mesh(ms, target_tris=args.simplify) + print_stats("Simplify", ms) + + # 导出 + export_all(ms, out_dir=out_dir, stem=args.stem, do_ply=do_ply, do_obj=do_obj, do_glb=do_glb) + + print("\n[DONE] All good.") + + +if __name__ == "__main__": + main() diff --git a/milo/configs/mesh/medium.yaml b/milo/configs/mesh/medium.yaml new file mode 100644 index 0000000..1f1ec72 --- /dev/null +++ b/milo/configs/mesh/medium.yaml @@ -0,0 +1,62 @@ +# Regularization schedule +start_iter: 8_001 +mesh_update_interval: 1 +stop_iter: 18_000 + +# Depth loss weight +use_depth_loss: true +depth_weight: 0.05 +depth_ratio: 1.0 +mesh_depth_loss_type: "log" + +# Normal loss weight +use_normal_loss: true +normal_weight: 0.05 +use_depth_normal: true + +# Delaunay computation +delaunay_reset_interval: 500 +n_max_points_in_delaunay: 1_200_000 +delaunay_sampling_method: "surface" +filter_large_edges: true +collapse_large_edges: false + +# Rasterization +use_scalable_renderer: false + +# SDF computation +sdf_reset_interval: 500 +sdf_default_isosurface: 0.5 +transform_sdf_to_linear_space: false +min_occupancy_value: 0.0000000001 + +# > For Integrate +use_ema: true +alpha_ema: 0.4 + +# > For TSDF +trunc_margin: 0.002 + +# > For learnable +occupancy_mode: "occupancy_shift" +# > Gaussian centers regularization +enforce_occupied_centers: true +occupied_centers_weight: 0.005 +# > Occupancy labels loss +use_occupancy_labels_loss: true +reset_occupancy_labels_every: 400 +occupancy_labels_loss_weight: 0.005 +# > SDF reset +fix_set_of_learnable_sdfs: true +learnable_sdf_reset_mode: "ema" +learnable_sdf_reset_stop_iter: 13_001 +learnable_sdf_reset_alpha_ema: 0.4 + +method_to_reset_sdf: "depth_fusion" +n_binary_steps_to_reset_sdf: 0 +sdf_reset_linearization_n_steps: 20 +sdf_reset_linearization_enforce_std: 0.5 +depth_fusion_reset_tolerance: 0.1 + +# Foreground Culling +radius_culling: -1.0 diff --git a/milo/configs/optimization/README.md b/milo/configs/optimization/README.md new file mode 100644 index 0000000..9406ae1 --- /dev/null +++ b/milo/configs/optimization/README.md @@ -0,0 +1,236 @@ +# 优化配置文件说明 (Optimization Configuration) + +本目录包含用于`yufu2mesh_new.py`的优化配置文件。通过YAML配置文件,你可以精细控制高斯参数的优化、损失权重和其他训练超参数。 + +## 快速开始 (Quick Start) + +### 基本用法 + +```bash +# 使用默认配置 +python yufu2mesh_new.py --opt_config default + +# 使用预设配置 +python yufu2mesh_new.py --opt_config xyz_only +python yufu2mesh_new.py --opt_config xyz_geometry + +# 使用自定义配置文件(完整路径) +python yufu2mesh_new.py --opt_config /path/to/my_config.yaml +``` + +## 预设配置 (Available Presets) + +### 1. `default.yaml` - 默认配置 +- **优化参数**: 仅XYZ位置 +- **适用场景**: 已有良好初始化的高斯,只需微调位置 +- **特点**: 保守策略,稳定收敛 + +### 2. `xyz_only.yaml` - 纯位置优化 +- **优化参数**: 仅XYZ位置 +- **适用场景**: 与default类似,但所有冻结参数的学习率都为0 +- **特点**: 最保守的优化策略 + +### 3. `xyz_geometry.yaml` - 位置+几何优化 +- **优化参数**: XYZ位置 + Scaling + Rotation +- **适用场景**: 需要调整高斯形状以更好拟合深度和法线 +- **特点**: 法线loss权重增加,mesh正则化权重也相应提高 + +### 4. `xyz_occupancy.yaml` - 位置+占用优化 +- **优化参数**: XYZ位置 + Occupancy Shift +- **适用场景**: 需要精细调整mesh提取质量,改善mesh拓扑 +- **特点**: 增强mesh正则化,较小的occupancy学习率 + +### 5. `full.yaml` - 全参数优化 +- **优化参数**: 所有参数(除base_occupancy) +- **适用场景**: 初始化质量较差,需要全面优化 +- **警告**: 可能导致过拟合,建议谨慎使用 + +## 配置文件结构 (Configuration Structure) + +每个YAML配置文件包含以下4个主要部分: + +### 1. `gaussian_params` - 高斯参数设置 + +控制哪些高斯参数需要被优化,以及各自的学习率。 + +```yaml +gaussian_params: + _xyz: # 高斯中心位置 + trainable: true # 是否训练 + lr: 5.0e-4 # 学习率 + + _features_dc: # 球谐0阶系数(主颜色) + trainable: false + lr: 1.0e-4 + + _scaling: # 高斯椭球的缩放 + trainable: false + lr: 1.0e-4 + + _rotation: # 高斯椭球的旋转 + trainable: false + lr: 1.0e-4 + + # ... 其他参数 +``` + +**可用参数列表**: +- `_xyz`: 高斯中心位置(3D坐标) +- `_features_dc`: 球谐函数0阶系数(基础颜色) +- `_features_rest`: 球谐函数高阶系数(视角相关颜色) +- `_scaling`: 高斯椭球的3轴缩放 +- `_rotation`: 高斯椭球的旋转四元数 +- `_opacity`: 不透明度 +- `_base_occupancy`: SDF占用基础值(通常不训练) +- `_occupancy_shift`: SDF占用偏移量(可学习) + +### 2. `loss_weights` - 损失权重 + +控制各个损失项在总损失中的比重。 + +```yaml +loss_weights: + depth: 1.0 # 深度一致性损失 + normal: 1.0 # 法线一致性损失 + mesh_depth: 0.1 # Mesh深度损失 + mesh_normal: 0.1 # Mesh法线损失 +``` + +### 3. `depth_processing` - 深度处理 + +控制深度图的加载和预处理。 + +```yaml +depth_processing: + clip_min: 0.0 # 最小深度值,null表示不裁剪 + clip_max: null # 最大深度值,null表示不裁剪 +``` + +**常用设置**: +- 室内场景: `clip_max: 10.0` 或 `20.0` +- 室外场景: `clip_max: 50.0` 或 `100.0` +- 无裁剪: `clip_max: null` + +### 4. `mesh_regularization` - Mesh正则化 + +覆盖mesh配置文件中的权重设置。 + +```yaml +mesh_regularization: + depth_weight: 0.1 # Mesh深度项权重 + normal_weight: 0.1 # Mesh法线项权重 +``` + +## 创建自定义配置 (Creating Custom Configuration) + +### 方法1: 复制并修改预设配置 + +```bash +cd /home/zoyo/Desktop/MILo_rtx50/milo/configs/optimization +cp default.yaml my_custom.yaml +# 编辑 my_custom.yaml +``` + +### 方法2: 从模板开始 + +使用`default.yaml`作为模板,它包含所有参数的完整注释。 + +### 配置建议 + +1. **学习率设置**: + - 位置参数 (_xyz): `5e-4` 到 `1e-3` + - 几何参数 (_scaling, _rotation): `1e-4` 到 `5e-4` + - 外观参数 (_features_*): `1e-4` 到 `2.5e-4` + - 占用参数 (_occupancy_shift): `5e-5` 到 `1e-4` + +2. **损失权重平衡**: + - 深度和法线损失通常在 `1.0` 左右 + - Mesh损失通常为深度/法线损失的 `0.1` 到 `0.2` 倍 + - 如果优化几何参数,可适当增加法线权重 + +3. **渐进式优化策略**: + - 第1阶段: 使用`xyz_only`快速收敛位置 + - 第2阶段: 使用`xyz_geometry`细化形状 + - 第3阶段: 使用`xyz_occupancy`优化mesh质量 + +## 调试技巧 (Debugging Tips) + +### 验证配置文件 + +```bash +python test_opt_config.py +``` + +### 查看当前使用的配置 + +运行`yufu2mesh_new.py`时,会在开始打印所有参数设置: + +``` +[INFO] 加载优化配置:xyz_geometry +[INFO] 参数 _xyz: trainable=True, lr=0.0005 +[INFO] 参数 _scaling: trainable=True, lr=0.0001 +... +[INFO] Loss权重配置: + > depth: 1.0 + > normal: 1.5 +... +``` + +### 常见问题 + +**Q: 为什么优化后效果变差了?** +- A: 可能学习率过大,尝试降低学习率或减少可训练参数 + +**Q: 收敛太慢怎么办?** +- A: 可以适当增加学习率,或者增加可训练参数的数量 + +**Q: Mesh质量不好?** +- A: 尝试使用`xyz_occupancy`配置,或增加mesh正则化权重 + +**Q: 旧的命令行参数还能用吗?** +- A: 可以,但会被YAML配置覆盖,建议迁移到YAML配置 + +## 配置迁移 (Migration from Old Parameters) + +如果你之前使用命令行参数: + +```bash +# 旧方式 +python yufu2mesh_new.py \ + --lr 5e-4 \ + --depth_loss_weight 1.0 \ + --normal_loss_weight 1.5 \ + --mesh_depth_weight 0.1 \ + --mesh_normal_weight 0.15 +``` + +现在应该创建一个YAML配置文件: + +```yaml +# my_config.yaml +gaussian_params: + _xyz: + trainable: true + lr: 5.0e-4 + # ... 其他参数设为不可训练 + +loss_weights: + depth: 1.0 + normal: 1.5 + mesh_depth: 0.1 + mesh_normal: 0.15 + +# ... 其他设置 +``` + +然后使用: + +```bash +python yufu2mesh_new.py --opt_config my_config +``` + +## 更多信息 + +- 配置测试脚本: `test_opt_config.py` +- 主程序: `yufu2mesh_new.py` +- 相关文档: 参见各配置文件内的详细注释 diff --git a/milo/configs/optimization/default.yaml b/milo/configs/optimization/default.yaml new file mode 100644 index 0000000..0ecb48b --- /dev/null +++ b/milo/configs/optimization/default.yaml @@ -0,0 +1,148 @@ + +# ============================================================================ +# 高斯参数优化配置文件 (Gaussian Parameters Optimization Config) +# ============================================================================ +# 此配置文件用于控制yufu2mesh训练过程中的所有可训练参数、学习率和损失权重 +# This config controls all trainable parameters, learning rates and loss weights + +# ---------------------------------------------------------------------------- +# 1. 高斯参数优化设置 (Gaussian Parameters Optimization) +# ---------------------------------------------------------------------------- +# 定义哪些高斯参数需要被优化,以及各自的学习率 +# Define which Gaussian parameters to optimize and their learning rates + +gaussian_params: + # 位置参数 (Position parameters) + # _xyz: 高斯中心位置,3D坐标 + _xyz: + trainable: true # 是否训练此参数 (whether to train this parameter) + lr: 5.0e-4 # 学习率 (learning rate) + + # 外观参数 (Appearance parameters - color/SH coefficients) + # _features_dc: 球谐函数0阶系数(主颜色) + _features_dc: + trainable: false # 默认冻结外观参数 (freeze appearance by default) + lr: 1.0e-4 + + # _features_rest: 球谐函数高阶系数(视角相关颜色) + _features_rest: + trainable: false + lr: 1.0e-4 + + # 几何形状参数 (Geometry parameters) + # _scaling: 高斯椭球的3轴缩放 + _scaling: + trainable: false # 默认冻结形状参数 (freeze geometry by default) + lr: 1.0e-4 + + # _rotation: 高斯椭球的旋转四元数 + _rotation: + trainable: false + lr: 1.0e-4 + + # 不透明度参数 (Opacity parameter) + # _opacity: 每个高斯的不透明度 + _opacity: + trainable: false # 默认冻结不透明度 (freeze opacity by default) + lr: 1.0e-4 + + # SDF占用参数 (SDF Occupancy parameters for mesh extraction) + # _base_occupancy: 占用网格基础值(由深度融合初始化,不训练) + _base_occupancy: + trainable: false # 通常不训练base值 (usually don't train base values) + lr: 0.0 + + # _occupancy_shift: 占用网格偏移量(可学习的调整) + _occupancy_shift: + trainable: false # 默认冻结,避免过度修改mesh topology + lr: 1.0e-4 + +# ---------------------------------------------------------------------------- +# 2. 损失函数权重 (Loss Weights) +# ---------------------------------------------------------------------------- +# 控制各个损失项在总损失中的比重 +# Control the weight of each loss component in the total loss + +loss_weights: + # 深度一致性损失 (Depth consistency loss) + # 衡量渲染深度与GT深度的差异 (measures difference between rendered and GT depth) + depth: 1.0 + + # 法线一致性损失 (Normal consistency loss) + # 衡量渲染法线与GT法线的差异 (measures difference between rendered and GT normals) + normal: 1.0 + + # Mesh正则化 - 深度损失 (Mesh regularization - depth loss) + # 从提取的mesh渲染的深度与高斯渲染深度的一致性 + mesh_depth: 0.1 + + # Mesh正则化 - 法线损失 (Mesh regularization - normal loss) + # 从提取的mesh渲染的法线与高斯渲染法线的一致性 + mesh_normal: 0.1 + +# ---------------------------------------------------------------------------- +# 3. 深度处理设置 (Depth Processing Settings) +# ---------------------------------------------------------------------------- +# 控制深度图的加载和预处理 +# Control depth map loading and preprocessing + +depth_processing: + # 深度裁剪最小值 (minimum depth clipping value) + # 小于此值的深度会被裁剪,设为null或0表示不裁剪 + clip_min: 0.0 + + # 深度裁剪最大值 (maximum depth clipping value) + # 大于此值的深度会被裁剪,设为null表示不裁剪 + # 示例: 对于户外场景可以设置为50.0或100.0来去除远景噪声 + clip_max: null + +# ---------------------------------------------------------------------------- +# 4. Mesh正则化权重覆盖 (Mesh Regularization Weight Override) +# ---------------------------------------------------------------------------- +# 这些权重会覆盖mesh配置文件(如configs/mesh/medium.yaml)中的默认值 +# These weights override the defaults in mesh config files (e.g., configs/mesh/medium.yaml) + +mesh_regularization: + # mesh深度项权重 (mesh depth term weight) + # 控制mesh提取过程中深度一致性的重要程度 + depth_weight: 0.1 + + # mesh法线项权重 (mesh normal term weight) + # 控制mesh提取过程中法线平滑度的重要程度 + normal_weight: 0.1 + +# ============================================================================ +# 使用说明 (Usage Instructions) +# ============================================================================ +# +# 1. 基础使用 (Basic usage): +# python yufu2mesh_new.py --opt_config default +# +# 2. 使用自定义配置 (Use custom config): +# python yufu2mesh_new.py --opt_config my_config +# (会自动查找 configs/optimization/my_config.yaml) +# +# 3. 使用完整路径 (Use full path): +# python yufu2mesh_new.py --opt_config /path/to/custom_config.yaml +# +# ============================================================================ +# 常用配置示例 (Common Configuration Examples) +# ============================================================================ +# +# 场景1: 只优化位置 (Only optimize positions) +# - 适用于已有良好初始化的高斯,只需微调位置 +# - _xyz: trainable=true, 其他全部false +# +# 场景2: 优化位置+形状 (Optimize positions + geometry) +# - 适用于需要调整高斯形状以更好拟合深度的情况 +# - _xyz, _scaling, _rotation: trainable=true +# +# 场景3: 全参数优化 (Full optimization) +# - 适用于初始化质量较差,需要全面优化的情况 +# - 所有参数: trainable=true (但要小心过拟合) +# +# 场景4: 优化位置+占用 (Optimize positions + occupancy) +# - 适用于需要精细调整mesh提取质量的情况 +# - _xyz, _occupancy_shift: trainable=true +# +# ============================================================================ \ No newline at end of file diff --git a/milo/configs/optimization/full.yaml b/milo/configs/optimization/full.yaml new file mode 100644 index 0000000..e08b7ba --- /dev/null +++ b/milo/configs/optimization/full.yaml @@ -0,0 +1,43 @@ +# 全参数优化配置 (Full parameter optimization) +# 适用场景: 初始化质量较差,需要全面优化 +# 警告: 可能导致过拟合,建议谨慎使用 + +gaussian_params: + _xyz: + trainable: true + lr: 5.0e-4 + _features_dc: + trainable: true # 优化颜色 + lr: 2.5e-4 + _features_rest: + trainable: true + lr: 2.5e-5 # 高阶SH系数用更小学习率 + _scaling: + trainable: true + lr: 1.0e-4 + _rotation: + trainable: true + lr: 1.0e-4 + _opacity: + trainable: true # 优化不透明度 + lr: 5.0e-5 + _base_occupancy: + trainable: false # base值仍然不训练 + lr: 0.0 + _occupancy_shift: + trainable: true + lr: 5.0e-5 + +loss_weights: + depth: 1.0 + normal: 1.5 + mesh_depth: 0.15 + mesh_normal: 0.15 + +depth_processing: + clip_min: 0.0 + clip_max: null + +mesh_regularization: + depth_weight: 0.12 + normal_weight: 0.12 diff --git a/milo/configs/optimization/geometry_only.yaml b/milo/configs/optimization/geometry_only.yaml new file mode 100644 index 0000000..e6d379f --- /dev/null +++ b/milo/configs/optimization/geometry_only.yaml @@ -0,0 +1,42 @@ +# 纯几何优化配置 (Geometry-only optimization) +# 适用场景: 位置已经优化好,只需要调整高斯的形状和朝向以更好地拟合法线 + +gaussian_params: + _xyz: + trainable: false # 冻结位置 + lr: 0.0 + _features_dc: + trainable: false + lr: 0.0 + _features_rest: + trainable: false + lr: 0.0 + _scaling: + trainable: true # 优化高斯椭球的尺度 + lr: 1.0e-4 + _rotation: + trainable: true # 优化高斯椭球的朝向 + lr: 1.0e-4 + _opacity: + trainable: false + lr: 0.0 + _base_occupancy: + trainable: false + lr: 0.0 + _occupancy_shift: + trainable: false + lr: 0.0 + +loss_weights: + depth: 0.5 # 降低深度权重,因为位置不变 + normal: 2.0 # 大幅增加法线权重,几何形状主要影响法线 + mesh_depth: 0.05 + mesh_normal: 0.2 # 增加mesh法线权重 + +depth_processing: + clip_min: 0.0 + clip_max: null + +mesh_regularization: + depth_weight: 0.05 + normal_weight: 0.2 # 增强法线正则化 diff --git a/milo/configs/optimization/xyz_geometry.yaml b/milo/configs/optimization/xyz_geometry.yaml new file mode 100644 index 0000000..2044c51 --- /dev/null +++ b/milo/configs/optimization/xyz_geometry.yaml @@ -0,0 +1,42 @@ +# 优化位置+几何配置 (Position + Geometry optimization) +# 适用场景: 需要调整高斯形状以更好拟合深度和法线 + +gaussian_params: + _xyz: + trainable: true + lr: 5.0e-4 + _features_dc: + trainable: false + lr: 0.0 + _features_rest: + trainable: false + lr: 0.0 + _scaling: + trainable: true # 允许调整高斯椭球的尺度 + lr: 1.0e-4 + _rotation: + trainable: true # 允许调整高斯椭球的朝向 + lr: 1.0e-4 + _opacity: + trainable: false + lr: 0.0 + _base_occupancy: + trainable: false + lr: 0.0 + _occupancy_shift: + trainable: false + lr: 0.0 + +loss_weights: + depth: 1.0 + normal: 1.5 # 增加法线权重,因为几何参数会影响法线 + mesh_depth: 0.1 + mesh_normal: 0.15 + +depth_processing: + clip_min: 0.0 + clip_max: null + +mesh_regularization: + depth_weight: 0.1 + normal_weight: 0.15 \ No newline at end of file diff --git a/milo/configs/optimization/xyz_occupancy.yaml b/milo/configs/optimization/xyz_occupancy.yaml new file mode 100644 index 0000000..26f06ad --- /dev/null +++ b/milo/configs/optimization/xyz_occupancy.yaml @@ -0,0 +1,42 @@ +# 优化位置+占用配置 (Position + Occupancy optimization) +# 适用场景: 需要精细调整mesh提取质量,改善mesh拓扑 + +gaussian_params: + _xyz: + trainable: true + lr: 5.0e-4 + _features_dc: + trainable: false + lr: 0.0 + _features_rest: + trainable: false + lr: 0.0 + _scaling: + trainable: false + lr: 0.0 + _rotation: + trainable: false + lr: 0.0 + _opacity: + trainable: false + lr: 0.0 + _base_occupancy: + trainable: false + lr: 0.0 + _occupancy_shift: + trainable: true # 允许调整SDF占用值 + lr: 5.0e-5 # 较小的学习率避免拓扑突变 + +loss_weights: + depth: 1.0 + normal: 1.0 + mesh_depth: 0.2 # 增加mesh损失权重 + mesh_normal: 0.2 + +depth_processing: + clip_min: 0.0 + clip_max: null + +mesh_regularization: + depth_weight: 0.15 # 增加mesh正则化强度 + normal_weight: 0.15 diff --git a/milo/configs/optimization/xyz_only.yaml b/milo/configs/optimization/xyz_only.yaml new file mode 100644 index 0000000..f9f57bf --- /dev/null +++ b/milo/configs/optimization/xyz_only.yaml @@ -0,0 +1,42 @@ +# 只优化位置配置 (Position-only optimization) +# 适用场景: 已有良好初始化,只需微调位置以匹配深度 + +gaussian_params: + _xyz: + trainable: true + lr: 5.0e-4 + _features_dc: + trainable: false + lr: 0.0 + _features_rest: + trainable: false + lr: 0.0 + _scaling: + trainable: false + lr: 0.0 + _rotation: + trainable: false + lr: 0.0 + _opacity: + trainable: false + lr: 0.0 + _base_occupancy: + trainable: false + lr: 0.0 + _occupancy_shift: + trainable: false + lr: 0.0 + +loss_weights: + depth: 1.0 + normal: 1.0 + mesh_depth: 0.1 + mesh_normal: 0.1 + +depth_processing: + clip_min: 0.0 + clip_max: null + +mesh_regularization: + depth_weight: 0.1 + normal_weight: 0.1 \ No newline at end of file diff --git a/milo/gaussian_renderer/__init__.py b/milo/gaussian_renderer/__init__.py index d05ac5f..074b51c 100644 --- a/milo/gaussian_renderer/__init__.py +++ b/milo/gaussian_renderer/__init__.py @@ -175,8 +175,18 @@ def render_imp(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tens else: colors_precomp = override_color - if culling==None: - culling=torch.zeros(means3D.shape[0], dtype=torch.bool, device='cuda') + if culling is None: + culling = torch.zeros(means3D.shape[0], dtype=torch.bool, device=means3D.device) + else: + if culling.dtype != torch.bool: + culling = culling.to(torch.bool) + if culling.shape != (means3D.shape[0],): + culling = culling.reshape(-1) + if culling.numel() != means3D.shape[0]: + raise ValueError(f"Culling mask has {culling.numel()} entries, " + f"but {means3D.shape[0]} Gaussians provided.") + if not culling.is_contiguous(): + culling = culling.contiguous() # Rasterize visible Gaussians to image, obtain their radii (on screen). rendered_image, radii, accum_max_count = rasterizer( diff --git a/milo/regularization/regularizer/mesh.py b/milo/regularization/regularizer/mesh.py index c6d4c7d..c135f3b 100644 --- a/milo/regularization/regularizer/mesh.py +++ b/milo/regularization/regularizer/mesh.py @@ -76,6 +76,10 @@ def initialize_mesh_regularization( "surface_delaunay_xyz_idx": None, "reset_delaunay_samples": True, "reset_sdf_values": True, + "surface_sample_export_path": None, + "surface_sample_saved": False, + "surface_sample_saved_iter": None, + "latest_mesh": None, } return mesh_renderer, mesh_state @@ -242,6 +246,7 @@ def compute_mesh_regularization( if downsample_gaussians_for_delaunay: print(f"[INFO] Downsampling Delaunay Gaussians from {n_gaussians_to_sample_from} to {n_max_gaussians_for_delaunay}.") + surface_indices_for_export = None if config["delaunay_sampling_method"] == "random": delaunay_xyz_idx = torch.randperm( n_gaussians_to_sample_from, device="cuda" @@ -257,6 +262,7 @@ def compute_mesh_regularization( n_samples=n_max_gaussians_for_delaunay, sampling_mask=delaunay_sampling_radius_mask, ) + surface_indices_for_export = delaunay_xyz_idx elif config["delaunay_sampling_method"] == "surface+opacity": delaunay_xyz_idx = gaussians.sample_surface_gaussians( scene=scene, @@ -268,9 +274,10 @@ def compute_mesh_regularization( n_samples=n_max_gaussians_for_delaunay, sampling_mask=delaunay_sampling_radius_mask, ) + surface_indices_for_export = delaunay_xyz_idx.clone() n_remaining_gaussians_to_sample = n_max_gaussians_for_delaunay - delaunay_xyz_idx.shape[0] if n_remaining_gaussians_to_sample > 0: - mesh_state["surface_delaunay_xyz_idx"] = delaunay_xyz_idx.clone() + mesh_state["surface_delaunay_xyz_idx"] = surface_indices_for_export.clone() opacity_sample_mask = torch.ones(gaussians._xyz.shape[0], device="cuda", dtype=torch.bool) opacity_sample_mask[delaunay_xyz_idx] = False delaunay_xyz_idx = torch.cat( @@ -286,6 +293,19 @@ def compute_mesh_regularization( delaunay_xyz_idx = torch.sort(delaunay_xyz_idx, dim=0)[0] else: raise ValueError(f"Invalid Delaunay sampling method: {config['delaunay_sampling_method']}") + + if ( + surface_indices_for_export is not None + and surface_indices_for_export.numel() > 0 + and mesh_state.get("surface_sample_export_path") + and not mesh_state.get("surface_sample_saved", False) + ): + export_path = mesh_state["surface_sample_export_path"] + gaussians.save_subset_ply(export_path, surface_indices_for_export) + mesh_state["surface_sample_saved"] = True + mesh_state["surface_sample_saved_iter"] = iteration + print(f"[INFO] Exported initial surface Gaussians ({surface_indices_for_export.numel()} points) to {export_path}.") + print(f"[INFO] Downsampled Delaunay Gaussians from {n_gaussians_to_sample_from} to {len(delaunay_xyz_idx)}.") reset_occupancy_labels_for_new_delaunay_sites = True else: @@ -456,6 +476,32 @@ def compute_mesh_regularization( # --- Build and Render Mesh --- mesh = Meshes(verts=verts, faces=faces[faces_mask]) + mesh_triangles = mesh.faces + if mesh_triangles.numel() == 0: + mesh_state["mesh_triangles"] = mesh_triangles + mesh_state["latest_mesh"] = None + return { + "mesh_loss": torch.zeros((), device=gaussians._xyz.device), + "mesh_depth_loss": torch.zeros((), device=gaussians._xyz.device), + "mesh_normal_loss": torch.zeros((), device=gaussians._xyz.device), + "occupied_centers_loss": torch.zeros((), device=gaussians._xyz.device), + "occupancy_labels_loss": torch.zeros((), device=gaussians._xyz.device), + "updated_state": mesh_state, + "mesh_render_pkg": { + "depth": torch.zeros_like(render_pkg.get("median_depth", torch.zeros(1, device=gaussians._xyz.device))), + "normals": torch.zeros_like(render_pkg.get("normal", torch.zeros_like(render_pkg.get("render", torch.zeros(gaussians._xyz.shape[0], device=gaussians._xyz.device))))) + }, + "voronoi_points_count": voronoi_points_count, + "mesh_triangles": mesh_triangles, + } + + mesh_state["mesh_triangles"] = mesh_triangles + # Cache the current mesh for logging (detached to avoid graph retention) + mesh_state["latest_mesh"] = Meshes( + verts=mesh.verts.detach(), + faces=mesh.faces.detach(), + verts_colors=mesh.verts_colors.detach() if mesh.verts_colors is not None else None, + ) mesh_render_pkg = mesh_renderer( mesh, @@ -489,7 +535,7 @@ def compute_mesh_regularization( voronoi_occupancy_labels, _ = evaluate_mesh_occupancy( points=voronoi_points, views=scene.getTrainCameras().copy(), - mesh=Meshes(verts=verts, faces=faces), + mesh=mesh, masks=None, return_colors=True, use_scalable_renderer=config["use_scalable_renderer"], @@ -624,4 +670,5 @@ def reset_mesh_state_at_next_iteration(mesh_state): mesh_state["reset_delaunay_samples"] = True mesh_state["reset_sdf_values"] = True mesh_state["delaunay_tets"] = None + mesh_state["latest_mesh"] = None return mesh_state diff --git a/milo/regularization/sdf/depth_fusion.py b/milo/regularization/sdf/depth_fusion.py index cf3ea67..88d2f78 100644 --- a/milo/regularization/sdf/depth_fusion.py +++ b/milo/regularization/sdf/depth_fusion.py @@ -536,6 +536,8 @@ def evaluate_mesh_occupancy( for cam_id, view in enumerate(tqdm(views, desc=f"Computing occupancy from mesh")): faces_mask = is_in_view_frustum(mesh.verts, view)[mesh.faces].any(axis=1) + if faces_mask.sum() == 0: + continue render_pkg = mesh_renderer( Meshes(verts=mesh.verts, faces=mesh.faces[faces_mask]), cam_idx=cam_id, @@ -609,6 +611,8 @@ def evaluate_mesh_colors( for cam_id, view in enumerate(tqdm(views, desc=f"Computing vertex colors")): faces_mask = is_in_view_frustum(mesh.verts, view)[mesh.faces].any(axis=1) + if faces_mask.sum() == 0: + continue render_pkg = mesh_renderer( Meshes(verts=mesh.verts, faces=mesh.faces[faces_mask]), cam_idx=cam_id, diff --git a/milo/scene/gaussian_model.py b/milo/scene/gaussian_model.py index cce6f67..36882ce 100644 --- a/milo/scene/gaussian_model.py +++ b/milo/scene/gaussian_model.py @@ -635,6 +635,68 @@ def save_ply(self, path): el = PlyElement.describe(elements, 'vertex') PlyData([el]).write(path) + def save_subset_ply(self, path, indices): + if indices is None: + raise ValueError("indices must be provided to save_subset_ply.") + + if not isinstance(indices, torch.Tensor): + indices = torch.as_tensor(indices, dtype=torch.long, device=self._xyz.device) + else: + indices = indices.to(self._xyz.device, dtype=torch.long) + + mkdir_p(os.path.dirname(path)) + + if indices.numel() == 0: + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + elements = np.empty(0, dtype=dtype_full) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + return + + xyz = self._xyz[indices].detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = ( + self._features_dc[indices] + .detach() + .transpose(1, 2) + .flatten(start_dim=1) + .contiguous() + .cpu() + .numpy() + ) + f_rest = ( + self._features_rest[indices] + .detach() + .transpose(1, 2) + .flatten(start_dim=1) + .contiguous() + .cpu() + .numpy() + ) + opacities = self._opacity[indices].detach().cpu().numpy() + scale = self._scaling[indices].detach().cpu().numpy() + rotation = self._rotation[indices].detach().cpu().numpy() + + if self.use_mip_filter: + filter_3D = self.filter_3D[indices].detach().cpu().numpy() + + if self.learn_occupancy: + base_occupancy = self._base_occupancy[indices].detach().cpu().numpy() + occupancy_shift = self._occupancy_shift[indices].detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + to_concatenate = (xyz, normals, f_dc, f_rest, opacities, scale, rotation) + if self.use_mip_filter: + to_concatenate = to_concatenate + (filter_3D,) + if self.learn_occupancy: + to_concatenate = to_concatenate + (base_occupancy, occupancy_shift) + attributes = np.concatenate(to_concatenate, axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + def reset_opacity(self): if self.use_mip_filter: # reset opacity to by considering 3D filter @@ -686,10 +748,25 @@ def load_ply(self, path): features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] - extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split('_')[-1])) if self.max_sh_degree is None: self.max_sh_degree = int(np.sqrt(len(extra_f_names) / 3 + 1) - 1) - assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + else: + expected_extra = 3 * (self.max_sh_degree + 1) ** 2 - 3 + if len(extra_f_names) != expected_extra: + inferred_degree = int(np.sqrt(len(extra_f_names) / 3 + 1) - 1) + print( + f"[WARNING] Expected SH degree {self.max_sh_degree} (extra features {expected_extra}), " + f"but file provides {len(extra_f_names)} extra features. " + f"Adjusting SH degree to {inferred_degree}." + ) + self.max_sh_degree = inferred_degree + expected_extra = 3 * (self.max_sh_degree + 1) ** 2 - 3 + if len(extra_f_names) != expected_extra: + raise ValueError( + f"Cannot reconcile SH degree. File has {len(extra_f_names)} extra features, " + f"which does not match expected {expected_extra} for degree {self.max_sh_degree}." + ) features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) for idx, attr_name in enumerate(extra_f_names): features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) @@ -1064,9 +1141,6 @@ def init_culling(self, num_views): self._culling=torch.zeros((self._xyz.shape[0], num_views), dtype=torch.bool, device='cuda') self.factor_culling=torch.ones((self._xyz.shape[0],1), device='cuda') - - - def depth_reinit(self, scene, render_depth, iteration, num_depth, args, pipe, background): out_pts_list=[] @@ -1281,7 +1355,7 @@ def culling_with_clone(self, scene, render_simp, iteration, args, pipe, backgrou self.factor_culling=count_vis/(count_rad+1e-1) non_prune_mask = init_cdf_mask(importance=imp_score, thres=0.999) - prune_mask = (count_vis<=1)[:,0] + prune_mask = (count_vis==0)[:,0] prune_mask = torch.logical_or(prune_mask, non_prune_mask==False) self.prune_points(prune_mask) @@ -1699,4 +1773,4 @@ def add_densification_stats_radegs( self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) self.xyz_gradient_accum_abs[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,2:], dim=-1, keepdim=True) self.xyz_gradient_accum_abs_max[update_filter] = torch.max(self.xyz_gradient_accum_abs_max[update_filter], torch.norm(viewspace_point_tensor.grad[update_filter,2:], dim=-1, keepdim=True)) - self.denom[update_filter] += 1 \ No newline at end of file + self.denom[update_filter] += 1 diff --git a/milo/scene/mesh.py b/milo/scene/mesh.py index 988cd8c..43415dd 100644 --- a/milo/scene/mesh.py +++ b/milo/scene/mesh.py @@ -3,9 +3,11 @@ import nvdiffrast.torch as dr from scene.cameras import Camera from utils.geometry_utils import transform_points_world_to_view +import os +import nvdiffrast.torch as dr -def nvdiff_rasterization( +def nvdiff_rasterization_old( camera, image_height:int, image_width:int, @@ -44,6 +46,123 @@ def nvdiff_rasterization( _output = _output + (pos,) return _output +def nvdiff_rasterization( + camera, + image_height: int, + image_width: int, + verts: torch.Tensor, + faces: torch.Tensor, + return_indices_only: bool = False, + glctx=None, + return_rast_out: bool = False, + return_positions: bool = False, +): + """ + 与原函数等价的替换版,支持按三角形分块(env: MILO_RAST_TRI_CHUNK), + 并修正:nvdiffrast CUDA 后端的 `ranges` 必须在 CPU。 + """ + import os + import torch + import nvdiffrast.torch as dr + + device = verts.device + dtype = verts.dtype + + # 1) 裁剪空间坐标(与原版一致) + cam_mtx = camera.full_proj_transform + pos = torch.cat([verts, torch.ones([verts.shape[0], 1], device=device, dtype=dtype)], dim=1) + pos = torch.matmul(pos, cam_mtx)[None] # [1,V,4] + + # 2) 准备 tri / 形参 + faces = faces.to(torch.int32).contiguous() + # 与 pos 同设备更稳(nvdiffrast 在很多场景都支持 tri 在 GPU/CPU,但统一更省心) + faces_dev = faces.to(pos.device) + + H, W = int(image_height), int(image_width) + chunk = int(os.getenv("MILO_RAST_TRI_CHUNK", "0") or "0") + use_chunking = chunk > 0 and faces.shape[0] > chunk + + # 快速路径:不分块,完全等价原实现 + if not use_chunking: + rast_out, _ = dr.rasterize(glctx, pos=pos, tri=faces_dev, resolution=[H, W]) + bary_coords = rast_out[..., :2] + zbuf = rast_out[..., 2] + pix_to_face = rast_out[..., 3].to(torch.int32) - 1 # 未命中 => -1 + if return_indices_only: + return pix_to_face + _out = (bary_coords, zbuf, pix_to_face) + if return_rast_out: + _out += (rast_out,) + if return_positions: + _out += (pos,) + return _out + + # 分块路径:ranges 必须在 CPU! + z_ndc = (pos[..., 2:3] / (pos[..., 3:4] + 1e-20)).contiguous() # [1,V,1] + + best_rast = None + best_depth = None + n_faces = int(faces.shape[0]) + start = 0 + + def _normalize_tri_id(rast_chunk, start_idx, count_idx): + tri_raw = rast_chunk[..., 3:4].to(torch.int64) # >0 => 命中 + if tri_raw.numel() == 0: + return rast_chunk[..., 3:4] + maxid = int(tri_raw.max().item()) + if maxid == 0: + return rast_chunk[..., 3:4] + # 如果是局部编号(1..count),平移为全局(1..n_faces) + if maxid <= count_idx: + tri_adj = torch.where(tri_raw > 0, tri_raw + start_idx, tri_raw) + else: + tri_adj = tri_raw + return tri_adj.to(rast_chunk.dtype) + + while start < n_faces: + count = min(chunk, n_faces - start) + # !!!关键修正:ranges 在 CPU 上!!! + ranges_cpu = torch.tensor([[start, count]], device="cpu", dtype=torch.int32) + + rast_chunk, _ = dr.rasterize( + glctx, pos=pos, tri=faces_dev, resolution=[H, W], ranges=ranges_cpu + ) + + depth_chunk, _ = dr.interpolate(z_ndc, rast_chunk, faces_dev) # [1,H,W,1] + tri_id_adj = _normalize_tri_id(rast_chunk, start, count) + + if best_rast is None: + best_rast = torch.zeros_like(rast_chunk) + best_depth = torch.full_like(depth_chunk, float("inf")) + + hit = (tri_id_adj > 0) + prev_hit = (best_rast[..., 3:4] > 0) + closer = hit & (~prev_hit | (depth_chunk < best_depth)) + + rast_chunk = torch.cat([rast_chunk[..., :3], tri_id_adj], dim=-1) + + best_depth = torch.where(closer, depth_chunk, best_depth) + best_rast = torch.where(closer.expand_as(best_rast), rast_chunk, best_rast) + + start += count + + rast_out = best_rast + bary_coords = rast_out[..., :2] + zbuf = rast_out[..., 2] + pix_to_face = rast_out[..., 3].to(torch.int32) - 1 # 未命中 => -1 + + if return_indices_only: + return pix_to_face + + _output = (bary_coords, zbuf, pix_to_face) + if return_rast_out: + _output += (rast_out,) + if return_positions: + _output += (pos,) + return _output + + + class Meshes(torch.nn.Module): """ @@ -191,7 +310,11 @@ def __init__( self.cameras = cameras if use_opengl: - self.gl_context = dr.RasterizeGLContext() + backend = os.environ.get("NVDIFRAST_BACKEND", "gl").lower() + if backend == "cuda": + self.gl_context = dr.RasterizeCudaContext() + else: + self.gl_context = dr.RasterizeGLContext() else: self.gl_context = dr.RasterizeCudaContext() @@ -429,6 +552,17 @@ def forward( filtered_pix_to_face = filtered_pix_to_face - 1 mesh = Meshes(verts=mesh.verts, faces=filtered_faces, verts_colors=mesh.verts_colors) fragments.pix_to_face = filtered_pix_to_face + + if mesh.faces.shape[0] == 0: + H, W = fragments.zbuf.shape[1:3] + output_pkg = {} + if return_depth: + output_pkg["depth"] = torch.zeros(1, H, W, 1, device=mesh.verts.device) + if return_normals: + output_pkg["normals"] = torch.zeros(1, H, W, 3, device=mesh.verts.device) + if return_pix_to_face: + output_pkg["pix_to_face"] = fragments.pix_to_face + return output_pkg # Rebuild rast_out rast_out = torch.zeros(*fragments.zbuf.shape[:-1], 4, device=fragments.zbuf.device) @@ -457,14 +591,21 @@ def forward( color_idx = features.shape[-1] features = torch.cat([features, mesh.verts_colors], dim=-1) # Shape (N, n_features) + # Short-circuit: no visible triangles -> return zero maps + if mesh.faces.shape[0] == 0: + H, W = fragments.zbuf.shape[1:3] + if return_depth: + output_pkg["depth"] = torch.zeros(1, H, W, 1, device=mesh.verts.device) + if return_colors: + output_pkg["rgb"] = torch.zeros(1, H, W, 3, device=mesh.verts.device) + if return_normals: + output_pkg["normals"] = torch.zeros(1, H, W, 3, device=mesh.verts.device) + if return_pix_to_face: + output_pkg["pix_to_face"] = fragments.pix_to_face + return output_pkg + # Compute image - if True: - feature_img, _ = dr.interpolate(features[None], rast_out, mesh.faces) # Shape (1, H, W, n_features) - else: - pix_to_verts = mesh.faces[fragments.pix_to_face] # Shape (1, H, W, 1, 3) - pix_to_features = features[pix_to_verts] # Shape (1, H, W, 1, 3, n_features) - feature_img = (pix_to_features * fragments.bary_coords[..., None]).sum(dim=-2) # Shape (1, H, W, 1, n_features) - feature_img = feature_img.squeeze(-2) # Shape (1, H, W, n_features) + feature_img, _ = dr.interpolate(features[None], rast_out, mesh.faces) # Shape (1, H, W, n_features) # Antialiasing for propagating gradients if use_antialiasing: @@ -498,4 +639,4 @@ def forward( output_pkg["rast_out"] = rast_out #### TO REMOVE - return output_pkg \ No newline at end of file + return output_pkg diff --git a/milo/scripts/check_view_depth_alignment.py b/milo/scripts/check_view_depth_alignment.py new file mode 100644 index 0000000..ae4d682 --- /dev/null +++ b/milo/scripts/check_view_depth_alignment.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +""" +Utility to inspect specific Discoverse camera views and verify whether their +depth maps stay consistent with the reference bridge mesh. + +Example: + python scripts/check_view_depth_alignment.py \\ + --camera_json milo/data/bridge_clean/camera_poses_cam1.json \\ + --depth_dir milo/data/bridge_clean/depth \\ + --ply_path milo/data/bridge_clean/yufu_bridge_cleaned.ply \\ + --view_indices 72 187 --stride 8 --max_samples 20000 +""" + +from __future__ import annotations + +import argparse +import json +import math +import random +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Sequence + +import numpy as np +import trimesh +from scipy.spatial import cKDTree + +OPENGL_TO_COLMAP = np.diag([1.0, -1.0, -1.0]).astype(np.float64) + + +@dataclass +class CameraPose: + name: str + rotation_c2w: np.ndarray # (3, 3) + camera_center: np.ndarray # (3,) + + +def quaternion_to_rotation_matrix(quaternion: Sequence[float]) -> np.ndarray: + q = np.asarray(quaternion, dtype=np.float64) + if q.shape != (4,): + raise ValueError(f"Quaternion needs 4 numbers, got shape {q.shape}") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + return np.array( + [ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], + ], + dtype=np.float64, + ) + + +def load_camera_list(json_path: Path) -> List[CameraPose]: + with json_path.open("r", encoding="utf-8") as fh: + payload = json.load(fh) + + if isinstance(payload, dict): + for key in ("frames", "poses", "camera_poses"): + if key in payload and isinstance(payload[key], list): + payload = payload[key] + break + else: + raise ValueError(f"{json_path} does not contain a pose list.") + + cameras: List[CameraPose] = [] + for idx, entry in enumerate(payload): + if "quaternion" in entry: + rotation_c2w = quaternion_to_rotation_matrix(entry["quaternion"]) + elif "rotation" in entry: + rotation_c2w = np.asarray(entry["rotation"], dtype=np.float64) + else: + raise KeyError(f"Entry {idx} missing quaternion/rotation.") + + rotation_c2w = rotation_c2w @ OPENGL_TO_COLMAP + + if "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float64) + elif "translation" in entry: + camera_center = np.asarray(entry["translation"], dtype=np.float64) + else: + raise KeyError(f"Entry {idx} missing position/translation.") + + name = entry.get("name") or f"view_{idx:04d}" + cameras.append(CameraPose(name=name, rotation_c2w=rotation_c2w, camera_center=camera_center)) + return cameras + + +def build_intrinsics(width: int, height: int, fov_y_deg: float) -> Dict[str, float]: + fov_y = math.radians(fov_y_deg) + fy = 0.5 * height / math.tan(0.5 * fov_y) + aspect = width / height + fov_x = 2.0 * math.atan(aspect * math.tan(fov_y * 0.5)) + fx = 0.5 * width / math.tan(0.5 * fov_x) + return { + "fx": fx, + "fy": fy, + "cx": (width - 1) * 0.5, + "cy": (height - 1) * 0.5, + } + + +def depth_to_world_points( + depth_map: np.ndarray, + camera: CameraPose, + intrinsics: Dict[str, float], + stride: int, +) -> np.ndarray: + h, w = depth_map.shape + ys = np.arange(0, h, stride) + xs = np.arange(0, w, stride) + grid_y, grid_x = np.meshgrid(ys, xs, indexing="ij") + samples = depth_map[grid_y, grid_x] + valid = np.isfinite(samples) & (samples > 0.0) + if not np.any(valid): + return np.empty((0, 3), dtype=np.float64) + samples = samples[valid] + px = grid_x[valid].astype(np.float64) + py = grid_y[valid].astype(np.float64) + fx, fy = intrinsics["fx"], intrinsics["fy"] + cx, cy = intrinsics["cx"], intrinsics["cy"] + x_cam = (px - cx) / fx * samples + y_cam = (py - cy) / fy * samples + z_cam = samples + cam_points = np.stack([x_cam, y_cam, z_cam], axis=1) + world_points = cam_points @ camera.rotation_c2w.T + camera.camera_center + return world_points.astype(np.float64) + + +def summarize_points(points: np.ndarray) -> Dict[str, np.ndarray]: + if points.size == 0: + return {"min": np.array([]), "max": np.array([]), "mean": np.array([])} + return { + "min": points.min(axis=0), + "max": points.max(axis=0), + "mean": points.mean(axis=0), + } + + +def build_mesh_kdtree(mesh_path: Path, sample_vertices: int) -> tuple[cKDTree, np.ndarray]: + mesh = trimesh.load(mesh_path, process=False) + if isinstance(mesh, trimesh.Scene): + combined = [] + for geom in mesh.geometry.values(): + combined.append(np.asarray(geom.vertices, dtype=np.float64)) + vertices = np.concatenate(combined, axis=0) if combined else np.empty((0, 3), dtype=np.float64) + else: + vertices = np.asarray(mesh.vertices, dtype=np.float64) + if vertices.size == 0: + raise ValueError(f"No vertices found in {mesh_path}") + if sample_vertices > 0 and sample_vertices < len(vertices): + rng = np.random.default_rng(seed=0) + choice = rng.choice(len(vertices), size=sample_vertices, replace=False) + vertices = vertices[choice] + tree = cKDTree(vertices) + return tree, vertices + + +def compute_distances(tree: cKDTree, points: np.ndarray, max_samples: int) -> np.ndarray: + if points.shape[0] == 0: + return np.empty(0, dtype=np.float64) + if max_samples > 0 and points.shape[0] > max_samples: + rng = np.random.default_rng(seed=42) + choice = rng.choice(points.shape[0], size=max_samples, replace=False) + pts = points[choice] + else: + pts = points + distances, _ = tree.query(pts, workers=-1) + return distances + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Check Discoverse depth / pose consistency for given view indices.") + parser.add_argument("--camera_json", type=Path, required=True, help="Path to camera pose JSON.") + parser.add_argument("--depth_dir", type=Path, required=True, help="Directory containing depth_img_0_*.npy files.") + parser.add_argument("--ply_path", type=Path, required=True, help="Reference mesh PLY used during capture.") + parser.add_argument("--view_indices", type=int, nargs="+", required=True, help="List of view indices to inspect.") + parser.add_argument("--width", type=int, default=1280) + parser.add_argument("--height", type=int, default=720) + parser.add_argument("--fov_y", type=float, default=75.0) + parser.add_argument("--stride", type=int, default=8, help="Pixel stride when sub-sampling depth map.") + parser.add_argument("--mesh_sample", type=int, default=200000, help="Number of mesh vertices used to build KD-tree.") + parser.add_argument("--max_samples", type=int, default=20000, help="Maximum number of world points for NN query.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + cameras = load_camera_list(args.camera_json) + if max(args.view_indices) >= len(cameras): + raise IndexError(f"Requested view index exceeds available cameras ({len(cameras)}).") + intr = build_intrinsics(args.width, args.height, args.fov_y) + depth_dir = args.depth_dir + if not depth_dir.is_dir(): + raise FileNotFoundError(f"Depth directory {depth_dir} not found.") + + tree, mesh_points = build_mesh_kdtree(args.ply_path, args.mesh_sample) + mesh_bounds = np.array([mesh_points.min(axis=0), mesh_points.max(axis=0)]) + + print(f"[INFO] Loaded {len(cameras)} cameras from {args.camera_json}.") + print(f"[INFO] Mesh bounds: min={mesh_bounds[0]}, max={mesh_bounds[1]}") + + for view_idx in args.view_indices: + camera = cameras[view_idx] + depth_path = depth_dir / f"depth_img_0_{view_idx:03d}.npy" + if not depth_path.is_file(): + depth_path = depth_dir / f"depth_img_0_{view_idx}.npy" + if not depth_path.is_file(): + print(f"[WARN] Depth file for view {view_idx} missing, skipping.") + continue + depth = np.load(depth_path).squeeze() + valid = np.isfinite(depth) & (depth > 0) + if not np.any(valid): + print(f"[WARN] View {view_idx} has no valid depth pixels.") + continue + depth_stats = { + "min": float(depth[valid].min()), + "max": float(depth[valid].max()), + "mean": float(depth[valid].mean()), + "std": float(depth[valid].std()), + } + world_points = depth_to_world_points(depth, camera, intr, stride=max(1, args.stride)) + point_stats = summarize_points(world_points) + dist = compute_distances(tree, world_points, args.max_samples) + out_of_bounds = np.logical_or(world_points < mesh_bounds[0], world_points > mesh_bounds[1]).any(axis=1) + print(f"\n[VIEW {view_idx:03d}] {camera.name}") + print(f" Depth stats : min={depth_stats['min']:.3f} max={depth_stats['max']:.3f} " + f"mean={depth_stats['mean']:.3f} std={depth_stats['std']:.3f}") + if world_points.size == 0: + print(" No valid reprojected points.") + continue + print(f" World bounds : min={point_stats['min']} max={point_stats['max']}") + print(f" Outside mesh BB fraction: {out_of_bounds.mean():.4f}") + if dist.size: + print(f" NN distance (m): mean={dist.mean():.3f} median={np.median(dist):.3f} " + f"p95={np.percentile(dist,95):.3f} max={dist.max():.3f}") + else: + print(" NN distance: n/a (insufficient points)") + + +if __name__ == "__main__": + main() diff --git a/milo/scripts/discoverse_to_colmap.py b/milo/scripts/discoverse_to_colmap.py new file mode 100644 index 0000000..afe859b --- /dev/null +++ b/milo/scripts/discoverse_to_colmap.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 +""" +Convert DISCOVERSE simulator exports (poses + RGB) into a MILo-compatible +COLMAP text scene. +""" + +import argparse +import json +import math +import os +import re +import shutil +from pathlib import Path +from typing import Dict, Iterable, List, Optional + +import numpy as np +import struct + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert DISCOVERSE simulator exports (poses + stereo RGB) into a MILo-compatible COLMAP text scene." + ) + parser.add_argument("--source", required=True, help="Root directory containing DISCOVERSE exports.") + parser.add_argument("--output", required=True, help="Destination directory for the COLMAP layout.") + parser.add_argument("--poses-root", default=".", help="Relative subdirectory under source with pose JSON files.") + parser.add_argument("--poses-glob", default="camera_poses_cam*.json", + help="Glob pattern (relative to poses-root) to pick pose JSON files.") + parser.add_argument("--image-root", default=".", help="Relative subdirectory under source holding RGB images.") + parser.add_argument("--image-pattern", default="rgb_img_{cam}_{frame_padded}.png", + help="Format string used to resolve image filenames. Keys: cam, frame, frame_padded, name, name_safe.") + parser.add_argument("--frame-padding", type=int, default=6, + help="Zero padding width when building frame_padded (default: 6).") + parser.add_argument("--orientation", choices=["camera_to_world", "world_to_camera"], default="camera_to_world", + help="Whether quaternions encode camera-to-world (default) or world-to-camera rotations.") + parser.add_argument("--quaternion-order", choices=["wxyz", "xyzw"], default="wxyz", + help="Component order inside the pose JSON (default assumes [w,x,y,z]).") + parser.add_argument("--camera-ids", type=int, nargs="*", + help="Optional explicit camera ids (for image naming) ordered like pose files.") + parser.add_argument("--width", type=int, required=True, help="Default image width (pixels).") + parser.add_argument("--height", type=int, required=True, help="Default image height (pixels).") + parser.add_argument("--fx", type=float, required=True, help="Default focal length fx (pixels).") + parser.add_argument("--fy", type=float, required=True, help="Default focal length fy (pixels).") + parser.add_argument("--cx", type=float, required=True, help="Default principal point x.") + parser.add_argument("--cy", type=float, required=True, help="Default principal point y.") + parser.add_argument("--intrinsics-json", type=str, + help="Optional JSON file overriding intrinsics per camera id.") + parser.add_argument("--copy-images", dest="copy_images", action="store_true", + help="Copy RGBs into output/images (default).") + parser.add_argument("--link-images", dest="copy_images", action="store_false", + help="Use symlinks instead of copying RGBs.") + parser.set_defaults(copy_images=True) + parser.add_argument("--dry-run", action="store_true", help="Print actions without writing output files.") + return parser.parse_args() + + +def load_pose_list(json_path: Path) -> List[dict]: + with json_path.open("r", encoding="utf-8") as fh: + data = json.load(fh) + + if isinstance(data, list): + frames = data + elif isinstance(data, dict): + for key in ("frames", "poses", "camera_poses"): + if key in data and isinstance(data[key], list): + frames = data[key] + break + else: + raise ValueError(f"Unsupported JSON structure in {json_path}") + else: + raise ValueError(f"Unsupported JSON structure in {json_path}") + + for idx, frame in enumerate(frames): + if not isinstance(frame, dict): + raise ValueError(f"Frame {idx} in {json_path} is not an object") + if "position" not in frame: + raise ValueError(f"Frame {idx} in {json_path} misses 'position'") + if "quaternion" not in frame: + raise ValueError(f"Frame {idx} in {json_path} misses 'quaternion'") + return frames + + +def normalize_quaternion(q: Iterable[float]) -> np.ndarray: + q_arr = np.asarray(list(q), dtype=np.float64) + norm = np.linalg.norm(q_arr) + if norm == 0: + raise ValueError("Encountered zero-length quaternion") + return q_arr / norm + + +def quaternion_to_matrix(q: np.ndarray) -> np.ndarray: + w, x, y, z = q + return np.array([ + [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], + [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], + [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], + ], dtype=np.float64) + + +def matrix_to_quaternion(R: np.ndarray) -> np.ndarray: + m00, m01, m02 = R[0] + m10, m11, m12 = R[1] + m20, m21, m22 = R[2] + trace = m00 + m11 + m22 + if trace > 0.0: + s = math.sqrt(trace + 1.0) * 2.0 + w = 0.25 * s + x = (m21 - m12) / s + y = (m02 - m20) / s + z = (m10 - m01) / s + elif m00 > m11 and m00 > m22: + s = math.sqrt(1.0 + m00 - m11 - m22) * 2.0 + w = (m21 - m12) / s + x = 0.25 * s + y = (m01 + m10) / s + z = (m02 + m20) / s + elif m11 > m22: + s = math.sqrt(1.0 + m11 - m00 - m22) * 2.0 + w = (m02 - m20) / s + x = (m01 + m10) / s + y = 0.25 * s + z = (m12 + m21) / s + else: + s = math.sqrt(1.0 + m22 - m00 - m11) * 2.0 + w = (m10 - m01) / s + x = (m02 + m20) / s + y = (m12 + m21) / s + z = 0.25 * s + q = np.array([w, x, y, z], dtype=np.float64) + if q[0] < 0: + q *= -1.0 + return q / np.linalg.norm(q) + + +def sanitize_name(name: str) -> str: + return re.sub(r"[^0-9A-Za-z._-]", "_", name) + + +def extract_frame_index(frame: dict) -> Optional[int]: + for key in ("frame_id", "frame_index", "index", "idx"): + if key in frame: + try: + return int(frame[key]) + except (TypeError, ValueError): + pass + name = frame.get("name") or frame.get("filename") or frame.get("image") + if isinstance(name, str): + matches = re.findall(r"traj[_-]?(\d+)", name) + if matches: + return int(matches[0]) + digits = re.findall(r"\d+", name) + if digits: + return int(digits[0]) + return None + + +def infer_camera_token(pose_path: Path, frames: List[dict]) -> Optional[int]: + name_digits: List[int] = [] + for frame in frames: + name = frame.get("name") + if not isinstance(name, str): + continue + match = re.search(r"cam[_-]?(\d+)", name) + if match: + name_digits.append(int(match.group(1))) + if name_digits: + return int(np.median(name_digits)) + + match = re.search(r"cam[_-]?(\d+)", pose_path.name) + if match: + return int(match.group(1)) + match = re.search(r"(\d+)", pose_path.stem) + if match: + return int(match.group(1)) + return None + + +def build_image_candidates(args: argparse.Namespace, cam_token: int, frame_idx: int, frame_name: Optional[str]) -> List[str]: + safe_name = sanitize_name(frame_name) if isinstance(frame_name, str) else None + padded = f"{frame_idx:0{args.frame_padding}d}" if args.frame_padding and frame_idx is not None else str(frame_idx) + fmt_values = { + "cam": cam_token, + "frame": frame_idx, + "frame_padded": padded, + "name": frame_name if isinstance(frame_name, str) else f"frame_{frame_idx}", + "name_safe": safe_name if safe_name else f"frame_{frame_idx}", + } + candidates: List[str] = [] + try: + candidates.append(args.image_pattern.format(**fmt_values)) + except KeyError: + pass + if frame_idx is not None: + for pad in (6, 5, 4, 3, 2, 1, 0): + if pad == args.frame_padding: + continue + fmt_values_mod = dict(fmt_values) + if pad > 0: + fmt_values_mod["frame_padded"] = f"{frame_idx:0{pad}d}" + else: + fmt_values_mod["frame_padded"] = str(frame_idx) + try: + candidate = args.image_pattern.format(**fmt_values_mod) + except KeyError: + continue + if candidate not in candidates: + candidates.append(candidate) + if safe_name: + fallback = [ + f"rgb_img_{cam_token}_{safe_name}.png", + f"rgb_img_{cam_token}_{frame_idx}.png", + f"rgb_img_{cam_token}_{frame_idx:06d}.png", + f"{safe_name}.png", + ] + for cand in fallback: + if cand not in candidates: + candidates.append(cand) + return candidates + + +def resolve_image_path(images_root: Path, candidates: List[str]) -> Optional[Path]: + for rel in candidates: + candidate_path = images_root / rel + if candidate_path.exists(): + return candidate_path + return None + + +def load_intrinsics(args: argparse.Namespace) -> Dict[int, Dict[str, float]]: + per_camera: Dict[int, Dict[str, float]] = {} + if args.intrinsics_json: + with open(args.intrinsics_json, "r", encoding="utf-8") as fh: + raw = json.load(fh) + if not isinstance(raw, dict): + raise ValueError("intrinsics JSON must map camera ids to parameter dicts") + for key, entry in raw.items(): + try: + cam_id = int(key) + except ValueError as exc: + raise ValueError(f"Invalid camera id '{key}' in intrinsics JSON") from exc + required = {"width", "height", "fx", "fy", "cx", "cy"} + if not required.issubset(entry.keys()): + missing = required - set(entry.keys()) + raise ValueError(f"Intrinsics entry for camera {cam_id} missing {missing}") + per_camera[cam_id] = {k: float(entry[k]) for k in required} + per_camera[cam_id]["width"] = int(per_camera[cam_id]["width"]) + per_camera[cam_id]["height"] = int(per_camera[cam_id]["height"]) + return per_camera + + +def main(): + args = parse_args() + + source_root = Path(args.source).resolve() + output_root = Path(args.output).resolve() + pose_root = (source_root / args.poses_root).resolve() + images_root = (source_root / args.image_root).resolve() + + intrinsics_map = load_intrinsics(args) + + pose_files = sorted(pose_root.glob(args.poses_glob)) + if not pose_files: + raise FileNotFoundError(f"No pose files matched '{args.poses_glob}' under {pose_root}") + + if args.camera_ids and len(args.camera_ids) != len(pose_files): + raise ValueError("Number of --camera-ids entries must match pose files found") + + default_intr = {"width": args.width, "height": args.height, "fx": args.fx, + "fy": args.fy, "cx": args.cx, "cy": args.cy} + + entries = [] + camera_id_map: Dict[int, int] = {} + next_camera_colmap_id = 1 + next_image_id = 1 + + for pose_idx, pose_path in enumerate(pose_files): + frames = load_pose_list(pose_path) + file_cam_token = infer_camera_token(pose_path, frames) + explicit_cam_token = None + if args.camera_ids: + explicit_cam_token = args.camera_ids[pose_idx] + cam_token_candidates = [] + if explicit_cam_token is not None: + cam_token_candidates.append(explicit_cam_token) + if file_cam_token is not None: + cam_token_candidates.extend([file_cam_token, file_cam_token - 1, file_cam_token + 1]) + cam_token_candidates.extend([pose_idx, pose_idx + 1]) + candidate_list = [c for c in cam_token_candidates if c is not None and c >= 0] + dedup_candidates = [] + for cand in candidate_list: + if cand not in dedup_candidates: + dedup_candidates.append(cand) + if not dedup_candidates: + dedup_candidates = [pose_idx] + + chosen_cam_token: Optional[int] = None + first_frame = frames[0] + frame_idx = extract_frame_index(first_frame) + for cam_candidate in dedup_candidates: + candidates = build_image_candidates(args, cam_candidate, frame_idx or 0, first_frame.get("name")) + found = resolve_image_path(images_root, candidates) + if found: + chosen_cam_token = cam_candidate + break + if chosen_cam_token is None: + raise FileNotFoundError( + f"Unable to locate RGB for first frame in {pose_path.name}. Tried camera ids {dedup_candidates} under {images_root}" + ) + + if chosen_cam_token not in camera_id_map: + camera_id_map[chosen_cam_token] = next_camera_colmap_id + next_camera_colmap_id += 1 + + intr = intrinsics_map.get(chosen_cam_token, default_intr) + for key in ("width", "height", "fx", "fy", "cx", "cy"): + if key not in intr: + raise ValueError(f"Missing intrinsic '{key}' for camera {chosen_cam_token}") + + for f_idx, frame in enumerate(frames): + position = np.asarray(frame["position"], dtype=np.float64).reshape(-1) + if position.size != 3: + raise ValueError(f"Position for frame {f_idx} in {pose_path.name} is not length 3") + quaternion_raw = frame["quaternion"] + if len(quaternion_raw) != 4: + raise ValueError(f"Quaternion for frame {f_idx} in {pose_path.name} does not have 4 components") + if args.quaternion_order == "wxyz": + q_ordered = quaternion_raw + else: + q_ordered = [quaternion_raw[3], quaternion_raw[0], quaternion_raw[1], quaternion_raw[2]] + q = normalize_quaternion(q_ordered) + R_input = quaternion_to_matrix(q) + if args.orientation == "camera_to_world": + R_w2c = R_input.T + else: + R_w2c = R_input + t = -R_w2c @ position.reshape(3, 1) + q_w2c = matrix_to_quaternion(R_w2c) + + frame_idx_val = extract_frame_index(frame) + if frame_idx_val is None: + frame_idx_val = f_idx + candidate_images = build_image_candidates(args, chosen_cam_token, frame_idx_val, frame.get("name")) + image_path = resolve_image_path(images_root, candidate_images) + if image_path is None: + raise FileNotFoundError( + f"RGB image for frame {frame.get('name', f_idx)} (camera {chosen_cam_token}) not found. " + f"Looked for: {candidate_images}" + ) + + entries.append({ + "image_id": next_image_id, + "camera_token": chosen_cam_token, + "camera_colmap_id": camera_id_map[chosen_cam_token], + "image_name": image_path.name, + "image_source": image_path, + "qvec": q_w2c, + "tvec": t.reshape(-1), + "frame_index": frame_idx_val, + }) + next_image_id += 1 + + entries.sort(key=lambda item: (item["frame_index"], item["camera_token"], item["image_id"])) + + if args.dry_run: + print(f"Would create output at {output_root}") + print("Cameras:") + for cam_token, colmap_id in camera_id_map.items(): + intr = intrinsics_map.get(cam_token, default_intr) + print(f" COLMAP id {colmap_id}: source cam {cam_token} -> PINHOLE {intr}") + print(f"Total images: {len(entries)}") + for entry in entries[:5]: + print(f" Example image {entry['image_id']}: {entry['image_name']} -> q={entry['qvec']} t={entry['tvec']}") + return + + sparse_dir = output_root / "sparse" / "0" + images_dir = output_root / "images" + if not output_root.exists(): + output_root.mkdir(parents=True, exist_ok=True) + sparse_dir.mkdir(parents=True, exist_ok=True) + images_dir.mkdir(parents=True, exist_ok=True) + + cameras_txt = sparse_dir / "cameras.txt" + with cameras_txt.open("w", encoding="utf-8") as fh: + fh.write("# Camera list with one line of data per camera:\n") + fh.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") + for cam_token, colmap_id in sorted(camera_id_map.items(), key=lambda kv: kv[1]): + intr = intrinsics_map.get(cam_token, default_intr) + fh.write( + f"{colmap_id} PINHOLE {int(intr['width'])} {int(intr['height'])} " + f"{float(intr['fx']):.12f} {float(intr['fy']):.12f} {float(intr['cx']):.12f} {float(intr['cy']):.12f}\n" + ) + + images_txt = sparse_dir / "images.txt" + with images_txt.open("w", encoding="utf-8") as fh: + fh.write("# Image list with two lines of data per image:\n") + fh.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n") + fh.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n") + for entry in entries: + q = entry["qvec"] + t = entry["tvec"] + fh.write( + f"{entry['image_id']} {q[0]:.12f} {q[1]:.12f} {q[2]:.12f} {q[3]:.12f} " + f"{t[0]:.12f} {t[1]:.12f} {t[2]:.12f} {entry['camera_colmap_id']} {entry['image_name']}\n" + ) + fh.write("\n") + + points_txt = sparse_dir / "points3D.txt" + with points_txt.open("w", encoding="utf-8") as fh: + fh.write("# Empty point cloud placeholder.\n") + + points_bin = sparse_dir / "points3D.bin" + with points_bin.open("wb") as fid: + fid.write(struct.pack(" Iterable[Path]: + if root.is_file(): + if root.suffix.lower() in VALID_EXTENSIONS: + yield root + return + + if not root.is_dir(): + print(f"[WARN] {root} 既不是文件也不是目录,跳过。") + return + + pattern = "**/*" if recursive else "*" + for path in root.glob(pattern): + if path.is_file() and path.suffix.lower() in VALID_EXTENSIONS: + yield path + + +def convert_to_convex_hull( + mesh_path: Path, + suffix: str, + overwrite: bool, +) -> Path: + mesh = trimesh.load(mesh_path, force="mesh", process=False) + + convex = mesh.convex_hull + output_path = mesh_path.with_name(f"{mesh_path.stem}{suffix}{mesh_path.suffix}") + if output_path.exists() and not overwrite: + raise FileExistsError(f"{output_path} 已存在,使用 --overwrite 以覆盖。") + convex.export(output_path) + return output_path + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="将单个 mesh 或目录中的所有 mesh 转为凸包表示,文件名追加 `_convex`。", + ) + parser.add_argument("input_path", type=Path, help="mesh 文件或目录路径。") + parser.add_argument( + "--recursive", + action="store_true", + help="若 input_path 为目录,递归遍历所有子目录。", + ) + parser.add_argument( + "--suffix", + default="_convex", + help="输出文件名后缀(默认 _convex)。", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="若目标文件已存在则覆盖。", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + mesh_paths = list(iter_mesh_paths(args.input_path, args.recursive)) + if not mesh_paths: + print(f"[WARN] 在 {args.input_path} 下找不到 mesh 文件(支持扩展名: {', '.join(VALID_EXTENSIONS)})。") + return 1 + + for path in mesh_paths: + try: + output = convert_to_convex_hull( + path, + suffix=args.suffix, + overwrite=args.overwrite, + ) + print(f"[INFO] {path} -> {output}") + except Exception as exc: # noqa: BLE001 + print(f"[ERROR] 处理 {path} 失败:{exc}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/milo/scripts/run_depth_sweep.py b/milo/scripts/run_depth_sweep.py new file mode 100644 index 0000000..67f8c60 --- /dev/null +++ b/milo/scripts/run_depth_sweep.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +Run a small sweep of depth-only training configurations. + +Each configuration executes `milo/depth_train.py` with 1000 iterations by default, +optionally restricting training to a single camera for reproducibility. After the +runs finish, a compact summary of the final metrics is emitted to stdout and +saved under `output/depth_sweep_summary.txt`. +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + + +DEFAULT_CONFIGS: List[Dict] = [ + { + "name": "lr0.7_clip7_dense", + "initial_lr_scale": 0.70, + "depth_clip_min": 0.1, + "depth_clip_max": 7.0, + "enable_densification": True, + }, + { + "name": "lr0.8_clip7_dense", + "initial_lr_scale": 0.80, + "depth_clip_min": 0.2, + "depth_clip_max": 7.0, + "enable_densification": True, + }, + { + "name": "lr0.9_clip6p5_dense", + "initial_lr_scale": 0.90, + "depth_clip_min": 0.3, + "depth_clip_max": 6.5, + "enable_densification": True, + }, + { + "name": "lr1.0_clip6_dense", + "initial_lr_scale": 1.00, + "depth_clip_min": 0.3, + "depth_clip_max": 6.0, + "enable_densification": True, + }, + { + "name": "lr1.15_clip6_dense", + "initial_lr_scale": 1.15, + "depth_clip_min": 0.4, + "depth_clip_max": 6.0, + "enable_densification": True, + }, + { + "name": "lr1.3_clip5p5_dense", + "initial_lr_scale": 1.30, + "depth_clip_min": 0.5, + "depth_clip_max": 5.5, + "enable_densification": True, + }, + { + "name": "lr0.75_clip7_no_dense", + "initial_lr_scale": 0.75, + "depth_clip_min": 0.1, + "depth_clip_max": 7.0, + "enable_densification": False, + }, + { + "name": "lr1.0_clip6_no_dense", + "initial_lr_scale": 1.0, + "depth_clip_min": 0.3, + "depth_clip_max": 6.0, + "enable_densification": False, + }, +] + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Sweep depth training configurations.") + parser.add_argument("--ply_path", required=True, type=Path, help="Input Gaussian PLY.") + parser.add_argument("--camera_poses", required=True, type=Path, help="Camera JSON file.") + parser.add_argument("--depth_dir", required=True, type=Path, help="Directory of depth .npy files.") + parser.add_argument("--output_root", type=Path, default=Path("runs/depth_sweep"), help="Base directory for sweep outputs.") + parser.add_argument("--iterations", type=int, default=1000, help="Iterations per configuration.") + parser.add_argument("--fixed_view_idx", type=int, default=0, help="Camera index to lock during sweep (-1 = random shuffling).") + parser.add_argument("--cuda_blocking", action="store_true", help="Set CUDA_LAUNCH_BLOCKING=1 for each run.") + parser.add_argument("--extra_arg", action="append", default=[], help="Extra CLI arguments passed verbatim to depth_train.py.") + parser.add_argument("--resume_if_exists", action="store_true", help="Skip configs whose output directory already exists.") + return parser + + +def ensure_directory(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def run_depth_train( + script_path: Path, + cfg: Dict, + args: argparse.Namespace, + run_dir: Path, +) -> int: + cmd: List[str] = [ + sys.executable, + str(script_path), + "--ply_path", + str(args.ply_path), + "--camera_poses", + str(args.camera_poses), + "--depth_dir", + str(args.depth_dir), + "--output_dir", + str(run_dir), + "--iterations", + str(args.iterations), + "--initial_lr_scale", + str(cfg["initial_lr_scale"]), + "--log_depth_stats", + ] + if cfg.get("depth_clip_min", 0.0) > 0.0: + cmd.extend(["--depth_clip_min", str(cfg["depth_clip_min"])]) + if cfg.get("depth_clip_max") is not None: + cmd.extend(["--depth_clip_max", str(cfg["depth_clip_max"])]) + if cfg.get("enable_densification", False): + cmd.append("--enable_densification") + if args.fixed_view_idx >= 0: + cmd.extend(["--fixed_view_idx", str(args.fixed_view_idx)]) + for extra in args.extra_arg: + cmd.append(extra) + + env = os.environ.copy() + if args.cuda_blocking: + env["CUDA_LAUNCH_BLOCKING"] = "1" + + print(f"[SWEEP] Running {cfg['name']} -> {run_dir}") + print(" Command:", " ".join(cmd)) + sys.stdout.flush() + result = subprocess.run(cmd, env=env) + if result.returncode != 0: + print(f"[SWEEP] {cfg['name']} exited with code {result.returncode}") + return result.returncode + + +def read_final_metrics(run_dir: Path) -> Optional[Dict]: + log_path = run_dir / "logs" / "losses.jsonl" + if not log_path.exists(): + return None + last_line: Optional[str] = None + with open(log_path, "r", encoding="utf-8") as log_file: + for line in log_file: + last_line = line.strip() + if not last_line: + return None + try: + data = json.loads(last_line) + except json.JSONDecodeError: + return None + return { + "iteration": data.get("iteration"), + "depth_loss": data.get("depth_loss"), + "pred_depth_mean": data.get("pred_depth_mean"), + "target_depth_mean": data.get("target_depth_mean"), + "pred_depth_max": data.get("pred_depth_max"), + "pred_depth_min": data.get("pred_depth_min"), + "target_depth_max": data.get("target_depth_max"), + "target_depth_min": data.get("target_depth_min"), + } + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + + script_path = Path("milo/depth_train.py").resolve() + ensure_directory(args.output_root) + ensure_directory(Path("output")) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + summary_lines: List[str] = [] + + for cfg in DEFAULT_CONFIGS: + run_dir = args.output_root / f"{timestamp}_{cfg['name']}" + if run_dir.exists() and args.resume_if_exists: + print(f"[SWEEP] Skipping {cfg['name']} (directory exists).") + metrics = read_final_metrics(run_dir) + else: + ensure_directory(run_dir) + exit_code = run_depth_train(script_path, cfg, args, run_dir) + if exit_code != 0: + summary_lines.append(f"{cfg['name']}: FAILED (code {exit_code})") + continue + metrics = read_final_metrics(run_dir) + + if not metrics: + summary_lines.append(f"{cfg['name']}: missing/invalid log") + continue + + summary_lines.append( + "{name}: depth_loss={loss:.4f} pred_mean={p_mean:.4f} target_mean={t_mean:.4f}".format( + name=cfg["name"], + loss=metrics.get("depth_loss", float("nan")), + p_mean=metrics.get("pred_depth_mean", float("nan")), + t_mean=metrics.get("target_depth_mean", float("nan")), + ) + ) + + summary_path = Path("output") / "depth_sweep_summary.txt" + with open(summary_path, "a", encoding="utf-8") as summary_file: + summary_file.write(f"# Sweep {timestamp}\n") + for line in summary_lines: + summary_file.write(line + "\n") + summary_file.write("\n") + + print("\n[SWEEP] Summary:") + for line in summary_lines: + print(" -", line) + print(f"[SWEEP] Full summary appended to {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/milo/scripts/scannet_to_colmap.py b/milo/scripts/scannet_to_colmap.py new file mode 100644 index 0000000..16ed700 --- /dev/null +++ b/milo/scripts/scannet_to_colmap.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +""" +将 ScanNet 场景 (RGB-D + 相机轨迹) 转换成 Milo 训练所需的纯 COLMAP 目录。 + +脚本会: +1. 解析 .sens 获取 RGB 帧与 camera-to-world pose。 +2. 应用 axisAlignment(若存在)以保持 ScanNet 公开场景的惯用坐标。 +3. 把选中的帧原封不动写成 JPEG,生成 COLMAP cameras.txt / images.txt。 +4. 从 *_vh_clean_2.ply(或备用 *_vh_clean.ply)抽取点云,写入 points3D.bin/txt。 +生成结果与 milo/data/Ignatius 相同(images + sparse/0)。 +""" + +from __future__ import annotations + +import argparse +import io +import struct +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator, List, Optional, Sequence, Tuple + +import numpy as np +from plyfile import PlyData + +try: + from PIL import Image +except ImportError: + Image = None + + +@dataclass +class ScanNetFrame: + """Minimal容器,只保留 Milo 转换需要的信息。""" + + index: int + camera_to_world: np.ndarray # (4, 4) + timestamp_color: int + timestamp_depth: int + color_bytes: bytes + + +class ScanNetSensorData: + """直接解析 ScanNet .sens(二进制 RGB-D 轨迹)。""" + + def __init__(self, sens_path: Path): + self.sens_path = Path(sens_path) + self._fh = None + self.version = None + self.sensor_name = "" + self.intrinsic_color = None + self.extrinsic_color = None + self.intrinsic_depth = None + self.extrinsic_depth = None + self.color_compression = None + self.depth_compression = None + self.color_width = 0 + self.color_height = 0 + self.depth_width = 0 + self.depth_height = 0 + self.depth_shift = 0.0 + self.num_frames = 0 + + def __enter__(self) -> "ScanNetSensorData": + self._fh = self.sens_path.open("rb") + self._read_header() + return self + + def __exit__(self, exc_type, exc, tb): + if self._fh: + self._fh.close() + self._fh = None + + def _read_header(self) -> None: + fh = self._fh + assert fh is not None + read = fh.read + self.version = struct.unpack(" Iterator[ScanNetFrame]: + if self._fh is None: + raise RuntimeError("Sensor file is not opened. Use within a context manager.") + + for frame_idx in range(self.num_frames): + mat = np.frombuffer(self._fh.read(16 * 4), dtype=" np.ndarray: + """复制自 COLMAP read_write_model,实现矩阵->四元数。""" + R = np.asarray(R, dtype=np.float64) + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec / np.linalg.norm(qvec) + + +def parse_axis_alignment(meta_txt: Path) -> np.ndarray: + """读取 axisAlignment=... 行,没有则返回单位阵。""" + if not meta_txt.is_file(): + return np.eye(4, dtype=np.float64) + + axis = None + with meta_txt.open("r", encoding="utf-8") as fh: + for line in fh: + if line.strip().startswith("axisAlignment"): + values = line.split("=", 1)[1].strip().split() + if len(values) != 16: + raise ValueError(f"axisAlignment 需要 16 个数,当前 {len(values)}") + axis = np.array([float(v) for v in values], dtype=np.float64).reshape(4, 4) + break + if axis is None: + axis = np.eye(4, dtype=np.float64) + return axis + + +def infer_scene_id(scene_root: Path) -> str: + """默认场景目录名就是 sceneXXXX_YY;否则尝试寻找唯一的 *.sens。""" + if scene_root.name.startswith("scene") and "_" in scene_root.name: + return scene_root.name + sens_files = list(scene_root.glob("*.sens")) + if len(sens_files) != 1: + raise ValueError("无法唯一确定 scene id,请使用 --scene-id") + return sens_files[0].stem + + +def find_point_cloud(scene_root: Path, scene_id: str, override: Optional[Path]) -> Path: + if override: + pc_path = Path(override) + if not pc_path.is_file(): + raise FileNotFoundError(f"指定点云 {pc_path} 不存在") + return pc_path + candidates = [ + scene_root / f"{scene_id}_vh_clean_2.ply", + scene_root / f"{scene_id}_vh_clean.ply", + ] + for cand in candidates: + if cand.is_file(): + return cand + raise FileNotFoundError("找不到 *_vh_clean_*.ply 点云,请用 --points-source 指定") + + +def load_point_cloud( + ply_path: Path, + stride: int = 1, + max_points: Optional[int] = None, + seed: int = 0, + transform: Optional[np.ndarray] = None, +) -> Tuple[np.ndarray, np.ndarray]: + ply = PlyData.read(str(ply_path)) + verts = ply["vertex"].data + xyz = np.vstack([verts["x"], verts["y"], verts["z"]]).T.astype(np.float64) + if {"red", "green", "blue"}.issubset(verts.dtype.names): + colors = np.vstack([verts["red"], verts["green"], verts["blue"]]).T.astype(np.uint8) + else: + colors = np.full((xyz.shape[0], 3), 255, dtype=np.uint8) + + idx = np.arange(xyz.shape[0]) + if stride > 1: + idx = idx[::stride] + if max_points is not None and idx.size > max_points: + rng = np.random.default_rng(seed) + idx = rng.choice(idx, size=max_points, replace=False) + xyz = xyz[idx] + if transform is not None: + if transform.shape != (4, 4): + raise ValueError("transform must be 4x4 homogeneous matrix") + homo = np.concatenate([xyz, np.ones((xyz.shape[0], 1), dtype=np.float64)], axis=1) + xyz = (homo @ transform.T)[:, :3] + return xyz, colors[idx] + + +def ensure_output_dirs(output_root: Path) -> Tuple[Path, Path]: + images_dir = output_root / "images" + sparse_dir = output_root / "sparse" / "0" + if output_root.exists(): + existing = [p for p in output_root.iterdir() if not p.name.startswith(".")] + if existing: + raise FileExistsError(f"{output_root} 已存在且非空,请指定新的输出目录") + sparse_dir.mkdir(parents=True, exist_ok=True) + images_dir.mkdir(parents=True, exist_ok=True) + return images_dir, sparse_dir + + +@dataclass +class ImageRecord: + image_id: int + name: str + qvec: np.ndarray + tvec: np.ndarray + frame_index: int + + +def write_cameras_txt(path: Path, camera_id: int, width: int, height: int, fx: float, fy: float, cx: float, cy: float) -> None: + with path.open("w", encoding="utf-8") as fh: + fh.write("# Camera list with one line of data per camera:\n") + fh.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") + fh.write("# Number of cameras: 1\n") + fh.write(f"{camera_id} PINHOLE {width} {height} {fx:.9f} {fy:.9f} {cx:.9f} {cy:.9f}\n") + + +def write_images_txt(path: Path, camera_id: int, records: Sequence[ImageRecord]) -> None: + with path.open("w", encoding="utf-8") as fh: + fh.write("# Image list with two lines of data per image:\n") + fh.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n") + fh.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n") + fh.write(f"# Number of images: {len(records)}, mean observations per image: 0\n") + for rec in records: + q = rec.qvec + t = rec.tvec + fh.write( + f"{rec.image_id} {q[0]:.12f} {q[1]:.12f} {q[2]:.12f} {q[3]:.12f} " + f"{t[0]:.12f} {t[1]:.12f} {t[2]:.12f} {camera_id} {rec.name}\n" + ) + fh.write("\n") + + +def write_points3d_txt(path: Path, xyz: np.ndarray, rgb: np.ndarray) -> None: + with path.open("w", encoding="utf-8") as fh: + fh.write("# 3D point list with one line of data per point:\n") + fh.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[]\n") + fh.write(f"# Number of points: {xyz.shape[0]}\n") + for idx, (pos, color) in enumerate(zip(xyz, rgb), start=1): + fh.write( + f"{idx} {pos[0]:.9f} {pos[1]:.9f} {pos[2]:.9f} " + f"{int(color[0])} {int(color[1])} {int(color[2])} 0\n" + ) + + +def write_points3d_bin(path: Path, xyz: np.ndarray, rgb: np.ndarray) -> None: + with path.open("wb") as fh: + fh.write(struct.pack(" bytes: + if Image is None: + raise RuntimeError("Pillow 未安装,无法使用 --resize-width/--resize-height。请先 pip install pillow。") + width, height = target_size + resample_map = { + "nearest": Image.NEAREST, + "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC, + "lanczos": Image.LANCZOS, + } + resample = resample_map[filter_name] + with Image.open(io.BytesIO(jpeg_bytes)) as img: + img = img.convert("RGB") + resized = img.resize((width, height), resample=resample) + buffer = io.BytesIO() + resized.save(buffer, format="JPEG", quality=quality) + return buffer.getvalue() + + +def convert_scene(args: argparse.Namespace) -> None: + if args.frame_step <= 0: + raise ValueError("frame-step 必须为正整数") + if args.start_frame < 0: + raise ValueError("start-frame 不能为负") + if args.points_stride <= 0: + raise ValueError("points-stride 必须为正整数") + resize_dims: Optional[Tuple[int, int]] = None + if args.resize_width is not None or args.resize_height is not None: + if args.resize_width is None or args.resize_height is None: + raise ValueError("--resize-width/--resize-height 需要同时指定") + if args.resize_width <= 0 or args.resize_height <= 0: + raise ValueError("resize 尺寸必须为正整数") + if not 1 <= args.resize_jpeg_quality <= 100: + raise ValueError("resize-jpeg-quality 需在 [1, 100]") + resize_dims = (args.resize_width, args.resize_height) + + scene_root = Path(args.scene_root).resolve() + scene_id = args.scene_id or infer_scene_id(scene_root) + sens_path = scene_root / f"{scene_id}.sens" + if not sens_path.is_file(): + raise FileNotFoundError(f"未找到 {sens_path}") + meta_path = scene_root / f"{scene_id}.txt" + axis = parse_axis_alignment(meta_path) if args.apply_axis_alignment else np.eye(4, dtype=np.float64) + images_dir, sparse_dir = ensure_output_dirs(Path(args.output).resolve()) + point_cloud_path = find_point_cloud(scene_root, scene_id, args.points_source) + + print(f"[INFO] 转换场景 {scene_id} -> {args.output}") + print(f"[INFO] 使用点云: {point_cloud_path}") + + with ScanNetSensorData(sens_path) as sensor: + if sensor.color_compression != 2: + raise NotImplementedError(f"暂不支持 color_compression={sensor.color_compression} 的 .sens") + + fx = float(sensor.intrinsic_color[0, 0]) + fy = float(sensor.intrinsic_color[1, 1]) + cx = float(sensor.intrinsic_color[0, 2]) + cy = float(sensor.intrinsic_color[1, 2]) + camera_id = 1 + + selected: List[ImageRecord] = [] + next_image_id = 1 + max_frames = args.max_frames if args.max_frames and args.max_frames > 0 else None + start = args.start_frame + for frame in sensor.iter_frames(): + if frame.index < start: + continue + if (frame.index - start) % args.frame_step != 0: + continue + if max_frames is not None and len(selected) >= max_frames: + break + + c2w = frame.camera_to_world + if args.apply_axis_alignment: + c2w = axis @ c2w + if not np.all(np.isfinite(c2w)): + print(f"[WARN] 跳过第 {frame.index} 帧:pose 含 NaN") + continue + + w2c = np.linalg.inv(c2w) + rot = w2c[:3, :3] + tvec = w2c[:3, 3] + qvec = rotmat_to_qvec(rot) + image_name = f"frame_{frame.index:06d}.jpg" + image_path = images_dir / image_name + if resize_dims is None: + with image_path.open("wb") as im_fh: + im_fh.write(frame.color_bytes) + else: + resized_bytes = resize_color_bytes( + frame.color_bytes, + resize_dims, + args.resize_jpeg_quality, + args.resize_filter, + ) + with image_path.open("wb") as im_fh: + im_fh.write(resized_bytes) + + selected.append( + ImageRecord( + image_id=next_image_id, + name=image_name, + qvec=qvec, + tvec=tvec, + frame_index=frame.index, + ) + ) + next_image_id += 1 + if len(selected) % 100 == 0: + print(f"[INFO] 已写入 {len(selected)} 张图像") + + if not selected: + raise RuntimeError("没有任何帧满足采样条件,请检查 start/step/max 参数。") + + cams_txt = sparse_dir / "cameras.txt" + write_cameras_txt(cams_txt, camera_id, sensor.color_width, sensor.color_height, fx, fy, cx, cy) + imgs_txt = sparse_dir / "images.txt" + write_images_txt(imgs_txt, camera_id, selected) + + xyz, rgb = load_point_cloud( + point_cloud_path, + stride=args.points_stride, + max_points=args.points_max, + seed=args.points_seed, + transform=axis if args.apply_axis_alignment else None, + ) + write_points3d_txt(sparse_dir / "points3D.txt", xyz, rgb) + write_points3d_bin(sparse_dir / "points3D.bin", xyz, rgb) + + print(f"[INFO] 转换完成:{len(selected)} 张图像,{xyz.shape[0]} 个点。") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="将 ScanNet 场景转换为 Milo 所需的 COLMAP 布局。") + parser.add_argument("--scene-root", required=True, help="包含 sceneXXXX_YY.* 的目录(目录内有 .sens/.txt/.ply)。") + parser.add_argument("--output", required=True, help="输出目录(需不存在或为空,将创建 images/ 与 sparse/0/)。") + parser.add_argument("--scene-id", help="可选,显式指定 sceneXXXX_YY。默认取目录名或自动推断。") + parser.add_argument("--start-frame", type=int, default=0, help="从第多少帧开始采样(默认 0)。") + parser.add_argument("--frame-step", type=int, default=1, help="帧采样步长,例如 5 表示每 5 帧取 1 帧。") + parser.add_argument("--max-frames", type=int, help="最多输出多少帧,默认全部。") + parser.add_argument("--no-axis-alignment", dest="apply_axis_alignment", action="store_false", help="不使用 axisAlignment。") + parser.set_defaults(apply_axis_alignment=True) + parser.add_argument("--points-source", type=Path, help="自定义点云 .ply 路径(默认自动找 *_vh_clean_2.ply)。") + parser.add_argument("--points-stride", type=int, default=1, help="点云下采样步长(1 表示保留全部)。") + parser.add_argument("--points-max", type=int, help="点云最多保留多少点。") + parser.add_argument("--points-seed", type=int, default=0, help="点云随机采样用的随机种子。") + parser.add_argument("--resize-width", type=int, help="可选,将 RGB 输出缩放到指定宽度(像素)。需要安装 pillow。") + parser.add_argument("--resize-height", type=int, help="可选,将 RGB 输出缩放到指定高度(像素)。") + parser.add_argument( + "--resize-filter", + choices=["nearest", "bilinear", "bicubic", "lanczos"], + default="lanczos", + help="可选缩放滤波器(默认 lanczos)。", + ) + parser.add_argument( + "--resize-jpeg-quality", + type=int, + default=95, + help="缩放后重新写入 JPEG 时的质量系数(1-100,默认 95)。", + ) + return parser + + +def main(): + parser = build_arg_parser() + args = parser.parse_args() + convert_scene(args) + + +if __name__ == "__main__": + main() diff --git a/milo/scripts/verify_camera_poses.py b/milo/scripts/verify_camera_poses.py new file mode 100644 index 0000000..d139ec0 --- /dev/null +++ b/milo/scripts/verify_camera_poses.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +""" +Verification script to check camera pose interpretation. +This script loads the camera poses and PLY file to verify if Gaussians +are in front of or behind the cameras. +""" + +import json +import numpy as np +from pathlib import Path + + +def quaternion_to_rotation_matrix(q): + """Convert quaternion [w, x, y, z] to rotation matrix.""" + q = np.asarray(q, dtype=np.float64) + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float32, + ) + return rotation + + +def load_ply_points(ply_path): + """Load point positions from PLY file.""" + from plyfile import PlyData + + plydata = PlyData.read(ply_path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + return positions + + +def check_camera_pose_current(camera_entry, points): + """Check with CURRENT (incorrect) pose interpretation.""" + quaternion = camera_entry["quaternion"] + camera_center = np.array(camera_entry["position"], dtype=np.float32) + + # Current code (WRONG according to codex): + rotation = quaternion_to_rotation_matrix(quaternion) + translation = -rotation.T @ camera_center + + # Transform points to camera space + # With current interpretation: R is passed as-is, T is translation + # getWorld2View2 expects R to be C2W and builds W2C as [R.T | T] + # So W2C rotation is rotation.T, W2C translation is translation + R_w2c = rotation.T + t_w2c = translation + + # Transform points: p_cam = R_w2c @ p_world + t_w2c + points_cam = (R_w2c @ points.T).T + t_w2c + + # Check z coordinate (positive = in front of camera) + in_front = np.sum(points_cam[:, 2] > 0) + total = len(points) + fraction = in_front / total + + return fraction, in_front, total + + +def check_camera_pose_corrected(camera_entry, points): + """Check with CORRECTED pose interpretation.""" + quaternion = camera_entry["quaternion"] + camera_center = np.array(camera_entry["position"], dtype=np.float32) + + # Corrected code (as suggested by codex): + rotation_w2c = quaternion_to_rotation_matrix(quaternion) + rotation_c2w = rotation_w2c.T + translation_w2c = -rotation_w2c @ camera_center + + # With corrected interpretation: R_c2w is passed, T is translation_w2c + # getWorld2View2 builds W2C as [R_c2w.T | T] = [R_w2c | translation_w2c] + R_w2c = rotation_c2w.T # = rotation_w2c + t_w2c = translation_w2c + + # Transform points: p_cam = R_w2c @ p_world + t_w2c + points_cam = (R_w2c @ points.T).T + t_w2c + + # Check z coordinate (positive = in front of camera) + in_front = np.sum(points_cam[:, 2] > 0) + total = len(points) + fraction = in_front / total + + return fraction, in_front, total + + +def main(): + # Paths + camera_poses_path = Path("/milo/data/bridge_small/camera_poses_cam1.json") + ply_path = Path("/milo/data/bridge_small/yufu_bridge_small.ply") + + # Load data + print(f"Loading camera poses from {camera_poses_path}") + with open(camera_poses_path, 'r') as f: + camera_entries = json.load(f) + print(f"Loaded {len(camera_entries)} cameras") + + print(f"\nLoading point cloud from {ply_path}") + points = load_ply_points(ply_path) + print(f"Loaded {len(points)} points") + + # Check first camera with both interpretations + print("\n" + "="*80) + print("CHECKING CAMERA 0 (traj_0_cam0)") + print("="*80) + + camera_0 = camera_entries[0] + print(f"Camera position: {camera_0['position']}") + print(f"Camera quaternion: {camera_0['quaternion']}") + + print("\n--- Current (WRONG) interpretation ---") + frac_current, in_front_current, total = check_camera_pose_current(camera_0, points) + print(f"Points in front of camera: {in_front_current}/{total} ({frac_current*100:.2f}%)") + + print("\n--- Corrected interpretation ---") + frac_corrected, in_front_corrected, total = check_camera_pose_corrected(camera_0, points) + print(f"Points in front of camera: {in_front_corrected}/{total} ({frac_corrected*100:.2f}%)") + + # Check a few more cameras + print("\n" + "="*80) + print("CHECKING FIRST 5 CAMERAS") + print("="*80) + print(f"{'Camera':<15} {'Current (wrong)':<20} {'Corrected':<20}") + print("-" * 80) + + for i in range(min(5, len(camera_entries))): + camera = camera_entries[i] + frac_current, _, _ = check_camera_pose_current(camera, points) + frac_corrected, _, _ = check_camera_pose_corrected(camera, points) + print(f"{camera['name']:<15} {frac_current*100:>6.2f}% {frac_corrected*100:>6.2f}%") + + print("\n" + "="*80) + print("CONCLUSION") + print("="*80) + if frac_current < 0.5 and frac_corrected > 0.5: + print("✓ Codex's analysis is CORRECT!") + print(" - Current code: Most points are BEHIND cameras (wrong)") + print(" - Corrected code: Most points are IN FRONT of cameras (correct)") + print("\nThe quaternions should be interpreted as world→camera rotations,") + print("and the fix suggested by codex is needed.") + else: + print("✗ Results don't match codex's analysis.") + print(" Further investigation needed.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/milo/test_opt_config.py b/milo/test_opt_config.py new file mode 100644 index 0000000..e7bedb5 --- /dev/null +++ b/milo/test_opt_config.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +""" +测试优化配置文件加载功能 +""" +import sys +from pathlib import Path + +# 添加当前目录到路径 +sys.path.insert(0, str(Path(__file__).parent)) + +from yufu2mesh_new import load_optimization_config + +def test_config(config_name: str): + """测试加载指定的配置文件""" + print(f"\n{'='*60}") + print(f"测试配置: {config_name}") + print('='*60) + + try: + config = load_optimization_config(config_name) + + print("\n✓ 配置加载成功!") + print("\n高斯参数配置:") + print("-" * 40) + for param_name, param_cfg in config["gaussian_params"].items(): + trainable = param_cfg.get("trainable", False) + lr = param_cfg.get("lr", 0.0) + status = "✓ 可训练" if trainable else "✗ 冻结" + print(f" {param_name:20s} {status:10s} lr={lr:.6f}") + + print("\nLoss权重配置:") + print("-" * 40) + for loss_name, weight in config["loss_weights"].items(): + print(f" {loss_name:20s} {weight:.3f}") + + print("\n深度处理配置:") + print("-" * 40) + depth_cfg = config["depth_processing"] + print(f" clip_min: {depth_cfg.get('clip_min')}") + print(f" clip_max: {depth_cfg.get('clip_max')}") + + print("\nMesh正则化配置:") + print("-" * 40) + mesh_cfg = config["mesh_regularization"] + print(f" depth_weight: {mesh_cfg.get('depth_weight')}") + print(f" normal_weight: {mesh_cfg.get('normal_weight')}") + + return True + + except Exception as e: + print(f"\n✗ 配置加载失败: {e}") + import traceback + traceback.print_exc() + return False + +def main(): + """测试所有预设配置""" + configs_to_test = [ + "default", + "xyz_only", + "xyz_geometry", + "xyz_occupancy", + "full" + ] + + print("开始测试优化配置加载功能...") + + results = {} + for config_name in configs_to_test: + results[config_name] = test_config(config_name) + + print("\n" + "="*60) + print("测试总结") + print("="*60) + for config_name, success in results.items(): + status = "✓ 通过" if success else "✗ 失败" + print(f" {config_name:20s} {status}") + + all_passed = all(results.values()) + if all_passed: + print("\n🎉 所有配置测试通过!") + return 0 + else: + print("\n⚠️ 部分配置测试失败") + return 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/milo/train.py b/milo/train.py index 2f0464a..4fb60cd 100644 --- a/milo/train.py +++ b/milo/train.py @@ -2,6 +2,8 @@ import sys import gc import yaml +import json +import random from functools import partial BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.abspath(os.path.join(BASE_DIR, '..')) @@ -208,6 +210,9 @@ def training( culling=gaussians._culling[:,viewpoint_cam.uid], ) + if "area_max" not in render_pkg: + render_pkg["area_max"] = torch.zeros_like(render_pkg["radii"]) + # ---Compute losses--- image, viewspace_point_tensor, visibility_filter, radii = ( render_pkg["render"], render_pkg["viewspace_points"], diff --git a/milo/useless_maybe/depth_guided_refine.py b/milo/useless_maybe/depth_guided_refine.py new file mode 100644 index 0000000..605aff4 --- /dev/null +++ b/milo/useless_maybe/depth_guided_refine.py @@ -0,0 +1,602 @@ +#!/usr/bin/env python3 +"""Depth-guided refinement of Gaussian SDFs and mesh extraction. + +This script optimizes a pretrained Gaussian Splat (PLY) using per-view depth maps. +Compared to `iterative_occupancy_refine.py`, Gaussian geometry is trainable and +the supervision comes directly from depth instead of RGB images. +""" + +from __future__ import annotations + +import argparse +import json +import os +import random +import re +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from arguments import OptimizationParams, PipelineParams +from gaussian_renderer import integrate_radegs +from milo.useless_maybe.ply2mesh import ( + ManualScene, + export_mesh_from_gaussians, + initialize_mesh_regularization, + load_cameras_from_json, + build_render_functions, +) +from regularization.regularizer.mesh import compute_mesh_regularization +from scene.gaussian_model import GaussianModel, SparseGaussianAdam +from utils.general_utils import get_expon_lr_func +from torch.nn.utils import clip_grad_norm_ + + +def ensure_learnable_occupancy(gaussians: GaussianModel) -> None: + """Ensure occupancy buffers exist and the shift tensor is trainable.""" + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + device = gaussians._xyz.device + n_pts = gaussians._xyz.shape[0] + base = torch.zeros((n_pts, 9), device=device) + shift = torch.zeros_like(base) + gaussians.learn_occupancy = True + gaussians._base_occupancy = torch.nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = torch.nn.Parameter(shift.requires_grad_(True)) + gaussians.set_occupancy_mode("occupancy_shift") + gaussians._occupancy_shift.requires_grad_(True) + + +def extract_loss_scalars(metrics: Dict) -> Dict[str, float]: + scalars: Dict[str, float] = {} + for key, value in metrics.items(): + if not key.endswith("_loss"): + continue + scalar: Optional[float] = None + if isinstance(value, torch.Tensor): + if value.ndim == 0: + scalar = float(value.item()) + elif isinstance(value, (float, int)): + scalar = float(value) + if scalar is not None: + scalars[key] = scalar + return scalars + + +def export_iteration_state( + iteration: int, + gaussians: GaussianModel, + mesh_state: Dict, + output_dir: str, + reference_camera=None, +) -> None: + os.makedirs(output_dir, exist_ok=True) + mesh_path = os.path.join(output_dir, f"mesh_iter_{iteration:05d}.ply") + ply_path = os.path.join(output_dir, f"gaussians_iter_{iteration:05d}.ply") + + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=mesh_path, + reference_camera=reference_camera, + ) + gaussians.save_ply(ply_path) + + +def natural_key(path: str) -> List[object]: + """Split path into text/number tokens for natural sorting.""" + return [ + int(token) if token.isdigit() else token + for token in re.split(r"(\d+)", path) + if token + ] + + +@dataclass +class DepthRecord: + depth: torch.Tensor # (1, H, W) on CPU + valid_mask: torch.Tensor # (1, H, W) on CPU, float mask in {0,1} + + +class DepthMapProvider: + """Loads and serves depth maps corresponding to camera viewpoints.""" + + def __init__( + self, + depth_dir: str, + cameras: Sequence, + depth_scale: float = 1.0, + depth_offset: float = 0.0, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + if not os.path.isdir(depth_dir): + raise FileNotFoundError(f"Depth directory not found: {depth_dir}") + self.depth_dir = depth_dir + self.depth_scale = depth_scale + self.depth_offset = depth_offset + self.clip_min = clip_min + self.clip_max = clip_max + + file_list = [f for f in os.listdir(depth_dir) if f.endswith(".npy")] + if not file_list: + raise ValueError(f"No depth npy files found in {depth_dir}") + + # Index depth files by (cam_idx, frame_idx) when possible. + pattern = re.compile(r"depth_img_(\d+)_(\d+)\.npy$") + indexed_files: Dict[Tuple[int, int], str] = {} + for filename in file_list: + match = pattern.match(filename) + if match: + cam_idx = int(match.group(1)) + frame_idx = int(match.group(2)) + indexed_files[(cam_idx, frame_idx)] = filename + + # Fallback: natural sorted list for sequential mapping. + natural_sorted_files = sorted(file_list, key=natural_key) + + self.depth_height: Optional[int] = None + self.depth_width: Optional[int] = None + self.global_min: float = float("inf") + self.global_max: float = float("-inf") + self.global_valid_pixels: int = 0 + + self.records: List[DepthRecord] = [] + for cam_idx, cam in enumerate(cameras): + depth_path = self._resolve_path( + cam.image_name if hasattr(cam, "image_name") else str(cam_idx), + cam_idx, + indexed_files, + natural_sorted_files, + ) + full_path = os.path.join(depth_dir, depth_path) + depth_np = np.load(full_path) + if depth_np.ndim == 3 and depth_np.shape[-1] == 1: + depth_np = depth_np[..., 0] + if depth_np.ndim == 2: + depth_np = depth_np[None, ...] # (1, H, W) + elif depth_np.ndim == 3 and depth_np.shape[0] == 1: + pass # already (1, H, W) + else: + raise ValueError(f"Unexpected depth shape {depth_np.shape} in {full_path}") + + depth_tensor = torch.from_numpy(depth_np.astype(np.float32)) + depth_tensor = depth_tensor * depth_scale + depth_offset + + if clip_min is not None or clip_max is not None: + depth_tensor = depth_tensor.clamp( + min=clip_min if clip_min is not None else float("-inf"), + max=clip_max if clip_max is not None else float("inf"), + ) + + valid_mask = (depth_tensor > 0.0).float() + # Track global statistics for diagnostics. + if self.depth_height is None: + self.depth_height, self.depth_width = depth_tensor.shape[-2:] + valid_values = depth_tensor[valid_mask > 0.5] + if valid_values.numel() > 0: + self.global_min = min(self.global_min, float(valid_values.min().item())) + self.global_max = max(self.global_max, float(valid_values.max().item())) + self.global_valid_pixels += int(valid_values.numel()) + + self.records.append(DepthRecord(depth=depth_tensor.contiguous(), valid_mask=valid_mask)) + + if len(self.records) != len(cameras): + raise RuntimeError("Depth map count does not match number of cameras.") + if self.global_min == float("inf"): + self.global_min = 0.0 + self.global_max = 0.0 + + def _resolve_path( + self, + camera_name: str, + camera_idx: int, + indexed_files: Dict[Tuple[int, int], str], + fallback_files: List[str], + ) -> str: + match = re.search(r"traj_(\d+)_cam(\d+)", camera_name) + if match: + frame_idx = int(match.group(1)) + cam_idx = int(match.group(2)) + candidate = indexed_files.get((cam_idx, frame_idx)) + if candidate: + return candidate + # Fallback to cam index with ordered list. + if camera_idx >= len(fallback_files): + raise IndexError( + f"Camera index {camera_idx} exceeds depth file count {len(fallback_files)}." + ) + return fallback_files[camera_idx] + + def get(self, index: int, device: torch.device) -> DepthRecord: + record = self.records[index] + depth = record.depth.to(device, non_blocking=True) + valid = record.valid_mask.to(device, non_blocking=True) + return DepthRecord(depth=depth, valid_mask=valid) + + +def compute_depth_loss( + predicted: torch.Tensor, + target: torch.Tensor, + valid_mask: torch.Tensor, + epsilon: float = 1e-8, +) -> Tuple[torch.Tensor, float, float, int]: + """Compute masked L1 loss and return (loss, mean_abs_error, valid_fraction, valid_pixels).""" + if predicted.shape != target.shape: + target = F.interpolate( + target.unsqueeze(0), + size=predicted.shape[-2:], + mode="bilinear", + align_corners=True, + ).squeeze(0) + valid_mask = F.interpolate( + valid_mask.unsqueeze(0), + size=predicted.shape[-2:], + mode="nearest", + ).squeeze(0) + + valid = valid_mask > 0.5 + valid_pixels = valid.sum().item() + if valid_pixels == 0: + zero = torch.zeros((), device=predicted.device, dtype=predicted.dtype) + return zero, 0.0, 0.0, 0 + + diff = (predicted - target).abs() * valid_mask + loss = diff.sum() / (valid_mask.sum() + epsilon) + mae = diff.sum().item() / (valid_pixels + epsilon) + valid_fraction = valid_pixels / valid_mask.numel() + return loss, mae, valid_fraction, int(valid_pixels) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Depth-guided Gaussian refinement with mesh regularization.") + parser.add_argument("--ply_path", type=str, required=True, help="Input Gaussian PLY.") + parser.add_argument("--camera_poses", type=str, required=True, help="Camera pose JSON.") + parser.add_argument("--depth_dir", type=str, required=True, help="Directory containing per-view depth .npy files.") + parser.add_argument("--mesh_config", type=str, default="medium", help="Mesh regularization config name.") + parser.add_argument("--iterations", type=int, default=5000, help="Number of optimization steps.") + parser.add_argument("--depth_loss_weight", type=float, default=1.0, help="(Deprecated) Depth loss multiplier; kept for backward compatibility.") + parser.add_argument("--mesh_loss_weight", type=float, default=1.0, help="(Deprecated) Mesh loss multiplier; kept for backward compatibility.") + parser.add_argument("--occupancy_lr_scale", type=float, default=1.0, help="Multiplier applied to occupancy LR.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered image width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered image height.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field-of-view in degrees.") + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--output_dir", type=str, default="./depth_refine_output", help="Directory to store outputs.") + parser.add_argument("--log_interval", type=int, default=100, help="Logging interval.") + parser.add_argument("--export_interval", type=int, default=1000, help="Mesh export interval (0 disables periodic export).") + parser.add_argument("--depth_scale", type=float, default=1.0, help="Scale factor applied to loaded depth maps.") + parser.add_argument("--depth_offset", type=float, default=0.0, help="Offset applied to loaded depth maps after scaling.") + parser.add_argument("--depth_clip_min", type=float, default=None, help="Clip depth to minimum value (after scaling).") + parser.add_argument("--depth_clip_max", type=float, default=None, help="Clip depth to maximum value (after scaling).") + parser.add_argument("--freeze_colors", dest="freeze_colors", action="store_true", help="Freeze SH features during optimization.") + parser.add_argument("--no-freeze_colors", dest="freeze_colors", action="store_false", help="Allow SH features to be optimized.") + parser.set_defaults(freeze_colors=True) + parser.add_argument("--grad_clip_norm", type=float, default=0.0, help="Apply gradient clipping with given norm (0 disables).") + parser.add_argument("--initial_lr_scale", type=float, default=1.0, help="Multiplier for position lr_init.") + parser.add_argument("--device", type=str, default="cuda", help="Compute device.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration at which mesh regularization starts.") + parser.add_argument("--mesh_stop_iter", type=int, default=None, help="Optional iteration to stop mesh regularization.") + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for surface sampling.") + parser.add_argument("--imp_metric", type=str, default="outdoor", choices=["outdoor", "indoor"], help="Importance metric for surface sampling.") + parser.add_argument("--depth_loss_epsilon", type=float, default=1e-6, help="Numerical epsilon for depth loss denominator.") + return parser + + +def main() -> None: + parser = build_arg_parser() + args = parser.parse_args() + + if not torch.cuda.is_available() and args.device.startswith("cuda"): + raise RuntimeError("CUDA device is required for depth-guided refinement.") + + device = torch.device(args.device) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + ) + print(f"[INFO] Loaded {len(cameras)} cameras from {args.camera_poses}.") + + depth_provider = DepthMapProvider( + depth_dir=args.depth_dir, + cameras=cameras, + depth_scale=args.depth_scale, + depth_offset=args.depth_offset, + clip_min=args.depth_clip_min, + clip_max=args.depth_clip_max, + ) + print(f"[INFO] Loaded {len(depth_provider.records)} depth maps from {args.depth_dir}.") + if depth_provider.depth_height is not None: + depth_h, depth_w = depth_provider.depth_height, depth_provider.depth_width + if depth_h != args.image_height or depth_w != args.image_width: + print( + f"[WARNING] Depth resolution ({depth_w}x{depth_h}) differs from render resolution " + f"({args.image_width}x{args.image_height}). Depth maps will be interpolated." + ) + if depth_provider.global_valid_pixels == 0: + print("[WARNING] No valid depth pixels found across dataset; depth supervision will be ineffective.") + else: + print( + "[INFO] Depth value range after scaling: " + f"{depth_provider.global_min:.4f} – {depth_provider.global_max:.4f} " + f"({depth_provider.global_valid_pixels} valid pixels)." + ) + + scene = ManualScene(cameras) + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + print(f"[INFO] Loaded {gaussians._xyz.shape[0]} Gaussians from {args.ply_path}.") + + ensure_learnable_occupancy(gaussians) + gaussians.init_culling(len(cameras)) + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + + mesh_config = load_mesh_config( + name=args.mesh_config, + start_iter_override=args.mesh_start_iter, + stop_iter_override=args.mesh_stop_iter, + total_iterations=args.iterations, + ) + occupancy_mode = mesh_config.get("occupancy_mode", "occupancy_shift") + if occupancy_mode != "occupancy_shift": + raise ValueError( + f"Depth-guided refinement requires occupancy_mode 'occupancy_shift', got '{occupancy_mode}'. " + "Please adjust the mesh configuration." + ) + gaussians.set_occupancy_mode(occupancy_mode) + print( + "[INFO] Mesh config '{name}': start_iter={start}, stop_iter={stop}, n_max_points_in_delaunay={limit}".format( + name=args.mesh_config, + start=mesh_config.get("start_iter"), + stop=mesh_config.get("stop_iter"), + limit=mesh_config.get("n_max_points_in_delaunay"), + ) + ) + + opt_parser = argparse.ArgumentParser() + opt_params = OptimizationParams(opt_parser) + opt_params.iterations = args.iterations + opt_params.position_lr_init *= args.initial_lr_scale + opt_params.position_lr_final *= args.initial_lr_scale + + gaussians.training_setup(opt_params) + + if args.freeze_colors: + gaussians._features_dc.requires_grad_(False) + gaussians._features_rest.requires_grad_(False) + + lr_xyz_init = opt_params.position_lr_init * gaussians.spatial_lr_scale + + param_groups = [ + {"params": [gaussians._xyz], "lr": lr_xyz_init, "name": "xyz"}, + {"params": [gaussians._opacity], "lr": opt_params.opacity_lr, "name": "opacity"}, + {"params": [gaussians._scaling], "lr": opt_params.scaling_lr, "name": "scaling"}, + {"params": [gaussians._rotation], "lr": opt_params.rotation_lr, "name": "rotation"}, + ] + if not args.freeze_colors: + param_groups.append({"params": [gaussians._features_dc], "lr": opt_params.feature_lr, "name": "f_dc"}) + param_groups.append({"params": [gaussians._features_rest], "lr": opt_params.feature_lr / 20.0, "name": "f_rest"}) + if gaussians.learn_occupancy: + param_groups.append({"params": [gaussians._occupancy_shift], "lr": opt_params.opacity_lr * args.occupancy_lr_scale, "name": "occupancy_shift"}) + + gaussians.optimizer = SparseGaussianAdam(param_groups, lr=0.0, eps=1e-15) + gaussians.xyz_scheduler_args = get_expon_lr_func( + lr_init=lr_xyz_init, + lr_final=opt_params.position_lr_final * gaussians.spatial_lr_scale, + lr_delay_mult=opt_params.position_lr_delay_mult, + max_steps=opt_params.position_lr_max_steps, + ) + + background = torch.zeros(3, dtype=torch.float32, device=device) + pipe_parser = argparse.ArgumentParser() + pipe = PipelineParams(pipe_parser) + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + mesh_renderer, mesh_state = initialize_mesh_regularization(scene, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + runtime_args = argparse.Namespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=getattr(args, "depth_reinit_iter", args.warn_until_iter), + ) + + os.makedirs(args.output_dir, exist_ok=True) + log_dir = os.path.join(args.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + loss_log_path = os.path.join(log_dir, "losses.jsonl") + + ema_depth_loss = None + ema_mesh_loss = None + pending_view_indices: List[int] = [] + printed_depth_diagnostics = False + + with open(loss_log_path, "w", encoding="utf-8") as loss_log_file: + for iteration in range(1, args.iterations + 1): + if not pending_view_indices: + pending_view_indices = list(range(len(cameras))) + random.shuffle(pending_view_indices) + + view_idx = pending_view_indices.pop() + viewpoint = cameras[view_idx] + + depth_record = depth_provider.get(view_idx, device) + render_pkg = render_view(viewpoint) + + pred_depth = render_pkg["median_depth"] + depth_loss, depth_mae, valid_fraction, valid_pixels = compute_depth_loss( + predicted=pred_depth, + target=depth_record.depth, + valid_mask=depth_record.valid_mask, + epsilon=args.depth_loss_epsilon, + ) + + if valid_pixels == 0: + skipped_record = { + "iteration": iteration, + "view_index": view_idx, + "skipped": True, + "skipped_reason": "invalid_depth", + } + loss_log_file.write(json.dumps(skipped_record) + "\n") + loss_log_file.flush() + if iteration % args.log_interval == 0 or iteration == 1: + print(f"[Iter {iteration:05d}] skipped view {view_idx} due to invalid depth.") + continue + + total_loss = depth_loss + + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=gaussians, + scene=scene, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=100.0 / max(args.iterations, 1), + args=runtime_args, + integrate_func=integrate_radegs, + ) + mesh_state = mesh_pkg["updated_state"] + mesh_loss_tensor = mesh_pkg["mesh_loss"] + mesh_loss = mesh_loss_tensor + + if not printed_depth_diagnostics: + depth_valid = depth_record.depth[depth_record.valid_mask > 0.5] + print( + "[DIAG] First valid depth batch: " + f"depth range {float(depth_valid.min().item()):.4f} – {float(depth_valid.max().item()):.4f}, " + f"predicted range {float(pred_depth.min().item()):.4f} – {float(pred_depth.max().item()):.4f}" + ) + print(f"[DIAG] Gaussian spatial_lr_scale: {gaussians.spatial_lr_scale:.6f}") + mesh_loss_unweighted = mesh_loss_tensor.item() + mesh_loss_weighted_diag = mesh_loss.item() + print( + f"[DIAG] Initial losses — depth_loss={depth_loss.item():.6e}, " + f"mesh_loss_raw={mesh_loss_unweighted:.6e}, " + f"mesh_loss_weighted={mesh_loss_weighted_diag:.6e}" + ) + printed_depth_diagnostics = True + + total_loss = total_loss + mesh_loss + + gaussians.optimizer.zero_grad(set_to_none=True) + total_loss.backward() + if args.grad_clip_norm > 0.0: + trainable_params: List[torch.Tensor] = [] + for group in gaussians.optimizer.param_groups: + for param in group.get("params", []): + if isinstance(param, torch.Tensor) and param.requires_grad: + trainable_params.append(param) + if trainable_params: + clip_grad_norm_(trainable_params, args.grad_clip_norm) + gaussians.update_learning_rate(iteration) + visibility = render_pkg["visibility_filter"] + radii = render_pkg["radii"] + gaussians.optimizer.step(visibility, radii.shape[0]) + + total_loss_value = float(total_loss.item()) + depth_loss_value = float(depth_loss.item()) + mesh_loss_value = float(mesh_loss_tensor.item()) + weighted_mesh_loss_value = mesh_loss_value + + ema_depth_loss = depth_loss_value if ema_depth_loss is None else (0.9 * ema_depth_loss + 0.1 * depth_loss_value) + ema_mesh_loss = weighted_mesh_loss_value if ema_mesh_loss is None else (0.9 * ema_mesh_loss + 0.1 * weighted_mesh_loss_value) + + iteration_record = { + "iteration": iteration, + "view_index": view_idx, + "total_loss": total_loss_value, + "depth_loss": depth_loss_value, + "mesh_loss_raw": mesh_loss_value, + "mesh_loss_weighted": weighted_mesh_loss_value, + "ema_depth_loss": ema_depth_loss, + "ema_mesh_loss": ema_mesh_loss, + "depth_mae": depth_mae, + "valid_fraction": valid_fraction, + "valid_pixels": valid_pixels, + } + iteration_record.update(extract_loss_scalars(mesh_pkg)) + loss_log_file.write(json.dumps(iteration_record) + "\n") + loss_log_file.flush() + + if args.export_interval > 0 and iteration % args.export_interval == 0: + export_iteration_state( + iteration=iteration, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=args.output_dir, + reference_camera=None, + ) + + if iteration % args.log_interval == 0 or iteration == 1: + print( + "[Iter {iter:05d}] loss={loss:.6f} depth={depth:.6f} mesh={mesh:.6f} " + "depth_mae={mae:.6f} valid={valid:.3f}".format( + iter=iteration, + loss=total_loss_value, + depth=depth_loss_value, + mesh=weighted_mesh_loss_value, + mae=depth_mae, + valid=valid_fraction, + ) + ) + + final_dir = os.path.join(args.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + export_iteration_state( + iteration=args.iterations, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=final_dir, + reference_camera=None, + ) + print(f"[INFO] Depth-guided refinement completed. Results saved to {args.output_dir}.") + + +def load_mesh_config( + name: str, + start_iter_override: Optional[int] = None, + stop_iter_override: Optional[int] = None, + total_iterations: Optional[int] = None, +) -> Dict: + from milo.useless_maybe.ply2mesh import load_mesh_config_file + + config = load_mesh_config_file(name) + if start_iter_override is not None: + config["start_iter"] = max(1, start_iter_override) + else: + config["start_iter"] = max(1, config.get("start_iter", 1)) + if stop_iter_override is not None: + config["stop_iter"] = stop_iter_override + elif total_iterations is not None: + config["stop_iter"] = max(config.get("stop_iter", total_iterations), total_iterations) + config["stop_iter"] = max(config.get("stop_iter", config["start_iter"]), config["start_iter"]) + return config + + +if __name__ == "__main__": + main() diff --git a/milo/useless_maybe/depth_train.py b/milo/useless_maybe/depth_train.py new file mode 100644 index 0000000..aafeba7 --- /dev/null +++ b/milo/useless_maybe/depth_train.py @@ -0,0 +1,940 @@ +#!/usr/bin/env python3 +""" +Depth-supervised training loop for 3D Gaussian Splatting. + +This script mirrors the original MILo image-supervised training pipeline, but +replaces the photometric loss with a depth reconstruction objective fed by +per-view depth maps. It supports mesh-in-the-loop regularization, gaussian +densification/simplification, and periodic exports for inspection. +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import random +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import yaml +from torch.nn.utils import clip_grad_norm_ +import trimesh + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +ROOT_DIR = os.path.abspath(os.path.join(BASE_DIR, "..")) +sys.path.append(ROOT_DIR) + +from arguments import OptimizationParams, PipelineParams # noqa: E402 +from gaussian_renderer import render_simp # noqa: E402 +from gaussian_renderer.radegs import render_radegs as render_radegs # noqa: E402 +from gaussian_renderer.radegs import integrate_radegs as integrate # noqa: E402 +from regularization.regularizer.mesh import initialize_mesh_regularization # noqa: E402 +from regularization.regularizer.mesh import compute_mesh_regularization # noqa: E402 +from regularization.regularizer.mesh import reset_mesh_state_at_next_iteration # noqa: E402 +from scene.cameras import Camera # noqa: E402 +from scene.gaussian_model import GaussianModel # noqa: E402 +from utils.geometry_utils import flatten_voronoi_features # noqa: E402 +from utils.general_utils import safe_state # noqa: E402 +from functional import extract_mesh, compute_delaunay_triangulation # noqa: E402 +from functional.mesh import frustum_cull_mesh # noqa: E402 +from regularization.sdf.learnable import convert_occupancy_to_sdf # noqa: E402 + + +def quaternion_to_rotation_matrix(q: Sequence[float]) -> np.ndarray: + q = np.asarray(q, dtype=np.float64) + if q.shape != (4,): + raise ValueError("Quaternion must have shape (4,)") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float32, + ) + return rotation + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, + data_device: str, +) -> List[Camera]: + if not os.path.isfile(json_path): + raise FileNotFoundError(f"Camera JSON not found: {json_path}") + with open(json_path, "r", encoding="utf-8") as f: + entries = json.load(f) + if not entries: + raise ValueError(f"No camera entries in {json_path}") + + fov_y = math.radians(fov_y_deg) + aspect = image_width / image_height + fov_x = 2.0 * math.atan(aspect * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(entries): + if "quaternion" in entry: + rotation = quaternion_to_rotation_matrix(entry["quaternion"]) + elif "rotation" in entry: + rotation = np.asarray(entry["rotation"], dtype=np.float32) + if rotation.shape != (3, 3): + raise ValueError(f"Camera entry {idx} rotation must be 3x3") + else: + raise KeyError(f"Camera entry {idx} missing rotation or quaternion.") + + if "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + elif "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"Camera entry {idx} position must be length-3.") + translation = -rotation.T @ camera_center + else: + raise KeyError(f"Camera entry {idx} missing translation/position.") + + if translation.shape != (3,): + raise ValueError(f"Camera entry {idx} translation must be length-3.") + + image_name = ( + entry.get("name") + or entry.get("img_name") + or entry.get("image_name") + or f"view_{idx:04d}" + ) + + camera = Camera( + colmap_id=str(idx), + R=rotation, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device=data_device, + ) + cameras.append(camera) + return cameras + + +def _clone_camera_for_scale(camera: Camera, scale: float) -> Camera: + if math.isclose(scale, 1.0): + return camera + + new_height = max(1, int(round(camera.image_height / scale))) + new_width = max(1, int(round(camera.image_width / scale))) + blank_image = torch.zeros(3, new_height, new_width, dtype=torch.float32) + + # Camera expects rotation/translation as numpy arrays; reuse existing values. + return Camera( + colmap_id=camera.colmap_id, + R=camera.R, + T=camera.T, + FoVx=camera.FoVx, + FoVy=camera.FoVy, + image=blank_image, + gt_alpha_mask=None, + image_name=camera.image_name, + uid=camera.uid, + data_device=str(camera.data_device), + ) + + +def _build_scaled_cameras( + cameras: Sequence[Camera], + scales: Sequence[float] = (1.0, 2.0), +) -> Dict[float, List[Camera]]: + scaled: Dict[float, List[Camera]] = {} + for scale in scales: + if math.isclose(scale, 1.0): + scaled[float(scale)] = list(cameras) + else: + scaled[float(scale)] = [_clone_camera_for_scale(cam, scale) for cam in cameras] + return scaled + + +class ManualScene: + """Minimal adapter exposing camera access expected by mesh regularizer.""" + + def __init__(self, cameras_by_scale: Dict[float, Sequence[Camera]]): + if 1.0 not in cameras_by_scale: + raise ValueError("At least scale 1.0 cameras must be provided.") + self._train_cameras: Dict[float, List[Camera]] = { + float(scale): list(cam_list) for scale, cam_list in cameras_by_scale.items() + } + + def getTrainCameras(self, scale: float = 1.0): + scale_key = float(scale) + if scale_key not in self._train_cameras: + scale_key = 1.0 + return list(self._train_cameras[scale_key]) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ): + preferred = scale2 if iteration <= warn_until_iter and float(scale2) in self._train_cameras else scale + fallback_scale = float(preferred) if float(preferred) in self._train_cameras else 1.0 + return list(self._train_cameras[fallback_scale]) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + def _render( + view: Camera, + pc_obj: GaussianModel, + pipe_obj: PipelineParams, + bg_color: torch.Tensor, + *, + kernel_size: float = 0.0, + require_coord: bool = False, + require_depth: bool = True, + ): + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_view(view: Camera): + return _render(view, gaussians, pipe, background) + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = _render( + view, + pc_obj, + pipe_obj, + bg_color, + kernel_size=kernel_size, + require_coord=require_coord, + require_depth=require_depth, + ) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def export_mesh_from_gaussians( + gaussians: GaussianModel, + mesh_state: Dict, + output_path: str, + reference_camera: Optional[Camera] = None, +) -> None: + delaunay_tets = mesh_state.get("delaunay_tets") + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_export = mesh + if reference_camera is not None: + mesh_to_export = frustum_cull_mesh(mesh, reference_camera) + + verts = mesh_to_export.verts.detach().cpu().numpy() + faces = mesh_to_export.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def load_mesh_config_file(name: str) -> Dict: + config_path = os.path.join(BASE_DIR, "configs", "mesh", f"{name}.yaml") + if not os.path.isfile(config_path): + raise FileNotFoundError(f"Mesh config not found: {config_path}") + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def ensure_learnable_occupancy(gaussians: GaussianModel) -> None: + """Ensure occupancy buffers exist and shifts are trainable.""" + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + device = gaussians._xyz.device + n_pts = gaussians._xyz.shape[0] + base = torch.zeros((n_pts, 9), device=device) + shift = torch.zeros_like(base) + gaussians.learn_occupancy = True + gaussians._base_occupancy = torch.nn.Parameter( + base.requires_grad_(False), requires_grad=False + ) + gaussians._occupancy_shift = torch.nn.Parameter(shift.requires_grad_(True)) + gaussians.set_occupancy_mode("occupancy_shift") + gaussians._occupancy_shift.requires_grad_(True) + + +@dataclass +class DepthRecord: + depth: torch.Tensor # (1, H, W) + valid_mask: torch.Tensor # (1, H, W) + + +class DepthMapProvider: + """Loads depth maps and matches them to cameras via naming convention.""" + + def __init__( + self, + depth_dir: Path, + cameras: Sequence, + depth_scale: float = 1.0, + depth_offset: float = 0.0, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + if not depth_dir.is_dir(): + raise FileNotFoundError(f"Depth directory not found: {depth_dir}") + + file_list = sorted([f.name for f in depth_dir.iterdir() if f.suffix == ".npy"]) + if not file_list: + raise ValueError(f"No depth .npy files found in {depth_dir}") + + pattern = re.compile(r"depth_img_(\d+)_(\d+)\.npy$") + indexed: Dict[Tuple[int, int], str] = {} + for filename in file_list: + match = pattern.match(filename) + if match: + cam_idx = int(match.group(1)) + frame_idx = int(match.group(2)) + indexed[(cam_idx, frame_idx)] = filename + + fallback_files = sorted(file_list, key=self._natural_key) + + self.depth_scale = depth_scale + self.depth_offset = depth_offset + self.clip_min = clip_min + self.clip_max = clip_max + + self.depth_height: Optional[int] = None + self.depth_width: Optional[int] = None + self.global_min: float = float("inf") + self.global_max: float = float("-inf") + self.global_valid_pixels: int = 0 + self.records: List[DepthRecord] = [] + + for cam_index, camera in enumerate(cameras): + depth_path = self._resolve_path( + camera_name=getattr(camera, "image_name", str(cam_index)), + camera_idx=cam_index, + indexed_files=indexed, + fallback_files=fallback_files, + ) + full_path = depth_dir / depth_path + depth_np = np.load(full_path) + if depth_np.ndim == 3 and depth_np.shape[-1] == 1: + depth_np = depth_np[..., 0] + if depth_np.ndim == 2: + depth_np = depth_np[None, ...] + elif depth_np.ndim == 3 and depth_np.shape[0] == 1: + pass + else: + raise ValueError(f"Unexpected depth shape {depth_np.shape} in {full_path}") + + depth = torch.from_numpy(depth_np.astype(np.float32)) + depth = depth * depth_scale + depth_offset + if clip_min is not None or clip_max is not None: + depth = depth.clamp( + min=clip_min if clip_min is not None else float("-inf"), + max=clip_max if clip_max is not None else float("inf"), + ) + + mask = (depth > 0.0).float() + if self.depth_height is None: + self.depth_height, self.depth_width = depth.shape[-2:] + + valid_values = depth[mask > 0.5] + if valid_values.numel() > 0: + self.global_min = min(self.global_min, float(valid_values.min())) + self.global_max = max(self.global_max, float(valid_values.max())) + self.global_valid_pixels += int(valid_values.numel()) + + self.records.append(DepthRecord(depth=depth.contiguous(), valid_mask=mask)) + + if self.global_min == float("inf"): + self.global_min = 0.0 + self.global_max = 0.0 + + @staticmethod + def _natural_key(path: str) -> List[object]: + tokens = re.split(r"(\d+)", Path(path).stem) + return [int(tok) if tok.isdigit() else tok for tok in tokens if tok] + + @staticmethod + def _resolve_path( + camera_name: str, + camera_idx: int, + indexed_files: Dict[Tuple[int, int], str], + fallback_files: Sequence[str], + ) -> str: + match = re.search(r"traj_(\d+)_cam(\d+)", camera_name) + if match: + frame_idx = int(match.group(1)) + cam_idx = int(match.group(2)) + candidate = indexed_files.get((cam_idx, frame_idx)) + if candidate: + return candidate + if camera_idx >= len(fallback_files): + raise IndexError( + f"Camera index {camera_idx} exceeds number of depth files ({len(fallback_files)})." + ) + return fallback_files[camera_idx] + + def __len__(self) -> int: + return len(self.records) + + def get(self, index: int, device: torch.device) -> DepthRecord: + record = self.records[index] + return DepthRecord( + depth=record.depth.to(device, non_blocking=True), + valid_mask=record.valid_mask.to(device, non_blocking=True), + ) + + +def compute_depth_loss( + predicted: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + epsilon: float, +) -> Tuple[torch.Tensor, float, float, int]: + if predicted.shape != target.shape: + target = F.interpolate( + target.unsqueeze(0), + size=predicted.shape[-2:], + mode="bilinear", + align_corners=True, + ).squeeze(0) + mask = F.interpolate( + mask.unsqueeze(0), + size=predicted.shape[-2:], + mode="nearest", + ).squeeze(0) + + valid = mask > 0.5 + valid_pixels = int(valid.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=predicted.device, dtype=predicted.dtype) + return zero, 0.0, 0.0, 0 + + diff = (predicted - target).abs() * mask + loss = diff.sum() / (mask.sum() + epsilon) + mae = diff.sum().item() / (valid_pixels + epsilon) + valid_fraction = valid_pixels / mask.numel() + return loss, mae, valid_fraction, valid_pixels + + +class DepthTrainer: + """Orchestrates depth-supervised optimization of a Gaussian model.""" + + def __init__(self, args: argparse.Namespace) -> None: + self.args = args + self.device = torch.device(args.device) + self._prepare_seeds(args.seed) + + base_cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + data_device=args.data_device, + ) + print(f"[INFO] Loaded {len(base_cameras)} cameras.") + self.cameras_by_scale = _build_scaled_cameras(base_cameras, scales=(1.0, 2.0)) + self.scene = ManualScene(self.cameras_by_scale) + self.cameras = self.cameras_by_scale[1.0] + + self.fixed_view_idx = args.fixed_view_idx + if self.fixed_view_idx is not None: + if not (0 <= self.fixed_view_idx < len(self.cameras)): + raise ValueError( + f"fixed_view_idx {self.fixed_view_idx} out of bounds for {len(self.cameras)} cameras." + ) + + depth_dir = Path(args.depth_dir) + self.depth_provider = DepthMapProvider( + depth_dir=depth_dir, + cameras=self.cameras, + depth_scale=args.depth_scale, + depth_offset=args.depth_offset, + clip_min=args.depth_clip_min, + clip_max=args.depth_clip_max, + ) + if self.depth_provider.global_valid_pixels == 0: + raise RuntimeError("No valid depth pixels found across the dataset.") + print( + "[INFO] Depth statistics after scaling: " + f"{self.depth_provider.global_min:.4f} – {self.depth_provider.global_max:.4f} " + f"({self.depth_provider.global_valid_pixels} valid pixels)" + ) + + self.scene.cameras_extent = self._estimate_extent(args.ply_path) + + self.gaussians = GaussianModel( + sh_degree=args.sh_degree, + use_mip_filter=not args.disable_mip_filter, + learn_occupancy=True, + use_appearance_network=False, + ) + self.gaussians.load_ply(args.ply_path) + ensure_learnable_occupancy(self.gaussians) + self.gaussians.init_culling(len(self.cameras)) + if self.gaussians.spatial_lr_scale <= 0: + self.gaussians.spatial_lr_scale = 1.0 + + opt_parser = argparse.ArgumentParser(add_help=False) + opt_params = OptimizationParams(opt_parser) + opt_params.iterations = args.iterations + opt_params.position_lr_init *= args.initial_lr_scale + opt_params.position_lr_final *= args.initial_lr_scale + self.gaussians.training_setup(opt_params) + if args.freeze_colors: + if hasattr(self.gaussians, "_features_dc"): + self.gaussians._features_dc.requires_grad_(False) + if hasattr(self.gaussians, "_features_rest"): + self.gaussians._features_rest.requires_grad_(False) + + self.background = torch.zeros(3, dtype=torch.float32, device=self.device) + pipe_parser = argparse.ArgumentParser(add_help=False) + self.pipe = PipelineParams(pipe_parser) + self.pipe.compute_cov3D_python = args.compute_cov3d_python + self.pipe.convert_SHs_python = args.convert_shs_python + self.pipe.debug = args.debug + + self.render_view, self.render_for_mesh = build_render_functions( + self.gaussians, self.pipe, self.background + ) + self.mesh_enabled = args.mesh_regularization + if self.mesh_enabled: + mesh_config = self._load_mesh_config( + args.mesh_config, args.mesh_start_iter, args.mesh_stop_iter, args.iterations + ) + occupancy_mode = mesh_config.get("occupancy_mode", "occupancy_shift") + if occupancy_mode != "occupancy_shift": + raise ValueError( + f"Mesh config '{args.mesh_config}' must use occupancy_mode 'occupancy_shift', got '{occupancy_mode}'." + ) + self.gaussians.set_occupancy_mode(occupancy_mode) + self.mesh_renderer, self.mesh_state = initialize_mesh_regularization( + self.scene, + mesh_config, + ) + self.mesh_state["reset_delaunay_samples"] = True + self.mesh_state["reset_sdf_values"] = True + self.mesh_config = mesh_config + self.runtime_args = argparse.Namespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=args.depth_reinit_iter, + ) + self._warmup_mesh_visibility() + else: + self.mesh_renderer = None + self.mesh_state = {} + self.mesh_config = {} + self.runtime_args = None + + self.optimizer = self.gaussians.optimizer + self.opt_params = opt_params + + self.output_dir = Path(args.output_dir) + (self.output_dir / "logs").mkdir(parents=True, exist_ok=True) + self.loss_log_path = self.output_dir / "logs" / "losses.jsonl" + self.pending_indices: List[int] = [] + self.ema_depth: Optional[float] = None + self.ema_mesh: Optional[float] = None + self.printed_depth_diag = False + self.log_depth_stats = bool(args.log_depth_stats or self.fixed_view_idx is not None) + + @staticmethod + def _prepare_seeds(seed: int) -> None: + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + def _estimate_extent(self, ply_path: str) -> float: + import trimesh + + mesh = trimesh.load(ply_path, process=False) + if hasattr(mesh, "vertices"): + vertices = np.asarray(mesh.vertices) + center = vertices.mean(axis=0) + radius = np.linalg.norm(vertices - center, axis=1).max() + return float(radius) + raise ValueError("Could not estimate scene extent from PLY.") + + def _load_mesh_config( + self, + name: str, + start_iter_override: Optional[int], + stop_iter_override: Optional[int], + total_iterations: int, + ) -> Dict: + config = load_mesh_config_file(name) + if start_iter_override is not None: + config["start_iter"] = max(1, start_iter_override) + if stop_iter_override is not None: + config["stop_iter"] = stop_iter_override + else: + config["stop_iter"] = max(config.get("stop_iter", total_iterations), total_iterations) + config["stop_iter"] = max(config["stop_iter"], config.get("start_iter", 1)) + if "occupancy_mode" not in config: + config["occupancy_mode"] = "occupancy_shift" + self.mesh_config = config + return config + + def _check_gaussian_numerics(self, label: str) -> None: + """Detect NaNs/Infs or extreme magnitudes before hitting CUDA kernels.""" + stats = { + "xyz": self.gaussians.get_xyz, + "scaling": self.gaussians.get_scaling, + "rotation": self.gaussians.get_rotation, + "opacity": self.gaussians.get_opacity, + } + for name, tensor in stats.items(): + if not torch.isfinite(tensor).all(): + invalid_mask = ~torch.isfinite(tensor) + num_bad = int(invalid_mask.sum().item()) + example_idx = invalid_mask.nonzero(as_tuple=False)[:5].flatten().tolist() + raise RuntimeError( + f"[NUMERIC] Detected {num_bad} non-finite entries in '{name}' " + f"during {label}. Sample indices: {example_idx}" + ) + max_abs = tensor.abs().max().item() + if max_abs > 1e6: + print( + f"[WARN] Large magnitude detected in '{name}' during {label}: " + f"{max_abs:.3e}" + ) + + def _warmup_mesh_visibility(self) -> None: + warmup_views = self.scene.getTrainCameras_warn_up( + iteration=1, + warn_until_iter=self.args.warn_until_iter, + scale=1.0, + scale2=2.0, + ) + for view in warmup_views: + render_simp( + view, + self.gaussians, + self.pipe, + self.background, + culling=self.gaussians._culling[:, view.uid], + ) + + def _select_view(self) -> int: + if self.fixed_view_idx is not None: + return self.fixed_view_idx + if not self.pending_indices: + self.pending_indices = list(range(len(self.cameras))) + random.shuffle(self.pending_indices) + return self.pending_indices.pop() + + def _log_iteration( + self, + iteration: int, + view_idx: int, + total_loss: float, + depth_loss: float, + mesh_loss: float, + depth_mae: float, + valid_fraction: float, + valid_pixels: int, + extra: Dict[str, float], + ) -> None: + record = { + "iteration": iteration, + "view_index": view_idx, + "total_loss": total_loss, + "depth_loss": depth_loss, + "mesh_loss": mesh_loss, + "depth_mae": depth_mae, + "valid_fraction": valid_fraction, + "valid_pixels": valid_pixels, + } + if self.ema_depth is not None: + record["ema_depth_loss"] = self.ema_depth + if self.ema_mesh is not None: + record["ema_mesh_loss"] = self.ema_mesh + record.update(extra) + with open(self.loss_log_path, "a", encoding="utf-8") as f: + f.write(json.dumps(record) + "\n") + + def run(self) -> None: + os.makedirs(self.output_dir, exist_ok=True) + with open(self.loss_log_path, "w", encoding="utf-8"): + pass + + for iteration in range(1, self.args.iterations + 1): + self.gaussians.update_learning_rate(iteration) + view_idx = self._select_view() + viewpoint = self.cameras[view_idx] + + depth_record = self.depth_provider.get(view_idx, self.device) + render_pkg = self.render_view(viewpoint) + pred_depth = render_pkg["median_depth"] + + depth_loss, depth_mae, valid_fraction, valid_pixels = compute_depth_loss( + predicted=pred_depth, + target=depth_record.depth, + mask=depth_record.valid_mask, + epsilon=self.args.depth_loss_epsilon, + ) + if valid_pixels == 0: + if iteration % self.args.log_interval == 0 or iteration == 1: + print(f"[Iter {iteration:05d}] skip view {view_idx} (no valid depth)") + continue + + mask_valid = depth_record.valid_mask.to(pred_depth.device) > 0.5 + pred_valid = pred_depth[mask_valid] + target_valid = depth_record.depth[mask_valid] + + if not self.printed_depth_diag: + if target_valid.numel() > 0 and pred_valid.numel() > 0: + print( + "[DIAG] First depth batch — target range {t_min:.4f} – {t_max:.4f}, " + "predicted range {p_min:.4f} – {p_max:.4f}".format( + t_min=float(target_valid.min().item()), + t_max=float(target_valid.max().item()), + p_min=float(pred_valid.min().item()), + p_max=float(pred_valid.max().item()), + ) + ) + self.printed_depth_diag = True + + total_loss = depth_loss + mesh_loss_tensor = torch.zeros_like(depth_loss) + mesh_pkg: Dict[str, torch.Tensor] = {} + mesh_active = self.mesh_enabled and iteration >= self.mesh_config.get("start_iter", 1) + if mesh_active: + self._check_gaussian_numerics(f"iter_{iteration}_before_mesh") + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=self.gaussians, + scene=self.scene, + pipe=self.pipe, + background=self.background, + kernel_size=0.0, + config=self.mesh_config, + mesh_renderer=self.mesh_renderer, + mesh_state=self.mesh_state, + render_func=self.render_for_mesh, + weight_adjustment=100.0 / max(self.args.iterations, 1), + args=self.runtime_args, + integrate_func=integrate, + ) + mesh_loss_tensor = mesh_pkg["mesh_loss"] + self.mesh_state = mesh_pkg["updated_state"] + total_loss = total_loss + mesh_loss_tensor + + self.optimizer.zero_grad(set_to_none=True) + total_loss.backward() + if self.args.grad_clip_norm > 0.0: + params: List[torch.Tensor] = [] + for group in self.optimizer.param_groups: + for param in group.get("params", []): + if isinstance(param, torch.Tensor) and param.requires_grad: + params.append(param) + if params: + clip_grad_norm_(params, self.args.grad_clip_norm) + + visibility = render_pkg["visibility_filter"] + radii = render_pkg["radii"] + self.optimizer.step(visibility, radii.shape[0]) + + total_val = float(total_loss.item()) + depth_val = float(depth_loss.item()) + mesh_val = float(mesh_loss_tensor.item()) + self.ema_depth = depth_val if self.ema_depth is None else (0.9 * self.ema_depth + 0.1 * depth_val) + self.ema_mesh = mesh_val if self.ema_mesh is None else (0.9 * self.ema_mesh + 0.1 * mesh_val) + + extra = {k: float(v.item()) for k, v in mesh_pkg.items() if hasattr(v, "item") and k.endswith("_loss")} + if self.log_depth_stats and target_valid.numel() > 0: + extra.update( + { + "pred_depth_min": float(pred_valid.min().item()), + "pred_depth_max": float(pred_valid.max().item()), + "pred_depth_mean": float(pred_valid.mean().item()), + "pred_depth_std": float(pred_valid.std(unbiased=False).item()), + "target_depth_min": float(target_valid.min().item()), + "target_depth_max": float(target_valid.max().item()), + "target_depth_mean": float(target_valid.mean().item()), + "target_depth_std": float(target_valid.std(unbiased=False).item()), + } + ) + self._log_iteration( + iteration=iteration, + view_idx=view_idx, + total_loss=total_val, + depth_loss=depth_val, + mesh_loss=mesh_val, + depth_mae=depth_mae, + valid_fraction=valid_fraction, + valid_pixels=valid_pixels, + extra=extra, + ) + + if iteration % self.args.log_interval == 0 or iteration == 1: + print( + "[Iter {iter:05d}] loss={loss:.6f} depth={depth:.6f} mesh={mesh:.6f} " + "mae={mae:.6f} valid={valid:.3f}".format( + iter=iteration, + loss=total_val, + depth=depth_val, + mesh=mesh_val, + mae=depth_mae, + valid=valid_fraction, + ) + ) + + if mesh_active and mesh_pkg.get("gaussians_changed", False): + self.mesh_state = reset_mesh_state_at_next_iteration(self.mesh_state) + + if self.args.export_interval > 0 and iteration % self.args.export_interval == 0: + self._export_state(iteration) + + self._export_state(self.args.iterations, final=True) + + def _sink_path(self, iteration: int, final: bool = False) -> Path: + + target_dir = self.output_dir / ("final" if final else f"iter_{iteration:05d}") + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir + + def _export_state(self, iteration: int, final: bool = False) -> None: + target_dir = self._sink_path(iteration, final) + ply_path = target_dir / f"gaussians_iter_{iteration:05d}.ply" + save_mesh = ( + self.mesh_enabled + and (iteration >= self.mesh_config.get("start_iter", 1) or final) + and self.mesh_state + ) + if save_mesh and self.mesh_state.get("delaunay_tets") is not None: + mesh_path = target_dir / f"mesh_iter_{iteration:05d}.ply" + export_mesh_from_gaussians( + gaussians=self.gaussians, + mesh_state=self.mesh_state, + output_path=str(mesh_path), + reference_camera=None, + ) + self.gaussians.save_ply(str(ply_path)) + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Depth-supervised training for Gaussian Splatting.") + parser.add_argument("--ply_path", type=str, required=True, help="Initial Gaussian PLY file.") + parser.add_argument("--camera_poses", type=str, required=True, help="Camera pose JSON compatible with ply2mesh.load_cameras_from_json.") + parser.add_argument("--depth_dir", type=str, required=True, help="Folder of per-view depth .npy files.") + parser.add_argument("--output_dir", type=str, default="./depth_training_output", help="Directory for logs and exports.") + parser.add_argument("--iterations", type=int, default=5000, help="Number of optimization steps.") + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--device", type=str, default="cuda", help="PyTorch device identifier.") + parser.add_argument("--data_device", type=str, default="cpu", help="Device to store camera image tensors.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered height.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field of view in degrees.") + parser.add_argument("--depth_scale", type=float, default=1.0, help="Scale factor applied to loaded depth maps.") + parser.add_argument("--depth_offset", type=float, default=0.0, help="Additive offset applied to depth.") + parser.add_argument("--depth_clip_min", type=float, default=None, help="Minimum depth after scaling.") + parser.add_argument("--depth_clip_max", type=float, default=None, help="Maximum depth after scaling.") + parser.add_argument("--depth_loss_epsilon", type=float, default=1e-6, help="Stability epsilon for depth loss denominator.") + parser.add_argument("--mesh_config", type=str, default="medium", help="Mesh-in-the-loop configuration name.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration to start mesh regularization.") + parser.add_argument("--mesh_stop_iter", type=int, default=None, help="Iteration to stop mesh regularization.") + parser.add_argument("--export_interval", type=int, default=1000, help="Export mesh/ply every N iterations.") + parser.add_argument("--log_interval", type=int, default=100, help="Console log interval.") + parser.add_argument("--grad_clip_norm", type=float, default=0.0, help="Gradient clipping norm (0 disables).") + parser.add_argument("--initial_lr_scale", type=float, default=1.0, help="Scaling factor for position learning rate.") + parser.add_argument("--convert_shs_python", action="store_true", help="Use PyTorch SH conversion (debug only).") + parser.add_argument("--compute_cov3d_python", action="store_true", help="Use PyTorch covariance (debug only).") + parser.add_argument("--debug", action="store_true", help="Enable renderer debug outputs.") + parser.add_argument("--disable_mip_filter", action="store_true", help="Disable 3D Mip filter.") + parser.add_argument("--sh_degree", type=int, default=0, help="Spherical harmonic degree for Gaussian colors.") + parser.add_argument("--mesh_regularization", action="store_true", help="Enable mesh-in-the-loop regularization.") + parser.add_argument("--freeze_colors", dest="freeze_colors", action="store_true", help="Freeze SH features during depth training.", default=True) + parser.add_argument("--no-freeze_colors", dest="freeze_colors", action="store_false", help="Allow SH features to be optimized.") + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for densification/mesh utilities.") + parser.add_argument("--imp_metric", type=str, default="outdoor", choices=["outdoor", "indoor"], help="Importance metric for mesh sampling heuristics.") + parser.add_argument("--depth_reinit_iter", type=int, default=2000, help="Iteration to trigger optional depth reinitialization routines.") + parser.add_argument("--fixed_view_idx", type=int, default=None, help="If provided, always train on this camera index (for debugging).") + parser.add_argument("--log_depth_stats", action="store_true", help="Record detailed depth statistics per iteration.") + return parser + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + safe_state(False) + trainer = DepthTrainer(args) + trainer.run() + + +if __name__ == "__main__": + main() diff --git a/milo/useless_maybe/downsample_colmap_points.py b/milo/useless_maybe/downsample_colmap_points.py new file mode 100644 index 0000000..8180ecb --- /dev/null +++ b/milo/useless_maybe/downsample_colmap_points.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +"""Downsample COLMAP points3D.bin (and optionally regenerate points3D.ply).""" + +from __future__ import annotations + +import argparse +import os +import shutil +import struct +from pathlib import Path +from typing import Optional + +import numpy as np +from plyfile import PlyData, PlyElement + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--input-bin", + required=True, + help="Path to the COLMAP points3D.bin file", + ) + parser.add_argument( + "--output-bin", + help=( + "Path to the output points3D.bin file. Defaults to overwriting the input " + "after creating a .bak backup." + ), + ) + parser.add_argument( + "--target", + type=int, + default=4_000_000, + help="Maximum number of points to keep (default: 4,000,000)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible sampling (default: 42)", + ) + parser.add_argument( + "--mode", + choices=("random", "radius"), + default="random", + help=( + "Sampling strategy: 'random' selects uniformly at random; " + "'radius' keeps points closest to the centroid (default: random)" + ), + ) + parser.add_argument( + "--ply-output", + help=( + "Optional path to write a matching points3D.ply. Defaults to the sibling " + "points3D.ply next to the binary file when omitted. Use '--ply-output '' ' to skip." + ), + ) + return parser.parse_args() + + +RECORD_STRUCT = struct.Struct(" np.ndarray: + """Read point positions once to compute spatial statistics.""" + positions = np.empty((num_points, 3), dtype=np.float32) + with path.open("rb") as fin: + fin.seek(8) # Skip count header + for idx in range(num_points): + record_bytes = fin.read(RECORD_STRUCT.size) + if not record_bytes: + raise EOFError("Unexpected end of file while gathering positions") + _, x, y, z, _, _, _, _ = RECORD_STRUCT.unpack(record_bytes) + positions[idx] = (x, y, z) + (track_len,) = TRACK_LEN_STRUCT.unpack(fin.read(TRACK_LEN_STRUCT.size)) + fin.seek(8 * track_len, os.SEEK_CUR) + return positions + + +def pick_indices( + total: int, + target: int, + rng: np.random.Generator, + mode: str, + positions: np.ndarray | None = None, +) -> np.ndarray: + if target <= 0 or total <= target: + return np.arange(total, dtype=np.int64) + if mode == "random": + indices = rng.choice(total, size=target, replace=False) + elif mode == "radius": + if positions is None: + raise ValueError("Positions are required for radius-based sampling.") + centroid = positions.mean(axis=0, dtype=np.float64) + dists = np.sum((positions - centroid) ** 2, axis=1) + partition = np.argpartition(dists, target - 1)[:target] + indices = np.sort(partition) + else: + raise ValueError(f"Unknown sampling mode: {mode}") + if not np.all(np.diff(indices) >= 0): + indices.sort() + return indices.astype(np.int64) + + +def write_ply(path: Path, positions: np.ndarray, colors: np.ndarray) -> None: + count = positions.shape[0] + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + elements = np.empty(count, dtype=dtype) + elements["x"] = positions[:, 0] + elements["y"] = positions[:, 1] + elements["z"] = positions[:, 2] + elements["nx"] = 0.0 + elements["ny"] = 0.0 + elements["nz"] = 0.0 + elements["red"] = colors[:, 0] + elements["green"] = colors[:, 1] + elements["blue"] = colors[:, 2] + ply = PlyData([PlyElement.describe(elements, "vertex")], text=False) + ply.write(str(path)) + + +def main() -> None: + args = parse_args() + input_bin = Path(args.input_bin).expanduser().resolve() + output_bin = Path(args.output_bin).expanduser().resolve() if args.output_bin else input_bin + + if not input_bin.exists(): + raise FileNotFoundError(f"Input binary file not found: {input_bin}") + + rng = np.random.default_rng(args.seed) + + with input_bin.open("rb") as fin: + num_points = struct.unpack(" 0 else None + + with input_bin.open("rb") as fin, tmp_bin.open("wb") as fout: + fin.seek(8) + fout.write(struct.pack(" argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--input", required=True, help="Path to the source PLY file") + parser.add_argument( + "--output", + help=( + "Path to the output PLY file. If omitted, the input is overwritten " + "after creating a backup with suffix .bak" + ), + ) + parser.add_argument( + "--target", + type=int, + default=4_000_000, + help="Desired maximum number of points (default: 4,000,000)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible sampling (default: 42)", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + input_path = Path(args.input).expanduser().resolve() + if args.output: + output_path = Path(args.output).expanduser().resolve() + else: + output_path = input_path + + if not input_path.exists(): + raise FileNotFoundError(f"Input PLY not found: {input_path}") + + print(f"[INFO] Loading PLY: {input_path}") + ply = PlyData.read(str(input_path)) + + if "vertex" not in ply: + raise ValueError("Input PLY does not contain a vertex element") + + vertex_data = ply["vertex"] + total_vertices = len(vertex_data) + target = max(int(args.target), 0) + print(f"[INFO] Total vertices: {total_vertices}") + print(f"[INFO] Target vertices: {target}") + + if target == 0 or total_vertices <= target: + print("[INFO] No downsampling needed.") + if output_path != input_path: + print(f"[INFO] Copying file to {output_path}") + shutil.copyfile(input_path, output_path) + else: + print("[INFO] Input already satisfies target; nothing to do.") + return + + rng = np.random.default_rng(args.seed) + print("[INFO] Sampling indices...") + sample_indices = rng.choice(total_vertices, size=target, replace=False) + sample_indices.sort() + downsampled_vertex = vertex_data[sample_indices] + + print("[INFO] Preparing PLY structure...") + new_vertex_element = PlyElement.describe(downsampled_vertex, "vertex") + new_ply = PlyData([new_vertex_element], text=ply.text, byte_order=ply.byte_order) + new_ply.comments = ply.comments + + if output_path == input_path: + backup_path = input_path.with_suffix(input_path.suffix + ".bak") + if not backup_path.exists(): + print(f"[INFO] Creating backup at {backup_path}") + shutil.copyfile(input_path, backup_path) + else: + print(f"[WARNING] Backup already exists at {backup_path}; reusing it.") + + os.makedirs(output_path.parent, exist_ok=True) + print(f"[INFO] Writing downsampled PLY to {output_path}") + new_ply.write(str(output_path)) + print("[INFO] Done.") + + +if __name__ == "__main__": + main() diff --git a/milo/useless_maybe/iterative_occupancy_refine.py b/milo/useless_maybe/iterative_occupancy_refine.py new file mode 100644 index 0000000..2eaacf9 --- /dev/null +++ b/milo/useless_maybe/iterative_occupancy_refine.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +"""Iteratively refine learnable occupancy (SDF) while keeping Gaussian geometry fixed.""" + +import argparse +import json +import os +import random +from types import SimpleNamespace + +import numpy as np +import torch +import torch.nn.functional as F + +from arguments import PipelineParams +from gaussian_renderer.radegs import integrate_radegs +from ply2mesh import ( + ManualScene, + load_cameras_from_json, + freeze_gaussian_rigid_parameters, + build_render_functions, + load_mesh_config_file, + export_mesh_from_gaussians, +) +from regularization.regularizer.mesh import initialize_mesh_regularization, compute_mesh_regularization +from regularization.sdf.learnable import convert_occupancy_to_sdf +from scene.gaussian_model import GaussianModel +from utils.geometry_utils import flatten_voronoi_features + + +def ensure_learnable_occupancy(gaussians: GaussianModel) -> None: + """Ensure occupancy buffers exist and only the shift is trainable.""" + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + device = gaussians._xyz.device + n_pts = gaussians._xyz.shape[0] + base = torch.zeros((n_pts, 9), device=device) + shift = torch.zeros_like(base) + gaussians.learn_occupancy = True + gaussians._base_occupancy = torch.nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = torch.nn.Parameter(shift.requires_grad_(True)) + gaussians.set_occupancy_mode("occupancy_shift") + gaussians._occupancy_shift.requires_grad_(True) + + +def extract_loss_scalars(metrics: dict) -> dict: + """Extract scalar loss values from the mesh regularization outputs.""" + scalars = {} + for key, value in metrics.items(): + if not key.endswith("_loss"): + continue + scalar = None + if isinstance(value, torch.Tensor): + if value.ndim == 0: + scalar = float(value.item()) + elif isinstance(value, (float, int)): + scalar = float(value) + if scalar is not None: + scalars[key] = scalar + return scalars + + +def export_iteration_state( + iteration: int, + gaussians: GaussianModel, + mesh_state: dict, + output_dir: str, + reference_camera=None, +) -> None: + os.makedirs(output_dir, exist_ok=True) + mesh_path = os.path.join(output_dir, f"mesh_iter_{iteration:05d}.ply") + ply_path = os.path.join(output_dir, f"gaussians_iter_{iteration:05d}.ply") + + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=mesh_path, + reference_camera=reference_camera, + ) + gaussians.save_ply(ply_path) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Occupancy-only refinement from a pretrained Gaussian PLY.") + parser.add_argument("--ply_path", type=str, required=True, help="Input Gaussian PLY (assumed geometrically correct).") + parser.add_argument("--camera_poses", type=str, required=True, help="JSON with camera poses matching the scene.") + parser.add_argument("--mesh_config", type=str, default="default", help="Mesh regularization config name.") + parser.add_argument("--iterations", type=int, default=2000, help="Number of optimization steps.") + parser.add_argument("--occupancy_lr", type=float, default=0.001, help="Learning rate for occupancy shift.") + parser.add_argument( + "--mesh_loss_weight", + type=float, + default=5.0, + help="Global weight applied to the mesh regularization loss.", + ) + parser.add_argument("--log_interval", type=int, default=100, help="Logging interval.") + parser.add_argument("--export_interval", type=int, default=1000, help="Mesh export interval (0 disables periodic export).") + parser.add_argument("--output_dir", type=str, default="./occupancy_refine_output", help="Directory to store outputs.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field-of-view in degrees.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered image width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered image height.") + parser.add_argument( + "--background", + type=float, + nargs=3, + default=(0.0, 0.0, 0.0), + help="Background color used for rendering (RGB).", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration at which mesh regularization starts.") + parser.add_argument( + "--mesh_stop_iter", + type=int, + default=None, + help="Iteration after which mesh regularization stops (defaults to total iterations).", + ) + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for surface sampling.") + parser.add_argument( + "--imp_metric", + type=str, + default="outdoor", + choices=["outdoor", "indoor"], + help="Importance metric used for surface sampling.", + ) + parser.add_argument( + "--cull_on_export", + action="store_true", + help="Frustum cull meshes using the first camera before export.", + ) + parser.add_argument( + "--sdf_log_samples", + type=int, + default=32, + help="Number of SDF values recorded per iteration (0 disables sampling).", + ) + parser.add_argument( + "--loss_log_filename", + type=str, + default="losses.jsonl", + help="Filename used for per-iteration loss logs.", + ) + parser.add_argument( + "--sdf_log_filename", + type=str, + default="sdf_samples.jsonl", + help="Filename used for per-iteration SDF sample logs.", + ) + parser.add_argument( + "--surface_gaussians_filename", + type=str, + default="surface_gaussians_initial.ply", + help="Filename for the first batch of surface Gaussians (empty string disables export).", + ) + return parser + + +def main() -> None: + parser = build_arg_parser() + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device is required for occupancy refinement.") + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + ) + print(f"[INFO] Loaded {len(cameras)} cameras from {args.camera_poses}.") + + scene = ManualScene(cameras) + + mesh_config = load_mesh_config_file(args.mesh_config) + mesh_config["start_iter"] = max(1, args.mesh_start_iter) + if args.mesh_stop_iter is not None: + mesh_config["stop_iter"] = args.mesh_stop_iter + else: + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", args.iterations), args.iterations) + + pipe_parser = argparse.ArgumentParser() + pipe = PipelineParams(pipe_parser) + background = torch.tensor(args.background, dtype=torch.float32, device="cuda") + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + print(f"[INFO] Loaded {gaussians._xyz.shape[0]} Gaussians from {args.ply_path}.") + + ensure_learnable_occupancy(gaussians) + + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + + gaussians.init_culling(len(cameras)) + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + freeze_gaussian_rigid_parameters(gaussians) + + optimizer = torch.optim.Adam([gaussians._occupancy_shift], lr=args.occupancy_lr) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + mesh_renderer, mesh_state = initialize_mesh_regularization(scene, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + surface_gaussians_path = None + if args.surface_gaussians_filename: + surface_gaussians_path = os.path.join(args.output_dir, args.surface_gaussians_filename) + print(f"[INFO] Will export first sampled surface Gaussians to {surface_gaussians_path}.") + mesh_state["surface_sample_export_path"] = surface_gaussians_path + mesh_state["surface_sample_saved"] = False + mesh_state["surface_sample_saved_iter"] = None + + runtime_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=getattr(args, "depth_reinit_iter", args.warn_until_iter), + ) + + os.makedirs(args.output_dir, exist_ok=True) + log_dir = os.path.join(args.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + loss_log_path = os.path.join(log_dir, args.loss_log_filename) + sdf_log_path = os.path.join(log_dir, args.sdf_log_filename) + + ema_loss = None + pending_view_indices: list[int] = [] + sdf_sample_indices_tensor = None # Stored on the same device as pivots_sdf_flat + sdf_sample_indices_list = None + + with open(loss_log_path, "w", encoding="utf-8") as loss_log_file, open( + sdf_log_path, "w", encoding="utf-8" + ) as sdf_log_file: + # Iterate through all cameras without replacement; reshuffle when one pass finishes. + for iteration in range(1, args.iterations + 1): + if not pending_view_indices: + pending_view_indices = list(range(len(cameras))) + random.shuffle(pending_view_indices) + + view_idx = pending_view_indices.pop() + viewpoint = cameras[view_idx] + + render_pkg = render_view(viewpoint) + + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=gaussians, + scene=scene, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=100.0 / max(args.iterations, 1), + args=runtime_args, + integrate_func=integrate_radegs, + ) + mesh_state = mesh_pkg["updated_state"] + + with torch.no_grad(): + current_occ = torch.sigmoid(gaussians._base_occupancy + gaussians._occupancy_shift) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(current_occ)) + pivots_sdf_flat = pivots_sdf.view(-1).detach() + if pivots_sdf_flat.numel() > 0: + sdf_mean = float(pivots_sdf_flat.mean().item()) + sdf_std = float(pivots_sdf_flat.std(unbiased=False).item()) + else: + sdf_mean = 0.0 + sdf_std = 0.0 + + sample_indices_list = [] + sample_values_list = [] + if args.sdf_log_samples > 0 and pivots_sdf_flat.numel() > 0: + sample_count = min(args.sdf_log_samples, pivots_sdf_flat.numel()) + need_refresh = sdf_sample_indices_tensor is None or sdf_sample_indices_tensor.numel() != sample_count + if not need_refresh: + max_index = int(sdf_sample_indices_tensor.max().item()) + need_refresh = max_index >= pivots_sdf_flat.numel() + if need_refresh: + # Draw once so the same subset of pivots is tracked across iterations. + sdf_sample_indices_tensor = torch.randperm( + pivots_sdf_flat.shape[0], device=pivots_sdf_flat.device + )[:sample_count] + sdf_sample_indices_list = sdf_sample_indices_tensor.detach().cpu().tolist() + else: + if sdf_sample_indices_tensor.device != pivots_sdf_flat.device: + sdf_sample_indices_tensor = sdf_sample_indices_tensor.to( + pivots_sdf_flat.device, non_blocking=True + ) + sample_values = pivots_sdf_flat[sdf_sample_indices_tensor] + sample_indices_list = sdf_sample_indices_list or [] + sample_values_list = sample_values.cpu().tolist() + + raw_mesh_loss = mesh_pkg["mesh_loss"] + loss = args.mesh_loss_weight * raw_mesh_loss + loss_value = float(loss.item()) + raw_loss_value = float(raw_mesh_loss.item()) + + loss_scalars = extract_loss_scalars(mesh_pkg) + skip_iteration = ( + mesh_pkg.get("mesh_triangles") is not None and mesh_pkg["mesh_triangles"].numel() == 0 + ) + + iteration_record = { + "iteration": iteration, + "view_index": view_idx, + "total_loss": loss_value, + "raw_mesh_loss": raw_loss_value, + "sdf_mean": sdf_mean, + "sdf_std": sdf_std, + "skipped": bool(skip_iteration), + } + if ema_loss is not None: + iteration_record["ema_loss"] = ema_loss + iteration_record.update(loss_scalars) + + sdf_record = { + "iteration": iteration, + "sdf_mean": sdf_mean, + "sdf_std": sdf_std, + "sample_count": len(sample_values_list), + "sample_indices": sample_indices_list, + "sample_values": sample_values_list, + } + + if skip_iteration: + iteration_record["skipped_reason"] = "empty_mesh" + loss_log_file.write(json.dumps(iteration_record) + "\n") + loss_log_file.flush() + sdf_record["skipped"] = True + sdf_log_file.write(json.dumps(sdf_record) + "\n") + sdf_log_file.flush() + print(f"[WARNING] Empty mesh at iteration {iteration}; skipping optimizer step.") + continue + + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + ema_loss = loss_value if ema_loss is None else (0.9 * ema_loss + 0.1 * loss_value) + iteration_record["ema_loss"] = ema_loss + + loss_log_file.write(json.dumps(iteration_record) + "\n") + loss_log_file.flush() + + sdf_log_file.write(json.dumps(sdf_record) + "\n") + sdf_log_file.flush() + + if iteration % args.log_interval == 0 or iteration == 1: + mesh_depth = loss_scalars.get("mesh_depth_loss", 0.0) + mesh_normal = loss_scalars.get("mesh_normal_loss", 0.0) + occupied_centers = loss_scalars.get("occupied_centers_loss", 0.0) + occupancy_labels = loss_scalars.get("occupancy_labels_loss", 0.0) + + print( + "[Iter {iter:05d}] loss={loss:.6f} ema={ema:.6f} depth={depth:.6f} " + "normal={normal:.6f} occ_centers={centers:.6f} labels={labels:.6f} " + "sdf_mean={sdf_mean:.6f} mesh_raw={raw_mesh:.6f}".format( + iter=iteration, + loss=loss_value, + ema=ema_loss, + depth=mesh_depth, + normal=mesh_normal, + centers=occupied_centers, + labels=occupancy_labels, + sdf_mean=sdf_mean, + raw_mesh=raw_loss_value, + ) + ) + + if args.export_interval > 0 and iteration % args.export_interval == 0: + export_iteration_state( + iteration=iteration, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=args.output_dir, + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + if surface_gaussians_path and not mesh_state.get("surface_sample_saved", False): + print( + "[WARNING] Requested export of surface Gaussians but no samples were saved. " + "Verify surface sampling settings." + ) + + final_dir = os.path.join(args.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + export_iteration_state( + iteration=args.iterations, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=final_dir, + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + print(f"[INFO] Occupancy refinement completed. Results saved to {args.output_dir}.") + + +if __name__ == "__main__": + main() diff --git a/milo/mesh_extract_integration.py b/milo/useless_maybe/mesh_extract_integration.py similarity index 100% rename from milo/mesh_extract_integration.py rename to milo/useless_maybe/mesh_extract_integration.py diff --git a/milo/useless_maybe/ply2mesh.py b/milo/useless_maybe/ply2mesh.py new file mode 100644 index 0000000..d16f1c6 --- /dev/null +++ b/milo/useless_maybe/ply2mesh.py @@ -0,0 +1,459 @@ +import os +import json +import math +import random +from argparse import ArgumentParser +from typing import List, Optional, Sequence + +import yaml +from types import SimpleNamespace +import torch +import torch.nn as nn +import numpy as np +import trimesh + +from arguments import PipelineParams +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +from regularization.regularizer.mesh import initialize_mesh_regularization, compute_mesh_regularization +from functional import extract_mesh, compute_delaunay_triangulation +from functional.mesh import frustum_cull_mesh +from regularization.sdf.learnable import convert_occupancy_to_sdf +from utils.geometry_utils import flatten_voronoi_features +from gaussian_renderer.radegs import render_radegs, integrate_radegs + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def quaternion_to_rotation_matrix(q: Sequence[float]) -> np.ndarray: + """Convert a unit quaternion [w, x, y, z] to a 3x3 rotation matrix.""" + q = np.asarray(q, dtype=np.float64) + if q.shape != (4,): + raise ValueError("Quaternion must have shape (4,)") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float32, + ) + return rotation + + +class ManualScene: + """Minimal scene wrapper exposing the API expected by mesh regularization utilities.""" + + def __init__(self, cameras: Sequence[Camera]): + self._train_cameras = list(cameras) + + def getTrainCameras(self, scale: float = 1.0): + return list(self._train_cameras) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ): + return list(self._train_cameras) + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, +) -> List[Camera]: + if not os.path.isfile(json_path): + raise FileNotFoundError(f"Camera JSON not found: {json_path}") + + with open(json_path, "r", encoding="utf-8") as f: + camera_entries = json.load(f) + + if not camera_entries: + raise ValueError(f"No camera entries found in {json_path}") + + fov_y = math.radians(fov_y_deg) + aspect_ratio = image_width / image_height + fov_x = 2.0 * math.atan(aspect_ratio * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(camera_entries): + if "quaternion" in entry: + rotation = quaternion_to_rotation_matrix(entry["quaternion"]) + elif "rotation" in entry: + rotation = np.asarray(entry["rotation"], dtype=np.float32) + if rotation.shape != (3, 3): + raise ValueError(f"Camera entry {idx} rotation must be 3x3, got {rotation.shape}") + else: + raise KeyError(f"Camera entry {idx} must provide either 'quaternion' or 'rotation'.") + + translation = None + if "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + elif "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"Camera entry {idx} position must be length-3, got shape {camera_center.shape}") + # Camera expects world-to-view translation (COLMAP convention t = -R * C). + rotation_w2c = rotation.T # rotation is camera-to-world + translation = -rotation_w2c @ camera_center + else: + raise KeyError(f"Camera entry {idx} must provide 'position', 'translation', or 'tvec'.") + + if translation.shape != (3,): + raise ValueError(f"Camera entry {idx} translation must be length-3, got shape {translation.shape}") + + image_name = ( + entry.get("name") + or entry.get("img_name") + or entry.get("image_name") + or f"view_{idx:04d}" + ) + camera = Camera( + colmap_id=str(idx), + R=rotation, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device="cuda", + ) + cameras.append(camera) + return cameras + + +def freeze_gaussian_rigid_parameters(gaussians: GaussianModel) -> None: + """Disable gradients for geometric and appearance parameters, keeping occupancy shift trainable.""" + freeze_attrs = [ + "_xyz", + "_features_dc", + "_features_rest", + "_opacity", + "_scaling", + "_rotation", + ] + for attr in freeze_attrs: + param = getattr(gaussians, attr, None) + if isinstance(param, nn.Parameter): + param.requires_grad_(False) + + if hasattr(gaussians, "_base_occupancy") and isinstance(gaussians._base_occupancy, nn.Parameter): + gaussians._base_occupancy.requires_grad_(False) + if hasattr(gaussians, "_occupancy_shift") and isinstance(gaussians._occupancy_shift, nn.Parameter): + gaussians._occupancy_shift.requires_grad_(True) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + def _render( + view: Camera, + pc_obj: GaussianModel, + pipe_obj: PipelineParams, + bg_color: torch.Tensor, + *, + kernel_size: float = 0.0, + require_coord: bool = False, + require_depth: bool = True, + ): + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_view(view: Camera): + return _render(view, gaussians, pipe, background) + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = _render( + view, + pc_obj, + pipe_obj, + bg_color, + kernel_size=kernel_size, + require_coord=require_coord, + require_depth=require_depth, + ) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def load_mesh_config_file(name: str) -> dict: + config_path = os.path.join(BASE_DIR, "configs", "mesh", f"{name}.yaml") + if not os.path.isfile(config_path): + raise FileNotFoundError(f"Mesh config not found: {config_path}") + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def export_mesh_from_gaussians( + gaussians: GaussianModel, + mesh_state: dict, + output_path: str, + reference_camera: Optional[Camera] = None, +) -> None: + delaunay_tets = mesh_state.get("delaunay_tets") + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_save = mesh + if reference_camera is not None: + mesh_to_save = frustum_cull_mesh(mesh, reference_camera) + + verts = mesh_to_save.verts.detach().cpu().numpy() + faces = mesh_to_save.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def train_occupancy_only(args) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device is required for occupancy fine-tuning.") + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + ) + print(f"[INFO] Loaded {len(cameras)} cameras from {args.camera_poses}.") + + scene = ManualScene(cameras) + + mesh_config = load_mesh_config_file(args.mesh_config) + mesh_config["start_iter"] = max(1, args.mesh_start_iter) + if args.mesh_stop_iter is not None: + mesh_config["stop_iter"] = args.mesh_stop_iter + else: + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", args.iterations), args.iterations) + + pipe_parser = ArgumentParser() + pipe = PipelineParams(pipe_parser) + + background = torch.tensor(args.background, dtype=torch.float32, device="cuda") + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + print(f"[INFO] Loaded {gaussians._xyz.shape[0]} Gaussians from {args.ply_path}.") + + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + print("[INFO] PLY does not provide occupancy buffers; initializing them to zeros.") + gaussians.learn_occupancy = True + base_occupancy = torch.zeros((gaussians._xyz.shape[0], 9), device=gaussians._xyz.device) + occupancy_shift = torch.zeros_like(base_occupancy) + gaussians._base_occupancy = nn.Parameter(base_occupancy.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = nn.Parameter(occupancy_shift.requires_grad_(True)) + + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + + gaussians.init_culling(len(cameras)) + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + freeze_gaussian_rigid_parameters(gaussians) + + optimizer = torch.optim.Adam([gaussians._occupancy_shift], lr=args.occupancy_lr) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + mesh_renderer, mesh_state = initialize_mesh_regularization(scene, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + runtime_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=getattr(args, 'depth_reinit_iter', args.warn_until_iter), + ) + + os.makedirs(args.output_dir, exist_ok=True) + + ema_loss: Optional[float] = None + + for iteration in range(1, args.iterations + 1): + view_idx = random.randrange(len(cameras)) + viewpoint = cameras[view_idx] + + render_pkg = render_view(viewpoint) + + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=gaussians, + scene=scene, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=100.0 / max(args.iterations, 1), + args=runtime_args, + integrate_func=integrate_radegs, + ) + mesh_state = mesh_pkg["updated_state"] + + if mesh_pkg.get("mesh_triangles") is not None and mesh_pkg["mesh_triangles"].numel() == 0: + print(f"[WARNING] Empty mesh at iteration {iteration}; skipping optimizer step.") + continue + + loss = mesh_pkg["mesh_loss"] + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + loss_value = float(loss.item()) + ema_loss = loss_value if ema_loss is None else (0.9 * ema_loss + 0.1 * loss_value) + + if iteration % args.log_interval == 0 or iteration == 1: + print( + "[Iter {iter:05d}] loss={loss:.6f} ema={ema:.6f} depth={depth:.6f} " + "normal={normal:.6f} occ_centers={centers:.6f} labels={labels:.6f}".format( + iter=iteration, + loss=loss_value, + ema=ema_loss, + depth=mesh_pkg["mesh_depth_loss"].item(), + normal=mesh_pkg["mesh_normal_loss"].item(), + centers=mesh_pkg["occupied_centers_loss"].item(), + labels=mesh_pkg["occupancy_labels_loss"].item(), + ) + ) + + if args.export_interval > 0 and iteration % args.export_interval == 0: + iteration_dir = os.path.join(args.output_dir, f"iter_{iteration:05d}") + os.makedirs(iteration_dir, exist_ok=True) + gaussians.save_ply(os.path.join(iteration_dir, "point_cloud.ply")) + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=os.path.join(iteration_dir, "mesh.ply"), + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + final_dir = os.path.join(args.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + gaussians.save_ply(os.path.join(final_dir, "point_cloud.ply")) + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=os.path.join(final_dir, "mesh.ply"), + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + print(f"[INFO] Occupancy-only training completed. Results saved to {args.output_dir}.") + + +def build_arg_parser() -> ArgumentParser: + parser = ArgumentParser(description="Occupancy-only fine-tuning from a pretrained Gaussian PLY.") + parser.add_argument("--ply_path", type=str, required=True, help="Input PLY file with pretrained Gaussians.") + parser.add_argument("--camera_poses", type=str, required=True, help="JSON file containing camera poses.") + parser.add_argument("--mesh_config", type=str, default="default", help="Mesh regularization config name.") + parser.add_argument("--iterations", type=int, default=2000, help="Number of optimization steps.") + parser.add_argument("--occupancy_lr", type=float, default=0.01, help="Learning rate for occupancy shift.") + parser.add_argument("--log_interval", type=int, default=50, help="Console logging interval.") + parser.add_argument("--export_interval", type=int, default=200, help="Mesh/PLY export interval.") + parser.add_argument("--output_dir", type=str, default="./ply2mesh_output", help="Directory to store outputs.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field-of-view in degrees.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered image width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered image height.") + parser.add_argument( + "--background", + type=float, + nargs=3, + default=(0.0, 0.0, 0.0), + help="Background color used for rendering (RGB).", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration at which mesh regularization starts.") + parser.add_argument( + "--mesh_stop_iter", + type=int, + default=None, + help="Iteration after which mesh regularization stops (defaults to total iterations).", + ) + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for surface sampling.") + parser.add_argument("--imp_metric", type=str, default="outdoor", choices=["outdoor", "indoor"], help="Importance metric for surface sampling.") + parser.add_argument("--cull_on_export", action="store_true", help="Enable frustum culling using the first camera before exporting meshes.") + return parser + + +if __name__ == "__main__": + argument_parser = build_arg_parser() + parsed_args = argument_parser.parse_args() + train_occupancy_only(parsed_args) diff --git a/milo/useless_maybe/yufu2mesh_iterative.py b/milo/useless_maybe/yufu2mesh_iterative.py new file mode 100644 index 0000000..b95f7bf --- /dev/null +++ b/milo/useless_maybe/yufu2mesh_iterative.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +"""Iteratively optimize SDF pivots so that the extracted mesh adheres to the provided Gaussian point cloud.""" + +import argparse +import json +import math +import os +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +import trimesh + +from arguments import PipelineParams +from functional import ( + sample_gaussians_on_surface, + extract_gaussian_pivots, + compute_initial_sdf_values, + compute_delaunay_triangulation, + extract_mesh, +) +from gaussian_renderer.radegs import render_radegs +from regularization.sdf.learnable import convert_sdf_to_occupancy +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel + + +def quaternion_to_rotation_matrix(q: List[float]) -> np.ndarray: + """Convert unit quaternion [w, x, y, z] to a rotation matrix.""" + w, x, y, z = q + xx, yy, zz = x * x, y * y, z * z + xy, xz, yz = x * y, x * z, y * z + wx, wy, wz = w * x, w * y, w * z + return np.array([ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], + ]) + + +def load_cameras( + poses_json: str, + height: int, + width: int, + fov_y_deg: float, + device: str, +) -> List[Camera]: + with open(poses_json, "r", encoding="utf-8") as f: + poses = json.load(f) + + fov_y = math.radians(fov_y_deg) + aspect = width / height + fov_x = 2.0 * math.atan(aspect * math.tan(fov_y / 2.0)) + + cameras: List[Camera] = [] + for idx, info in enumerate(poses): + cam = Camera( + colmap_id=str(idx), + R=quaternion_to_rotation_matrix(info["quaternion"]), + T=np.asarray(info["position"]), + FoVx=fov_x, + FoVy=fov_y, + image=torch.empty(3, height, width), + gt_alpha_mask=None, + image_name=info.get("name", f"view_{idx:05d}"), + uid=idx, + data_device=device, + ) + cameras.append(cam) + return cameras + + +def build_render_function( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + def render_func(view: Camera): + render_pkg = render_radegs( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + kernel_size=0.0, + scaling_modifier=1.0, + require_coord=False, + require_depth=True, + ) + return {"render": render_pkg["render"], "depth": render_pkg["median_depth"]} + + return render_func + + +def sample_tensor(tensor: torch.Tensor, max_samples: int) -> torch.Tensor: + if max_samples <= 0 or tensor.shape[0] <= max_samples: + return tensor + idx = torch.randperm(tensor.shape[0], device=tensor.device)[:max_samples] + return tensor[idx] + + +def export_mesh(mesh, path: str) -> None: + verts = mesh.verts.detach().cpu().numpy() + faces = mesh.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(path) + + +def main(): + parser = argparse.ArgumentParser(description="Iteratively refine SDF pivots using Chamfer supervision from the Gaussian cloud.") + parser.add_argument("--ply_path", type=str, required=True, help="Perfect Gaussian PLY.") + parser.add_argument("--camera_poses", type=str, required=True, help="Camera pose JSON.") + parser.add_argument("--output_dir", type=str, default="./iter_occ_refine", help="Output directory.") + parser.add_argument("--iterations", type=int, default=400, help="Number of SDF optimization steps.") + parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate for SDF pivots.") + parser.add_argument("--reg_weight", type=float, default=5e-4, help="L2 regularization weight towards the initial SDF.") + parser.add_argument("--mesh_sample_count", type=int, default=4096, help="Number of mesh vertices sampled per step.") + parser.add_argument("--gaussian_sample_count", type=int, default=4096, help="Number of Gaussian centers sampled per step.") + parser.add_argument("--surface_sample_limit", type=int, default=400000, help="Maximum Gaussians kept for Delaunay pivots.") + parser.add_argument("--clamp_sdf", type=float, default=1.0, help="Clamp range for SDF values.") + parser.add_argument("--log_interval", type=int, default=10, help="Logging interval.") + parser.add_argument("--export_interval", type=int, default=100, help="Mesh export interval (0 disables periodic export).") + parser.add_argument("--image_height", type=int, default=720, help="Renderer image height.") + parser.add_argument("--image_width", type=int, default=1280, help="Renderer image width.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical FoV in degrees.") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device is required for occupancy refinement.") + + device = "cuda" + os.makedirs(args.output_dir, exist_ok=True) + + cameras = load_cameras( + poses_json=args.camera_poses, + height=args.image_height, + width=args.image_width, + fov_y_deg=args.fov_y, + device=device, + ) + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + + pipe_parser = argparse.ArgumentParser() + pipe = PipelineParams(pipe_parser) + background = torch.tensor([0.0, 0.0, 0.0], device=device) + render_func = build_render_function(gaussians, pipe, background) + + with torch.no_grad(): + means = gaussians.get_xyz.detach().clone() + scales = gaussians.get_scaling.detach().clone() + rotations = gaussians.get_rotation.detach().clone() + + with torch.no_grad(): + surface_gaussians_idx = sample_gaussians_on_surface( + views=cameras, + means=means, + scales=scales, + rotations=rotations, + opacities=gaussians.get_opacity, + n_max_samples=args.surface_sample_limit, + scene_type="outdoor", + ) + + if surface_gaussians_idx.numel() == 0: + raise RuntimeError("Surface sampling returned zero Gaussians.") + + surface_means = means[surface_gaussians_idx].detach() + + initial_sdf = compute_initial_sdf_values( + views=cameras, + render_func=render_func, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ).detach() + + pivots, _ = extract_gaussian_pivots( + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + delaunay_tets = compute_delaunay_triangulation( + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + learned_sdf = torch.nn.Parameter(initial_sdf.clone()) + optimizer = torch.optim.Adam([learned_sdf], lr=args.lr) + + for iteration in range(1, args.iterations + 1): + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=learned_sdf, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + mesh_verts = mesh.verts + if mesh_verts.numel() == 0: + print(f"[Iter {iteration:05d}] Empty mesh, skipping update.") + continue + + sampled_mesh_pts = sample_tensor(mesh_verts, args.mesh_sample_count) + sampled_gaussian_pts = sample_tensor(surface_means, args.gaussian_sample_count) + + with torch.no_grad(): + nn_idx_forward = torch.cdist( + sampled_mesh_pts.detach(), + sampled_gaussian_pts.detach(), + p=2, + ).argmin(dim=1) + nn_idx_backward = torch.cdist( + sampled_gaussian_pts, + sampled_mesh_pts.detach(), + p=2, + ).argmin(dim=1) + + nearest_gauss = sampled_gaussian_pts[nn_idx_forward] + nearest_mesh = sampled_mesh_pts[nn_idx_backward] + + loss_forward = torch.mean(torch.sum((sampled_mesh_pts - nearest_gauss) ** 2, dim=1)) + loss_backward = torch.mean(torch.sum((sampled_gaussian_pts - nearest_mesh) ** 2, dim=1)) + chamfer_loss = loss_forward + loss_backward + + reg_loss = F.mse_loss(learned_sdf, initial_sdf) + loss = chamfer_loss + args.reg_weight * reg_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + learned_sdf.clamp_(-args.clamp_sdf, args.clamp_sdf) + + if iteration % args.log_interval == 0 or iteration == 1: + print( + f"[Iter {iteration:05d}] chamfer={chamfer_loss.item():.6f} " + f"reg={reg_loss.item():.6f} total={loss.item():.6f} " + f"|mesh|={sampled_mesh_pts.shape[0]} |gauss|={sampled_gaussian_pts.shape[0]}" + ) + + if args.export_interval > 0 and iteration % args.export_interval == 0: + export_mesh( + mesh=mesh, + path=os.path.join(args.output_dir, f"mesh_iter_{iteration:05d}.ply"), + ) + + final_mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=learned_sdf, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + export_mesh(final_mesh, os.path.join(args.output_dir, "final_mesh.ply")) + + with torch.no_grad(): + final_occ = convert_sdf_to_occupancy(learned_sdf.detach()).view(-1, 9) + base_occ = convert_sdf_to_occupancy(initial_sdf).view(-1, 9) + gaussians.learn_occupancy = True + total_gaussians = gaussians._xyz.shape[0] + base_buffer = base_occ.new_zeros((total_gaussians, 9)) + shift_buffer = base_occ.new_zeros((total_gaussians, 9)) + surface_idx = surface_gaussians_idx.long() + base_buffer.index_copy_(0, surface_idx, base_occ) + shift_buffer.index_copy_(0, surface_idx, final_occ - base_occ) + gaussians._base_occupancy = torch.nn.Parameter(base_buffer, requires_grad=False) + gaussians._occupancy_shift = torch.nn.Parameter(shift_buffer, requires_grad=False) + gaussians.save_ply(os.path.join(args.output_dir, "refined_gaussians.ply")) + + print(f"[INFO] Optimization complete. Results saved to {args.output_dir}.") + + +if __name__ == "__main__": + main() diff --git a/milo/utils/log_utils.py b/milo/utils/log_utils.py index 905d108..2430d27 100644 --- a/milo/utils/log_utils.py +++ b/milo/utils/log_utils.py @@ -1,5 +1,5 @@ import os -from typing import List, Union, Dict, Any +from typing import List, Union, Dict, Any, Optional import torch import numpy as np import math @@ -117,6 +117,22 @@ def make_log_figure( return log_images_dict +def save_inter_figure(depth_diff: torch.Tensor, normal_diff: torch.Tensor, save_path: str): + plt.figure(figsize=(12, 6)) + plt.suptitle("inter") + plt.subplot(1, 2, 1) + plt.imshow(depth_diff.cpu(), cmap="Spectral") + plt.title("depth") + plt.colorbar() + plt.subplot(1, 2, 2) + plt.imshow(normal_diff.cpu(), cmap="Spectral", vmin=0.0, vmax=2.0) + plt.title("normal") + plt.colorbar() + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.savefig(save_path) + plt.close() + + def training_report(iteration, l1_loss, testing_iterations, scene, renderFunc, renderArgs): # Report test and samples of training set if iteration in testing_iterations: @@ -263,6 +279,35 @@ def log_training_progress( )) / 2. ) titles_to_log.append(f"Mesh Normals {viewpoint_idx}") + + # Depth/Normal diff for the current logging view + mesh_depth_map = mesh_render_pkg["depth"].detach().squeeze() + gauss_depth_map = render_pkg["median_depth"].detach().squeeze() + valid_depth_mask = (mesh_depth_map > 0) & (gauss_depth_map > 0) + depth_diff = torch.zeros_like(mesh_depth_map) + depth_diff[valid_depth_mask] = (mesh_depth_map - gauss_depth_map).abs()[valid_depth_mask] + + mesh_normals_view = fix_normal_map( + viewpoint_cam, + mesh_render_pkg["normals"].detach(), + normal_in_view_space=True, + ) + gauss_normals_view = fix_normal_map( + viewpoint_cam, + render_pkg["normal"].detach(), + normal_in_view_space=True, + ) + if mesh_normals_view.shape[0] == 3: + mesh_normals_view = mesh_normals_view.permute(1, 2, 0) + if gauss_normals_view.shape[0] == 3: + gauss_normals_view = gauss_normals_view.permute(1, 2, 0) + normal_dot = (mesh_normals_view * gauss_normals_view).sum(dim=-1).clamp(-1., 1.) + normal_diff = (1. - normal_dot) * valid_depth_mask.float() + + images_to_log.append(depth_diff) + titles_to_log.append(f"Depth Diff {viewpoint_idx}") + images_to_log.append(normal_diff) + titles_to_log.append(f"Normal Diff {viewpoint_idx}") log_images_dict = make_log_figure( images=images_to_log, @@ -319,4 +364,4 @@ def log_training_progress( ema_mesh_depth_loss_for_log, ema_mesh_normal_loss_for_log, ema_occupied_centers_loss_for_log, ema_occupancy_labels_loss_for_log, ema_depth_order_loss_for_log - ) \ No newline at end of file + ) diff --git a/milo/yufu2mesh.py b/milo/yufu2mesh.py new file mode 100644 index 0000000..9fa015b --- /dev/null +++ b/milo/yufu2mesh.py @@ -0,0 +1,219 @@ +from functional import ( + sample_gaussians_on_surface, + extract_gaussian_pivots, + compute_initial_sdf_values, + compute_delaunay_triangulation, + extract_mesh, + frustum_cull_mesh, +) +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +import json, math, torch, trimesh +import numpy as np +from arguments import ModelParams, PipelineParams, OptimizationParams, read_config +def quaternion_to_rotation_matrix(q): + """ + 将单位四元数转换为3x3旋转矩阵。 + + 参数: + q: 一个包含四个元素的列表或数组 [w, x, y, z] + + 返回: + R: 一个3x3的NumPy数组表示的旋转矩阵。 + """ + w, x, y, z = q + # 计算矩阵的每个元素,避免重复计算以提高效率 + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + R = np.array([ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)] + ]) + return R + +# Load or initialize a 3DGS-like model and training cameras +ply_path = "/home/zoyo/Desktop/MILo_rtx50/milo/data/Bridge/yufu_bridge_cleaned.ply" +camera_poses_json = "/home/zoyo/Desktop/MILo_rtx50/milo/data/Bridge/camera_poses_cam1.json" +camera_poses = json.load(open(camera_poses_json)) +with open(camera_poses_json, 'r') as fcc_file: + fcc_data = json.load(fcc_file) + print(len(fcc_data),type(fcc_data)) + +gaussians = GaussianModel( + sh_degree=0, + # use_mip_filter=use_mip_filter, + # learn_occupancy=args.mesh_regularization, + # use_appearance_network=args.decoupled_appearance, + ) +gaussians.load_ply(ply_path) +train_cameras = [] +height = 720 +width = 1280 +fov_y = math.radians(75) +# fov_x = math.radians(108) +aspect_ratio = width / height +fov_x = 2 * math.atan(aspect_ratio * math.tan(fov_y / 2)) +for i in range(len(fcc_data)): + camera_info = fcc_data[i] + camera = Camera( + colmap_id=str(i), + R=quaternion_to_rotation_matrix(camera_info['quaternion']), + T=np.asarray(camera_info['position']), + FoVx=fov_x, + FoVy=fov_y, + image=torch.empty(3, height, width), + gt_alpha_mask=None, + image_name=camera_info['name'], + uid=i, + data_device='cuda', + ) + train_cameras.append(camera) + +# following this template. It will be used only for initializing SDF values. +# The wrapper should accept just a camera as input, and return a dictionary +# with "render" and "depth" keys. +from gaussian_renderer.radegs import render_radegs + + +from argparse import ArgumentParser, Namespace +parser = ArgumentParser(description="Training script parameters") +parser.add_argument("--bug", type=bool, default=False) +pipe = PipelineParams(parser) +background = torch.tensor([0., 0., 0.], device="cuda") +def render_func(view): + render_pkg = render_radegs( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + kernel_size=0.0, + scaling_modifier = 1.0, + require_coord=False, + require_depth=True + ) + return { + "render": render_pkg["render"], + "depth": render_pkg["median_depth"], + } + +# Only the parameters of the Gaussians are needed for extracting the mesh. +means = gaussians.get_xyz +scales = gaussians.get_scaling +rotations = gaussians.get_rotation +opacities = gaussians.get_opacity + +# Sample Gaussians on the surface. +# Should be performed only once, or just once in a while. +# In this example, we sample at most 600_000 Gaussians. +surface_gaussians_idx = sample_gaussians_on_surface( + views=train_cameras, + means=means, + scales=scales, + rotations=rotations, + opacities=opacities, + n_max_samples=600_000, + scene_type='outdoor', +) + +# Compute initial SDF values for pivots. Should be performed only once. +# In the paper, we propose to learn optimal SDF values by maximizing the +# consistency between volumetric renderings and surface mesh renderings. +initial_pivots_sdf = compute_initial_sdf_values( + views=train_cameras, + render_func=render_func, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + +# Compute Delaunay Triangulation. +# Can be performed once in a while. +delaunay_tets = compute_delaunay_triangulation( + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + +# Differentiably extract a mesh from Gaussian parameters, including initial +# or updated SDF values for the Gaussian pivots. +# This function is differentiable with respect to the parameters of the Gaussians, +# as well as the SDF values. Can be performed at every training iteration. +mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=initial_pivots_sdf, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + + + +# You can now apply any differentiable operation on the extracted mesh, +# and backpropagate gradients back to the Gaussians! +# In the paper, we propose to use differentiable mesh rendering. +from scene.mesh import MeshRasterizer, MeshRenderer +renderer = MeshRenderer(MeshRasterizer(cameras=train_cameras)) + +# We cull the mesh based on the view frustum for more efficiency +i_view = np.random.randint(0, len(train_cameras)) +refined_mesh = frustum_cull_mesh(mesh, train_cameras[i_view]) + +mesh_render_pkg = renderer( + refined_mesh, + cam_idx=i_view, + return_depth=True, return_normals=True +) +mesh_depth = mesh_render_pkg["depth"] +mesh_normals = mesh_render_pkg["normals"] + +# 转换为numpy数组后保存 +save_dict = {} +for key, value in mesh_render_pkg.items(): + if isinstance(value, torch.Tensor): + save_dict[key] = value.detach().cpu().numpy() + else: + save_dict[key] = value + +np.savez("mesh_render_output.npz", **save_dict) + +# 保存mesh +# import trimesh + +# 从Meshes对象中提取顶点和面 +refined_vertices = refined_mesh.verts.detach().cpu().numpy() +refined_faces = refined_mesh.faces.detach().cpu().numpy() + +# 创建trimesh对象并保存 +refined_mesh_obj = trimesh.Trimesh(vertices=refined_vertices, faces=refined_faces) + +# # 保存为OBJ格式 +# mesh_obj.export('extracted_mesh.obj') + +# 或者保存为PLY格式 +refined_mesh_obj.export(f'refined_mesh_{len(fcc_data)}.ply') + +vertices = mesh.verts.detach().cpu().numpy() +faces = mesh.faces.detach().cpu().numpy() + +# 创建trimesh对象并保存 +mesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces) + +# # 保存为OBJ格式 +# mesh_obj.export('extracted_mesh.obj') + +# 或者保存为PLY格式 +mesh_obj.export(f'mesh_{len(fcc_data)}.ply') + +# # 或者保存为STL格式 +# mesh_obj.export('extracted_mesh.stl') \ No newline at end of file diff --git a/milo/yufu2mesh_new.py b/milo/yufu2mesh_new.py new file mode 100644 index 0000000..1595dc1 --- /dev/null +++ b/milo/yufu2mesh_new.py @@ -0,0 +1,1664 @@ +from pathlib import Path +import math +import json +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple +from types import SimpleNamespace + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +import trimesh +import yaml + +from argparse import ArgumentParser + +from functional import ( + compute_delaunay_triangulation, + extract_mesh, + frustum_cull_mesh, +) +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +from gaussian_renderer import render_full +from gaussian_renderer.radegs import render_radegs, integrate_radegs +from arguments import PipelineParams +from regularization.regularizer.mesh import ( + initialize_mesh_regularization, + compute_mesh_regularization, +) +from regularization.sdf.learnable import convert_occupancy_to_sdf +from utils.geometry_utils import flatten_voronoi_features + +# DISCOVER-SE 相机轨迹使用 OpenGL 右手坐标系(相机前方为 -Z,向上为 +Y), +# 而 MILo/colmap 渲染管线假设的是前方 +Z、向上 -Y。需要在读入时做一次轴翻转。 +OPENGL_TO_COLMAP = np.diag([1.0, -1.0, -1.0]).astype(np.float32) + +# 统一管理脚本根目录和数据目录,便于构造相对路径 +MILO_DIR = Path(__file__).resolve().parent +DATA_ROOT = MILO_DIR / "data" + + +class DepthProvider: + """负责加载并缓存 Discoverse 深度图,统一形状、裁剪和掩码。""" + + def __init__( + self, + depth_root: Path, + image_height: int, + image_width: int, + device: torch.device, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + self.depth_root = Path(depth_root) + if not self.depth_root.is_dir(): + raise FileNotFoundError(f"深度目录不存在:{self.depth_root}") + self.image_height = image_height + self.image_width = image_width + self.device = device + self.clip_min = clip_min + self.clip_max = clip_max + self._cache: Dict[int, torch.Tensor] = {} + self._mask_cache: Dict[int, torch.Tensor] = {} + + def _file_for_index(self, view_index: int) -> Path: + return self.depth_root / f"depth_img_0_{view_index}.npy" + + def _load_numpy(self, file_path: Path) -> np.ndarray: + depth_np = np.load(file_path) + depth_np = np.squeeze(depth_np) + if depth_np.ndim != 2: + raise ValueError(f"{file_path} 深度数组维度异常:{depth_np.shape}") + if depth_np.shape != (self.image_height, self.image_width): + raise ValueError( + f"{file_path} 深度分辨率应为 {(self.image_height, self.image_width)},当前为 {depth_np.shape}" + ) + if self.clip_min is not None or self.clip_max is not None: + min_val = self.clip_min if self.clip_min is not None else None + max_val = self.clip_max if self.clip_max is not None else None + depth_np = np.clip( + depth_np, + min_val if min_val is not None else depth_np.min(), + max_val if max_val is not None else depth_np.max(), + ) + return depth_np.astype(np.float32) + + def get(self, view_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """返回 (depth_tensor, valid_mask),均在 GPU 上。""" + if view_index not in self._cache: + file_path = self._file_for_index(view_index) + if not file_path.is_file(): + raise FileNotFoundError(f"缺少深度文件:{file_path}") + depth_np = self._load_numpy(file_path) + depth_tensor = torch.from_numpy(depth_np).to(self.device) + valid_mask = torch.isfinite(depth_tensor) & (depth_tensor > 0.0) + if self.clip_min is not None: + valid_mask &= depth_tensor >= self.clip_min + if self.clip_max is not None: + valid_mask &= depth_tensor <= self.clip_max + self._cache[view_index] = depth_tensor + self._mask_cache[view_index] = valid_mask + return self._cache[view_index], self._mask_cache[view_index] + + def as_numpy(self, view_index: int) -> np.ndarray: + depth_tensor, _ = self.get(view_index) + return depth_tensor.detach().cpu().numpy() + + +class NormalGroundTruthCache: + """缓存以初始高斯生成的法线 GT,避免训练阶段重复渲染。""" + + def __init__( + self, + cache_dir: Path, + image_height: int, + image_width: int, + device: torch.device, + ) -> None: + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.image_height = image_height + self.image_width = image_width + self.device = device + self._memory_cache: Dict[int, torch.Tensor] = {} + + def _file_path(self, view_index: int) -> Path: + return self.cache_dir / f"normal_view_{view_index:04d}.npy" + + def has(self, view_index: int) -> bool: + return self._file_path(view_index).is_file() + + def store(self, view_index: int, normal_tensor: torch.Tensor) -> None: + normals_np = prepare_normals(normal_tensor) + np.save(self._file_path(view_index), normals_np.astype(np.float16)) + + def get(self, view_index: int) -> torch.Tensor: + if view_index not in self._memory_cache: + path = self._file_path(view_index) + if not path.is_file(): + raise FileNotFoundError( + f"未找到视角 {view_index} 的法线缓存:{path},请先完成预计算。" + ) + normals_np = np.load(path) + expected_shape = (self.image_height, self.image_width, 3) + if normals_np.shape != expected_shape: + raise ValueError( + f"{path} 法线缓存尺寸应为 {expected_shape},当前为 {normals_np.shape}" + ) + normals_tensor = ( + torch.from_numpy(normals_np.astype(np.float32)) + .permute(2, 0, 1) + .to(self.device) + ) + self._memory_cache[view_index] = normals_tensor + return self._memory_cache[view_index] + + def clear_memory_cache(self) -> None: + self._memory_cache.clear() + + def ensure_all(self, cameras: Sequence[Camera], render_fn) -> None: + total = len(cameras) + for idx, camera in enumerate(cameras): + if self.has(idx): + continue + with torch.no_grad(): + pkg = render_fn(camera) + normal_tensor = pkg["normal"] + self.store(idx, normal_tensor) + if (idx + 1) % 10 == 0 or idx + 1 == total: + print(f"[INFO] 预计算法线缓存 {idx + 1}/{total}") + + +def compute_depth_loss_tensor( + pred_depth: torch.Tensor, + gt_depth: torch.Tensor, + gt_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """返回深度 L1 损失及统计信息。""" + if pred_depth.dim() == 3: + pred_depth = pred_depth.squeeze(0) + if pred_depth.shape != gt_depth.shape: + raise ValueError( + f"预测深度尺寸 {pred_depth.shape} 与 GT {gt_depth.shape} 不一致。" + ) + valid_mask = gt_mask & torch.isfinite(pred_depth) & (pred_depth > 0.0) + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_depth.device) + stats = {"valid_px": 0, "mae": float("nan"), "rmse": float("nan")} + return zero, stats + diff = pred_depth - gt_depth + abs_diff = diff.abs() + loss = abs_diff[valid_mask].mean() + rmse = torch.sqrt((diff[valid_mask] ** 2).mean()) + stats = { + "valid_px": valid_pixels, + "mae": float(abs_diff[valid_mask].mean().detach().item()), + "rmse": float(rmse.detach().item()), + } + return loss, stats + + +def compute_normal_loss_tensor( + pred_normals: torch.Tensor, + gt_normals: torch.Tensor, + base_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """基于余弦相似度的法线损失。""" + if pred_normals.dim() != 3 or pred_normals.shape[0] != 3: + raise ValueError(f"预测法线维度应为 (3,H,W),当前为 {pred_normals.shape}") + if gt_normals.shape != pred_normals.shape: + raise ValueError( + f"法线 GT 尺寸 {gt_normals.shape} 与预测 {pred_normals.shape} 不一致。" + ) + gt_mask = torch.isfinite(gt_normals).all(dim=0) + pred_mask = torch.isfinite(pred_normals).all(dim=0) + valid_mask = base_mask & gt_mask & pred_mask + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_normals.device) + stats = {"valid_px": 0, "mean_cos": float("nan")} + return zero, stats + + pred_unit = F.normalize(pred_normals, dim=0) + gt_unit = F.normalize(gt_normals, dim=0) + cos_sim = (pred_unit * gt_unit).sum(dim=0).clamp(-1.0, 1.0) + loss_map = (1.0 - cos_sim) * valid_mask + loss = loss_map.sum() / valid_mask.sum() + stats = {"valid_px": valid_pixels, "mean_cos": float(cos_sim[valid_mask].mean().item())} + return loss, stats + + +class ManualScene: + """最小 Scene 封装,提供 mesh regularization 所需的接口。""" + + def __init__(self, cameras: Sequence[Camera]): + self._train_cameras = list(cameras) + + def getTrainCameras(self, scale: float = 1.0) -> List[Camera]: + return list(self._train_cameras) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ) -> List[Camera]: + return list(self._train_cameras) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + """构建训练/重建所需的渲染器:训练走 render_full 以获得精确深度梯度,SDF 仍沿用 RaDe-GS。""" + + def render_view(view: Camera) -> Dict[str, torch.Tensor]: + pkg = render_full( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + compute_expected_normals=False, + compute_expected_depth=True, + compute_accurate_median_depth_gradient=True, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def load_mesh_config(config_name: str) -> Dict[str, Any]: + """支持直接传文件路径或 configs/mesh/.yaml。""" + candidate = Path(config_name) + if not candidate.is_file(): + candidate = Path(__file__).resolve().parent / "configs" / "mesh" / f"{config_name}.yaml" + if not candidate.is_file(): + raise FileNotFoundError(f"无法找到 mesh 配置:{config_name}") + with candidate.open("r", encoding="utf-8") as fh: + return yaml.safe_load(fh) + + +def load_optimization_config(config_name: str) -> Dict[str, Any]: + """加载优化配置文件,支持直接传文件路径或 configs/optimization/.yaml。""" + candidate = Path(config_name) + if not candidate.is_file(): + candidate = Path(__file__).resolve().parent / "configs" / "optimization" / f"{config_name}.yaml" + if not candidate.is_file(): + raise FileNotFoundError(f"无法找到优化配置:{config_name}") + with candidate.open("r", encoding="utf-8") as fh: + config = yaml.safe_load(fh) + + # 验证配置结构 + required_keys = ["gaussian_params", "loss_weights", "depth_processing", "mesh_regularization"] + for key in required_keys: + if key not in config: + raise ValueError(f"优化配置文件缺少必需的键:{key}") + + return config + + +def setup_gaussian_optimization( + gaussians: GaussianModel, + opt_config: Dict[str, Any], +) -> Tuple[torch.optim.Optimizer, Dict[str, float]]: + """ + 根据优化配置设置高斯参数的可训练性和优化器。 + + Args: + gaussians: 高斯模型实例 + opt_config: 优化配置字典 + + Returns: + optimizer: 配置好的优化器 + loss_weights: 损失权重字典 + """ + param_groups = [] + params_config = opt_config["gaussian_params"] + + # 遍历所有配置的参数 + for param_name, param_cfg in params_config.items(): + if not hasattr(gaussians, param_name): + print(f"[WARNING] 高斯模型没有属性 {param_name},跳过") + continue + + param_tensor = getattr(gaussians, param_name) + if not isinstance(param_tensor, torch.Tensor): + print(f"[WARNING] {param_name} 不是张量,跳过") + continue + + trainable = param_cfg.get("trainable", False) + lr = param_cfg.get("lr", 0.0) + + # 设置梯度 + param_tensor.requires_grad_(trainable) + + # 如果可训练且学习率>0,添加到优化器参数组 + if trainable and lr > 0.0: + param_groups.append({ + "params": [param_tensor], + "lr": lr, + "name": param_name + }) + print(f"[INFO] 参数 {param_name}: trainable=True, lr={lr}") + else: + print(f"[INFO] 参数 {param_name}: trainable=False") + + if not param_groups: + raise ValueError("没有可训练的参数!请检查优化配置文件。") + + # 创建优化器 + optimizer = torch.optim.Adam(param_groups) + + # 提取损失权重 + loss_weights = opt_config["loss_weights"] + + return optimizer, loss_weights + + +def ensure_gaussian_occupancy(gaussians: GaussianModel) -> None: + """mesh regularization 依赖 9 维 occupancy 网格,此处在推理环境补齐缓冲。""" + needs_init = ( + not getattr(gaussians, "learn_occupancy", False) + or not hasattr(gaussians, "_occupancy_shift") + or gaussians._occupancy_shift.numel() == 0 + ) + if needs_init: + gaussians.learn_occupancy = True + base = torch.zeros((gaussians._xyz.shape[0], 9), device=gaussians._xyz.device) + shift = torch.zeros_like(base) + gaussians._base_occupancy = nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = nn.Parameter(shift.requires_grad_(True)) + + +def export_mesh_from_state( + gaussians: GaussianModel, + mesh_state: Dict[str, Any], + output_path: Path, + reference_camera: Optional[Camera] = None, +) -> None: + """根据当前 mesh_state 导出网格,并可选做视椎裁剪。""" + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + delaunay_tets = mesh_state.get("delaunay_tets") + + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_save = mesh if reference_camera is None else frustum_cull_mesh(mesh, reference_camera) + verts = mesh_to_save.verts.detach().cpu().numpy() + faces = mesh_to_save.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def quaternion_to_rotation_matrix(quaternion: Sequence[float]) -> np.ndarray: + """将单位四元数转换为 3x3 旋转矩阵。""" + # 这里显式转换 DISCOVERSE 导出的四元数,确保后续符合 MILo 的旋转约定 + q = np.asarray(quaternion, dtype=np.float64) + if q.shape != (4,): + raise ValueError(f"四元数需要包含 4 个分量,当前形状为 {q.shape}") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], + ] + ) + return rotation + + +def freeze_gaussian_model(model: GaussianModel) -> None: + """显式关闭高斯模型中参数的梯度。""" + # 推理阶段冻结高斯参数,后续循环只做前向评估 + tensor_attrs = [ + "_xyz", + "_features_dc", + "_features_rest", + "_scaling", + "_rotation", + "_opacity", + ] + for attr in tensor_attrs: + value = getattr(model, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + + +def prepare_depth_map(depth_tensor: torch.Tensor) -> np.ndarray: + """将深度张量转为二维 numpy 数组。""" + # 统一 squeeze 逻辑,防止 Matplotlib 因 shape 异常报错 + depth_np = depth_tensor.detach().cpu().numpy() + depth_np = np.squeeze(depth_np) + if depth_np.ndim == 1: + depth_np = np.expand_dims(depth_np, axis=0) + return depth_np + + +def prepare_normals(normal_tensor: torch.Tensor) -> np.ndarray: + """将法线张量转换为 HxWx3 的 numpy 数组。""" + # 兼容渲染输出为 (3,H,W) 或 (H,W,3) 的两种格式 + normals_np = normal_tensor.detach().cpu().numpy() + normals_np = np.squeeze(normals_np) + if normals_np.ndim == 3 and normals_np.shape[0] == 3: + normals_np = np.transpose(normals_np, (1, 2, 0)) + if normals_np.ndim == 2: + normals_np = normals_np[..., None] + return normals_np + + +def normals_to_rgb(normals: np.ndarray) -> np.ndarray: + """将 [-1,1] 范围的法线向量映射到 [0,1] 以便可视化。""" + normals = np.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0) + rgb = 0.5 * (normals + 1.0) + return np.clip(rgb, 0.0, 1.0).astype(np.float32) + + +def save_normal_visualization(normal_rgb: np.ndarray, output_path: Path) -> None: + """保存法线可视化图像。""" + plt.imsave(output_path, normal_rgb) + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, +) -> List[Camera]: + """按照 MILo 的视图约定读取相机文件,并进行坐标系转换。""" + # 自定义读取 DISCOVERSE 风格 JSON,统一转换到 COLMAP 世界->相机坐标 + pose_path = Path(json_path) + if not pose_path.is_file(): + raise FileNotFoundError(f"未找到相机 JSON:{json_path}") + + with pose_path.open("r", encoding="utf-8") as fh: + camera_list = json.load(fh) + + if isinstance(camera_list, dict): + for key in ("frames", "poses", "camera_poses"): + if key in camera_list and isinstance(camera_list[key], list): + camera_list = camera_list[key] + break + else: + raise ValueError(f"{json_path} 中的 JSON 结构不包含可识别的相机列表。") + + if not isinstance(camera_list, list) or not camera_list: + raise ValueError(f"{json_path} 中没有有效的相机条目。") + + fov_y = math.radians(fov_y_deg) + aspect_ratio = image_width / image_height + fov_x = 2.0 * math.atan(aspect_ratio * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(camera_list): + if "quaternion" in entry: + rotation_c2w = quaternion_to_rotation_matrix(entry["quaternion"]).astype(np.float32) + elif "rotation" in entry: + rotation_c2w = np.asarray(entry["rotation"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 quaternion 或 rotation。") + + if rotation_c2w.shape != (3, 3): + raise ValueError(f"相机条目 {idx} 的旋转矩阵形状应为 (3,3),实际为 {rotation_c2w.shape}") + + # DISCOVER-SE 的 quaternion/rotation 直接导入后,渲染出来的 PNG 会上下翻转, + # 说明其前进方向仍是 OpenGL 的 -Z。通过右乘 diag(1,-1,-1) 将其显式转换到 + # MILo/colmap 的坐标系,使得后续投影矩阵与深度图一致。 + rotation_c2w = rotation_c2w @ OPENGL_TO_COLMAP + + if "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"相机条目 {idx} 的 position 应为 3 维向量,实际为 {camera_center.shape}") + rotation_w2c = rotation_c2w.T + translation = (-rotation_w2c @ camera_center).astype(np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + # 如果 JSON 已直接存储 colmap 风格的 T(即世界到相机),这里假设它与旋转 + # 一样来自 OpenGL 坐标。严格来说也应执行同样的轴变换,但现有数据集只有 + # position 字段;为避免重复转换,这里只做类型检查并保留原值。 + elif "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 position/translation/tvec 信息。") + + if translation.shape != (3,): + raise ValueError(f"相机条目 {idx} 的平移向量应为长度 3,实际为 {translation.shape}") + + image_name = ( + entry.get("name") + or entry.get("image_name") + or entry.get("img_name") + or f"view_{idx:04d}" + ) + + camera = Camera( + colmap_id=str(idx), + R=rotation_c2w, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device="cuda", + ) + cameras.append(camera) + return cameras + + +def save_heatmap(data: np.ndarray, output_path: Path, title: str) -> None: + """将二维数据保存为热力图,便于直观观察差异。""" + plt.figure(figsize=(6, 4)) + finite_mask = np.isfinite(data) + if finite_mask.any(): + finite_values = data[finite_mask] + vmax = float(np.percentile(finite_values, 99.0)) + if (not np.isfinite(vmax)) or (vmax <= 0.0): + vmax = 1.0 + else: + vmax = 1.0 + masked_data = np.ma.array(data, mask=~finite_mask) + plt.imshow(masked_data, cmap="inferno", vmin=0.0, vmax=vmax) + plt.title(title) + plt.colorbar() + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + +def resolve_data_path(user_path: str) -> Path: + """将相对路径映射到 milo/data 下,支持用户提供绝对路径。""" + if not user_path: + raise ValueError("路径参数不能为空。") + path = Path(user_path).expanduser() + if path.is_absolute(): + return path + return (DATA_ROOT / path).resolve() + + +def save_detail_visualizations( + iteration: int, + detail_dir: Path, + view_index: int, + gaussian_depth_map: np.ndarray, + mesh_depth_map: np.ndarray, + gt_depth_map: np.ndarray, + gaussian_normals_map: np.ndarray, + mesh_normals_map: np.ndarray, + gt_normals_map: np.ndarray, + gaussian_valid: np.ndarray, + mesh_valid: np.ndarray, + gt_valid: np.ndarray, + shared_min: float, + shared_max: float, + depth_stats: Dict[str, float], + normal_stats: Dict[str, float], + loss_summary: Dict[str, float], +) -> None: + """保存更详细的调试可视化,便于逐迭代排查。""" + detail_dir.mkdir(parents=True, exist_ok=True) + + shared_min = shared_min if np.isfinite(shared_min) else 0.0 + shared_max = shared_max if np.isfinite(shared_max) else 1.0 + if not np.isfinite(shared_max) or shared_max <= shared_min: + shared_max = shared_min + 1.0 + + gaussian_mask = gaussian_valid & gt_valid + mesh_mask = mesh_valid & gt_valid + + gaussian_depth_diff = np.full(gaussian_depth_map.shape, np.nan, dtype=np.float32) + mesh_depth_diff = np.full(mesh_depth_map.shape, np.nan, dtype=np.float32) + if gaussian_mask.any(): + gaussian_depth_diff[gaussian_mask] = np.abs( + gaussian_depth_map[gaussian_mask] - gt_depth_map[gaussian_mask] + ) + if mesh_mask.any(): + mesh_depth_diff[mesh_mask] = np.abs(mesh_depth_map[mesh_mask] - gt_depth_map[mesh_mask]) + + diff_values: List[np.ndarray] = [] + if gaussian_mask.any(): + diff_values.append(gaussian_depth_diff[gaussian_mask].reshape(-1)) + if mesh_mask.any(): + diff_values.append(mesh_depth_diff[mesh_mask].reshape(-1)) + if diff_values: + diff_stack = np.concatenate(diff_values) + diff_vmax = float(np.percentile(diff_stack, 99.0)) + if (not np.isfinite(diff_vmax)) or diff_vmax <= 0.0: + diff_vmax = 1.0 + else: + diff_vmax = 1.0 + + depth_fig, depth_axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_depth, ax_gaussian_depth, ax_mesh_depth = depth_axes[0] + ax_gaussian_diff, ax_mesh_diff, ax_depth_hist = depth_axes[1] + + im_gt = ax_gt_depth.imshow(gt_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gt_depth.set_title("GT depth") + ax_gt_depth.axis("off") + depth_fig.colorbar(im_gt, ax=ax_gt_depth, fraction=0.046, pad=0.04) + + im_gaussian = ax_gaussian_depth.imshow( + gaussian_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + ax_gaussian_depth.set_title("Gaussian depth") + ax_gaussian_depth.axis("off") + depth_fig.colorbar(im_gaussian, ax=ax_gaussian_depth, fraction=0.046, pad=0.04) + + im_mesh = ax_mesh_depth.imshow( + mesh_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + ax_mesh_depth.set_title("Mesh depth") + ax_mesh_depth.axis("off") + depth_fig.colorbar(im_mesh, ax=ax_mesh_depth, fraction=0.046, pad=0.04) + + gaussian_diff_masked = np.ma.array(gaussian_depth_diff, mask=~gaussian_mask) + mesh_diff_masked = np.ma.array(mesh_depth_diff, mask=~mesh_mask) + ax_gaussian_diff.imshow(gaussian_diff_masked, cmap="magma", vmin=0.0, vmax=diff_vmax) + ax_gaussian_diff.set_title("|Gaussian - GT|") + ax_gaussian_diff.axis("off") + ax_mesh_diff.imshow(mesh_diff_masked, cmap="magma", vmin=0.0, vmax=diff_vmax) + ax_mesh_diff.set_title("|Mesh - GT|") + ax_mesh_diff.axis("off") + + ax_depth_hist.set_title("Depth diff histogram") + ax_depth_hist.set_xlabel("Absolute difference") + ax_depth_hist.set_ylabel("Count") + if diff_values: + if gaussian_mask.any(): + ax_depth_hist.hist( + gaussian_depth_diff[gaussian_mask].reshape(-1), + bins=60, + alpha=0.6, + label="Gaussian", + ) + if mesh_mask.any(): + ax_depth_hist.hist( + mesh_depth_diff[mesh_mask].reshape(-1), + bins=60, + alpha=0.6, + label="Mesh", + ) + ax_depth_hist.legend() + else: + ax_depth_hist.text(0.5, 0.5, "No valid depth diffs", ha="center", va="center") + depth_fig.suptitle( + f"Iter {iteration:04d} view {view_index} | depth_loss={loss_summary['depth_loss']:.4f}" + f" | normal_loss={loss_summary['normal_loss']:.4f} | mesh_loss={loss_summary['mesh_loss']:.4f}", + fontsize=12, + ) + depth_fig.tight_layout(rect=[0, 0, 1, 0.95]) + depth_fig.savefig(detail_dir / f"detail_depth_iter_{iteration:04d}.png", dpi=300) + plt.close(depth_fig) + + gaussian_normals_rgb = normals_to_rgb(gaussian_normals_map) + mesh_normals_rgb = normals_to_rgb(mesh_normals_map) + gt_normals_rgb = normals_to_rgb(gt_normals_map) + + gaussian_normal_mask = np.all(np.isfinite(gaussian_normals_map), axis=-1) & np.all( + np.isfinite(gt_normals_map), axis=-1 + ) + mesh_normal_mask = np.all(np.isfinite(mesh_normals_map), axis=-1) & np.all( + np.isfinite(gt_normals_map), axis=-1 + ) + gaussian_normal_diff = np.linalg.norm(gaussian_normals_map - gt_normals_map, axis=-1) + mesh_normal_diff = np.linalg.norm(mesh_normals_map - gt_normals_map, axis=-1) + gaussian_normal_diff = np.where(gaussian_normal_mask, gaussian_normal_diff, np.nan) + mesh_normal_diff = np.where(mesh_normal_mask, mesh_normal_diff, np.nan) + + normal_diff_values: List[np.ndarray] = [] + if gaussian_normal_mask.any(): + normal_diff_values.append(gaussian_normal_diff[gaussian_normal_mask].reshape(-1)) + if mesh_normal_mask.any(): + normal_diff_values.append(mesh_normal_diff[mesh_normal_mask].reshape(-1)) + if normal_diff_values: + normal_diff_stack = np.concatenate(normal_diff_values) + normal_vmax = float(np.percentile(normal_diff_stack, 99.0)) + if (not np.isfinite(normal_vmax)) or normal_vmax <= 0.0: + normal_vmax = 1.0 + else: + normal_vmax = 1.0 + + normal_fig, normal_axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_normals, ax_gaussian_normals, ax_mesh_normals = normal_axes[0] + ax_gaussian_normal_diff, ax_mesh_normal_diff, ax_normal_text = normal_axes[1] + + ax_gt_normals.imshow(gt_normals_rgb) + ax_gt_normals.set_title("GT normals") + ax_gt_normals.axis("off") + ax_gaussian_normals.imshow(gaussian_normals_rgb) + ax_gaussian_normals.set_title("Gaussian normals") + ax_gaussian_normals.axis("off") + ax_mesh_normals.imshow(mesh_normals_rgb) + ax_mesh_normals.set_title("Mesh normals") + ax_mesh_normals.axis("off") + + gaussian_normals_masked = np.ma.array(gaussian_normal_diff, mask=~gaussian_normal_mask) + mesh_normals_masked = np.ma.array(mesh_normal_diff, mask=~mesh_normal_mask) + im_gaussian_normal = ax_gaussian_normal_diff.imshow( + gaussian_normals_masked, + cmap="magma", + vmin=0.0, + vmax=normal_vmax, + ) + ax_gaussian_normal_diff.set_title("‖Gaussian-GT‖") + ax_gaussian_normal_diff.axis("off") + normal_fig.colorbar(im_gaussian_normal, ax=ax_gaussian_normal_diff, fraction=0.046, pad=0.04) + + im_mesh_normal = ax_mesh_normal_diff.imshow( + mesh_normals_masked, + cmap="magma", + vmin=0.0, + vmax=normal_vmax, + ) + ax_mesh_normal_diff.set_title("‖Mesh-GT‖") + ax_mesh_normal_diff.axis("off") + normal_fig.colorbar(im_mesh_normal, ax=ax_mesh_normal_diff, fraction=0.046, pad=0.04) + + ax_normal_text.axis("off") + text_lines = [ + f"Iter {iteration:04d} view {view_index}", + f"Loss: total={loss_summary['loss']:.4f} depth={loss_summary['depth_loss']:.4f} normal={loss_summary['normal_loss']:.4f}", + f"Mesh loss={loss_summary['mesh_loss']:.4f} (depth={loss_summary['mesh_depth_loss']:.4f}, normal={loss_summary['mesh_normal_loss']:.4f})", + f"Occupancy: centers={loss_summary['occupied_loss']:.4f} labels={loss_summary['labels_loss']:.4f}", + f"Depth metrics: mae={depth_stats['mae']:.4f} rmse={depth_stats['rmse']:.4f}", + f"Normal metrics: valid_px={normal_stats['valid_px']:.0f} cos={normal_stats['mean_cos']:.4f}", + f"Grad norm={loss_summary['grad_norm']:.4f}", + ] + ax_normal_text.text(0.0, 1.0, "\n".join(text_lines), va="top") + normal_fig.tight_layout() + normal_fig.savefig(detail_dir / f"detail_normals_iter_{iteration:04d}.png", dpi=300) + plt.close(normal_fig) + +def main(): + parser = ArgumentParser(description="桥梁场景高斯到网格迭代分析脚本") + parser.add_argument( + "--num_iterations", + type=int, + default=5, + help="执行循环的次数(未启用 --lock_view_repeat 时生效)", + ) + parser.add_argument("--ma_beta", type=float, default=0.8, help="loss 滑动平均系数") + + # ========== 新增:优化配置文件参数 ========== + parser.add_argument( + "--opt_config", + type=str, + default="default", + help="优化配置名称或完整路径(默认 default,查找 configs/optimization/default.yaml)", + ) + + # ========== 保留旧参数以向后兼容,但会被YAML配置覆盖 ========== + parser.add_argument( + "--depth_loss_weight", type=float, default=None, help="(已弃用,请使用 --opt_config) 深度一致性项权重" + ) + parser.add_argument( + "--normal_loss_weight", type=float, default=None, help="(已弃用,请使用 --opt_config) 法线一致性项权重" + ) + parser.add_argument( + "--lr", + "--learning_rate", + dest="lr", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) 仅优化 XYZ 的学习率", + ) + parser.add_argument( + "--depth_clip_min", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) 深度最小裁剪值", + ) + parser.add_argument( + "--depth_clip_max", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) 深度最大裁剪值", + ) + parser.add_argument( + "--mesh_depth_weight", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) mesh 深度项权重覆盖", + ) + parser.add_argument( + "--mesh_normal_weight", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) mesh 法线项权重覆盖", + ) + + parser.add_argument( + "--delaunay_reset_interval", + type=int, + default=1000, + help="每隔多少次迭代重建一次 Delaunay(<=0 表示每次重建)", + ) + parser.add_argument( + "--mesh_config", + type=str, + default="medium", + help="mesh 配置名称或路径(默认 medium)", + ) + parser.add_argument( + "--save_interval", + type=int, + default=None, + help="保存可视化/npz 的间隔,默认与 Delaunay 重建间隔相同", + ) + parser.add_argument( + "--detail_interval", + type=int, + default=None, + help="详细调试图像保存间隔(单位:迭代,未设置则禁用)", + ) + parser.add_argument( + "--heatmap_dir", + type=str, + default="yufu2mesh_outputs", + help="保存热力图等输出的目录", + ) + parser.add_argument( + "--depth_gt_dir", + type=str, + default="bridge_small/depth", + help="Discoverse 深度 npy 相对路径(根目录为 milo/data,亦可填绝对路径)", + ) + parser.add_argument( + "--ply_path", + type=str, + default="bridge_small/yufu_bridge_small.ply", + help="初始高斯 PLY 路径(相对于 milo/data,可填绝对路径)", + ) + parser.add_argument( + "--camera_poses_json", + type=str, + default="bridge_small/camera_poses_cam1.json", + help="相机位姿 JSON 路径(相对于 milo/data,可填绝对路径)", + ) + parser.add_argument( + "--normal_cache_dir", + type=str, + default=None, + help="法线缓存目录,默认为 runs//normal_gt", + ) + parser.add_argument( + "--skip_normal_gt_generation", + action="store_true", + help="已存在缓存时跳过初始法线 GT 预计算", + ) + parser.add_argument("--seed", type=int, default=0, help="控制随机性的种子") + parser.add_argument( + "--lock_view_repeat", + type=int, + default=None, + help="启用视角锁定调试模式时指定同一视角连续迭代次数,启用后总迭代数=视角数量×该值并忽略 --num_iterations;未提供则关闭该模式", + ) + parser.add_argument( + "--log_interval", + type=int, + default=100, + help="控制迭代日志的打印频率(默认每次迭代打印)", + ) + parser.add_argument( + "--warn_until_iter", + type=int, + default=3000, + help="surface sampling warmup 迭代数(用于 mesh downsample)", + ) + parser.add_argument( + "--imp_metric", + type=str, + default="outdoor", + choices=["outdoor", "indoor"], + help="surface sampling 的重要性度量类型", + ) + parser.add_argument( + "--mesh_start_iter", + type=int, + default=2000, + help="mesh 正则起始迭代(默认 2000,避免冷启动阶段干扰)", + ) + parser.add_argument( + "--mesh_update_interval", + type=int, + default=5, + help="mesh 正则重建/回传间隔,>1 可减少 DMTet 抖动(默认 5)", + ) + pipe = PipelineParams(parser) + args = parser.parse_args() + + # ========== 加载优化配置 ========== + print(f"[INFO] 加载优化配置:{args.opt_config}") + opt_config = load_optimization_config(args.opt_config) + + # 兼容性:如果命令行指定了旧参数,发出警告并使用YAML配置 + if args.depth_loss_weight is not None: + print(f"[WARNING] --depth_loss_weight 已弃用,将使用YAML配置中的值") + if args.normal_loss_weight is not None: + print(f"[WARNING] --normal_loss_weight 已弃用,将使用YAML配置中的值") + if args.lr is not None: + print(f"[WARNING] --lr 已弃用,将使用YAML配置中的学习率设置") + if args.depth_clip_min is not None: + print(f"[WARNING] --depth_clip_min 已弃用,将使用YAML配置中的值") + if args.depth_clip_max is not None: + print(f"[WARNING] --depth_clip_max 已弃用,将使用YAML配置中的值") + if args.mesh_depth_weight is not None: + print(f"[WARNING] --mesh_depth_weight 已弃用,将使用YAML配置中的值") + if args.mesh_normal_weight is not None: + print(f"[WARNING] --mesh_normal_weight 已弃用,将使用YAML配置中的值") + + lock_view_mode = args.lock_view_repeat is not None + lock_repeat = max(1, args.lock_view_repeat) if lock_view_mode else 1 + + pipe.debug = getattr(args, "debug", False) + + # 所有输出固定写入 milo/runs/ 下,便于管理实验产物 + base_run_dir = MILO_DIR / "runs" + output_dir = base_run_dir / args.heatmap_dir + output_dir.mkdir(parents=True, exist_ok=True) + iteration_image_dir = output_dir / "iteration_images" + iteration_image_dir.mkdir(parents=True, exist_ok=True) + lock_view_output_dir: Optional[Path] = None + if lock_view_mode: + lock_view_output_dir = output_dir / "lock_view_repeat" + lock_view_output_dir.mkdir(parents=True, exist_ok=True) + + detail_interval: Optional[int] = None + detail_image_dir: Optional[Path] = None + if args.detail_interval is not None: + if args.detail_interval <= 0: + print("[WARNING] --detail_interval <= 0,已禁用详细调试输出。") + else: + detail_interval = max(1, args.detail_interval) + detail_image_dir = output_dir / "detail_images" + detail_image_dir.mkdir(parents=True, exist_ok=True) + print(f"[INFO] 启用详细调试输出:每 {detail_interval} 次迭代写入 detail_images。") + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + depth_gt_dir = resolve_data_path(args.depth_gt_dir) + ply_path = resolve_data_path(args.ply_path) + camera_poses_json = resolve_data_path(args.camera_poses_json) + print( + f"[INFO] 使用数据路径:depth_gt={depth_gt_dir}, ply={ply_path}, camera_json={camera_poses_json}" + ) + + gaussians = GaussianModel(sh_degree=0, learn_occupancy=True) + gaussians.load_ply(str(ply_path)) + + # ========== 使用新的优化配置设置参数 ========== + print("[INFO] 配置高斯参数优化...") + optimizer, loss_weights = setup_gaussian_optimization(gaussians, opt_config) + + height = 720 + width = 1280 + fov_y_deg = 75.0 + + train_cameras = load_cameras_from_json( + json_path=str(camera_poses_json), + image_height=height, + image_width=width, + fov_y_deg=fov_y_deg, + ) + num_views = len(train_cameras) + print(f"[INFO] 成功加载 {num_views} 个相机视角。") + if lock_view_mode: + total_iterations = num_views * lock_repeat + print( + f"[INFO] 启用视角锁定调试:每个视角连续 {lock_repeat} 次,总迭代数 {total_iterations}(已忽略 --num_iterations)。" + ) + else: + total_iterations = args.num_iterations + + device = gaussians._xyz.device + background = torch.tensor([0.0, 0.0, 0.0], device=device) + + # ========== 应用优化配置到mesh和深度处理 ========== + mesh_config = load_mesh_config(args.mesh_config) + mesh_config["start_iter"] = max(0, args.mesh_start_iter) + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", total_iterations), total_iterations) + mesh_config["mesh_update_interval"] = max(1, args.mesh_update_interval) + mesh_config["delaunay_reset_interval"] = args.delaunay_reset_interval + + # 从优化配置中获取mesh权重 + mesh_reg_config = opt_config["mesh_regularization"] + mesh_config["depth_weight"] = mesh_reg_config["depth_weight"] + mesh_config["normal_weight"] = mesh_reg_config["normal_weight"] + print(f"[INFO] Mesh正则化权重: depth={mesh_config['depth_weight']}, normal={mesh_config['normal_weight']}") + + # 这里默认沿用 surface 采样以对齐训练阶段;如仅需快速分析,也可以切换为 random 提升速度。 + mesh_config["delaunay_sampling_method"] = "surface" + + scene_wrapper = ManualScene(train_cameras) + + ensure_gaussian_occupancy(gaussians) + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + + # 从优化配置中获取深度处理参数 + depth_proc_config = opt_config["depth_processing"] + depth_clip_min = depth_proc_config.get("clip_min") + depth_clip_max = depth_proc_config.get("clip_max") + if depth_clip_min is not None and depth_clip_min <= 0.0: + depth_clip_min = None + print(f"[INFO] 深度裁剪范围: min={depth_clip_min}, max={depth_clip_max}") + + depth_provider = DepthProvider( + depth_root=depth_gt_dir, + image_height=height, + image_width=width, + device=device, + clip_min=depth_clip_min, + clip_max=depth_clip_max, + ) + + normal_cache_dir = Path(args.normal_cache_dir) if args.normal_cache_dir else (output_dir / "normal_gt") + normal_cache = NormalGroundTruthCache( + cache_dir=normal_cache_dir, + image_height=height, + image_width=width, + device=device, + ) + if args.skip_normal_gt_generation: + missing = [idx for idx in range(num_views) if not normal_cache.has(idx)] + if missing: + raise RuntimeError( + f"跳过法线 GT 预计算被拒绝,仍有 {len(missing)} 个视角缺少缓存(示例 {missing[:5]})。" + ) + else: + print("[INFO] 开始预计算初始法线 GT(仅进行一次,若存在缓存会自动跳过)。") + normal_cache.ensure_all(train_cameras, render_view) + normal_cache.clear_memory_cache() + + mesh_renderer, mesh_state = initialize_mesh_regularization(scene_wrapper, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + # optimizer已在setup_gaussian_optimization中创建 + mesh_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + ) + + # 输出loss权重配置 + print(f"[INFO] Loss权重配置:") + for loss_name, weight in loss_weights.items(): + print(f" > {loss_name}: {weight}") + + # 记录整个迭代过程中的指标与梯度,结束时统一写入 npz/曲线 + stats_history: Dict[str, List[float]] = { + "iteration": [], + "depth_loss": [], + "normal_loss": [], + "mesh_loss": [], + "mesh_depth_loss": [], + "mesh_normal_loss": [], + "occupied_centers_loss": [], + "occupancy_labels_loss": [], + "depth_mae": [], + "depth_rmse": [], + "normal_mean_cos": [], + "normal_valid_px": [], + "grad_norm": [], + } + + moving_loss = None + previous_depth: Dict[int, np.ndarray] = {} + previous_normals: Dict[int, np.ndarray] = {} + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + lock_sequence = list(range(num_views)) if lock_view_mode else [] + lock_view_ptr = 0 + lock_repeat_ptr = 0 + save_interval = args.save_interval if args.save_interval is not None else args.delaunay_reset_interval + if save_interval is None or save_interval <= 0: + save_interval = 1 + log_interval = max(1, args.log_interval) + mesh_export_warned = False + + for iteration in range(total_iterations): + optimizer.zero_grad(set_to_none=True) + if lock_view_mode: + view_index = lock_sequence[lock_view_ptr] + else: + if not camera_stack: + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + view_index = camera_stack.pop() + viewpoint = train_cameras[view_index] + + training_pkg = render_view(viewpoint) + gt_depth_tensor, gt_depth_mask = depth_provider.get(view_index) + depth_loss_tensor, depth_stats = compute_depth_loss_tensor( + pred_depth=training_pkg["median_depth"], + gt_depth=gt_depth_tensor, + gt_mask=gt_depth_mask, + ) + gt_normals_tensor = normal_cache.get(view_index) + normal_loss_tensor, normal_stats = compute_normal_loss_tensor( + pred_normals=training_pkg["normal"], + gt_normals=gt_normals_tensor, + base_mask=gt_depth_mask, + ) + def _zero_mesh_pkg() -> Dict[str, Any]: + zero = torch.zeros((), device=device) + depth_zero = torch.zeros_like(training_pkg["median_depth"]) + normal_zero = torch.zeros( + training_pkg["median_depth"].shape[-2], + training_pkg["median_depth"].shape[-1], + 3, + device=device, + ) + return { + "mesh_loss": zero, + "mesh_depth_loss": zero, + "mesh_normal_loss": zero, + "occupied_centers_loss": zero, + "occupancy_labels_loss": zero, + "updated_state": mesh_state, + "mesh_render_pkg": { + "depth": depth_zero, + "normals": normal_zero, + }, + } + + mesh_active = iteration >= mesh_config["start_iter"] + if mesh_active: + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=training_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_index, + gaussians=gaussians, + scene=scene_wrapper, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=1.0, + args=mesh_args, + integrate_func=integrate_radegs, + ) + else: + mesh_pkg = _zero_mesh_pkg() + + mesh_state = mesh_pkg["updated_state"] + mesh_loss_tensor = mesh_pkg["mesh_loss"] + + # ========== 使用YAML配置的loss权重 ========== + total_loss = ( + loss_weights["depth"] * depth_loss_tensor + + loss_weights["normal"] * normal_loss_tensor + + mesh_loss_tensor + ) + depth_loss_value = float(depth_loss_tensor.detach().item()) + normal_loss_value = float(normal_loss_tensor.detach().item()) + mesh_loss_value = float(mesh_loss_tensor.detach().item()) + loss_value = float(total_loss.detach().item()) + + if total_loss.requires_grad: + total_loss.backward() + # 计算所有可训练参数的总梯度范数 + total_grad_norm_sq = 0.0 + for param_group in optimizer.param_groups: + for param in param_group["params"]: + if param.grad is not None: + total_grad_norm_sq += param.grad.detach().norm().item() ** 2 + grad_norm = float(total_grad_norm_sq ** 0.5) + optimizer.step() + else: + optimizer.zero_grad(set_to_none=True) + grad_norm = float("nan") + + mesh_render_pkg = mesh_pkg["mesh_render_pkg"] + mesh_depth_map = prepare_depth_map(mesh_render_pkg["depth"]) + mesh_normals_map = prepare_normals(mesh_render_pkg["normals"]) + gaussian_depth_map = prepare_depth_map(training_pkg["median_depth"]) + gaussian_normals_map = prepare_normals(training_pkg["normal"]) + gt_depth_map = depth_provider.as_numpy(view_index) + gt_normals_map = prepare_normals(gt_normals_tensor) + + mesh_valid = np.isfinite(mesh_depth_map) & (mesh_depth_map > 0.0) + gaussian_valid = np.isfinite(gaussian_depth_map) & (gaussian_depth_map > 0.0) + gt_valid = np.isfinite(gt_depth_map) & (gt_depth_map > 0.0) + overlap_mask = gaussian_valid & gt_valid + + depth_delta = gaussian_depth_map - gt_depth_map + if overlap_mask.any(): + delta_abs = np.abs(depth_delta[overlap_mask]) + diff_mean = float(delta_abs.mean()) + diff_max = float(delta_abs.max()) + diff_rmse = float(np.sqrt(np.mean(depth_delta[overlap_mask] ** 2))) + else: + diff_mean = diff_max = diff_rmse = float("nan") + + mesh_depth_loss = float(mesh_pkg["mesh_depth_loss"].item()) + mesh_normal_loss = float(mesh_pkg["mesh_normal_loss"].item()) + occupied_loss = float(mesh_pkg["occupied_centers_loss"].item()) + labels_loss = float(mesh_pkg["occupancy_labels_loss"].item()) + + moving_loss = ( + loss_value + if moving_loss is None + else args.ma_beta * moving_loss + (1 - args.ma_beta) * loss_value + ) + + stats_history["iteration"].append(float(iteration)) + stats_history["depth_loss"].append(depth_loss_value) + stats_history["normal_loss"].append(normal_loss_value) + stats_history["mesh_loss"].append(mesh_loss_value) + stats_history["mesh_depth_loss"].append(mesh_depth_loss) + stats_history["mesh_normal_loss"].append(mesh_normal_loss) + stats_history["occupied_centers_loss"].append(occupied_loss) + stats_history["occupancy_labels_loss"].append(labels_loss) + stats_history["depth_mae"].append(depth_stats["mae"]) + stats_history["depth_rmse"].append(depth_stats["rmse"]) + stats_history["normal_mean_cos"].append(normal_stats["mean_cos"]) + stats_history["normal_valid_px"].append(float(normal_stats["valid_px"])) + stats_history["grad_norm"].append(grad_norm) + + def _fmt(value: float) -> str: + return f"{value:.6f}" + + if (iteration % log_interval == 0) or (iteration == total_iterations - 1): + print( + "[INFO] Iter {iter:02d} | loss={total} (depth={depth}, normal={normal}, mesh={mesh}) | ma_loss={ma}".format( + iter=iteration, + total=_fmt(loss_value), + depth=_fmt(depth_loss_value), + normal=_fmt(normal_loss_value), + mesh=_fmt(mesh_loss_value), + ma=f"{moving_loss:.6f}", + ) + ) + + should_save = (save_interval <= 0) or (iteration % save_interval == 0) + should_save_detail = ( + detail_interval is not None and (iteration % detail_interval == 0) + ) + shared_min: Optional[float] = None + shared_max: Optional[float] = None + if should_save or should_save_detail: + valid_values: List[np.ndarray] = [] + if mesh_valid.any(): + valid_values.append(mesh_depth_map[mesh_valid].reshape(-1)) + if gaussian_valid.any(): + valid_values.append(gaussian_depth_map[gaussian_valid].reshape(-1)) + if gt_valid.any(): + valid_values.append(gt_depth_map[gt_valid].reshape(-1)) + if valid_values: + all_valid = np.concatenate(valid_values) + shared_min = float(all_valid.min()) + shared_max = float(all_valid.max()) + else: + shared_min, shared_max = 0.0, 1.0 + + loss_summary = { + "depth_loss": depth_loss_value, + "normal_loss": normal_loss_value, + "mesh_loss": mesh_loss_value, + "mesh_depth_loss": mesh_depth_loss, + "mesh_normal_loss": mesh_normal_loss, + "occupied_loss": occupied_loss, + "labels_loss": labels_loss, + "loss": loss_value, + "grad_norm": grad_norm, + } + + if should_save: + if shared_min is None or shared_max is None: + shared_min, shared_max = 0.0, 1.0 + + gaussian_depth_vis_path = iteration_image_dir / f"gaussian_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gaussian_depth_vis_path, + gaussian_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + depth_vis_path = iteration_image_dir / f"mesh_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + depth_vis_path, + mesh_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + gt_depth_vis_path = iteration_image_dir / f"gt_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gt_depth_vis_path, + gt_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + normal_vis_path = iteration_image_dir / f"mesh_normal_vis_iter_{iteration:02d}.png" + mesh_normals_rgb = normals_to_rgb(mesh_normals_map) + save_normal_visualization(mesh_normals_rgb, normal_vis_path) + + gaussian_normal_vis_path = iteration_image_dir / f"gaussian_normal_vis_iter_{iteration:02d}.png" + gaussian_normals_rgb = normals_to_rgb(gaussian_normals_map) + save_normal_visualization(gaussian_normals_rgb, gaussian_normal_vis_path) + + gt_normal_vis_path = iteration_image_dir / f"gt_normal_vis_iter_{iteration:02d}.png" + gt_normals_rgb = normals_to_rgb(gt_normals_map) + save_normal_visualization(gt_normals_rgb, gt_normal_vis_path) + + output_npz = output_dir / f"mesh_render_iter_{iteration:02d}.npz" + np.savez( + output_npz, + mesh_depth=mesh_depth_map, + gaussian_depth=gaussian_depth_map, + depth_gt=gt_depth_map, + mesh_normals=mesh_normals_map, + gaussian_normals=gaussian_normals_map, + normal_gt=gt_normals_map, + depth_loss=depth_loss_value, + normal_loss=normal_loss_value, + mesh_loss=mesh_loss_value, + mesh_depth_loss=mesh_depth_loss, + mesh_normal_loss=mesh_normal_loss, + occupied_centers_loss=occupied_loss, + occupancy_labels_loss=labels_loss, + loss=loss_value, + moving_loss=moving_loss, + depth_mae=depth_stats["mae"], + depth_rmse=depth_stats["rmse"], + normal_valid_px=normal_stats["valid_px"], + normal_mean_cos=normal_stats["mean_cos"], + grad_norm=grad_norm, + iteration=iteration, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + if overlap_mask.any(): + depth_diff_vis = np.full_like( + gaussian_depth_map, np.nan, dtype=np.float32 + ) + depth_diff_vis[overlap_mask] = np.abs(depth_delta[overlap_mask]) + save_heatmap( + depth_diff_vis, + iteration_image_dir / f"depth_diff_iter_{iteration:02d}.png", + f"|Pred-GT| iter {iteration}", + ) + + composite_path = iteration_image_dir / f"comparison_iter_{iteration:02d}.png" + fig, axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_depth, ax_gaussian_depth, ax_mesh_depth = axes[0] + ax_gt_normals, ax_gaussian_normals, ax_mesh_normals = axes[1] + + im0 = ax_gt_depth.imshow(gt_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gt_depth.set_title("GT depth") + ax_gt_depth.axis("off") + fig.colorbar(im0, ax=ax_gt_depth, fraction=0.046, pad=0.04) + + im1 = ax_gaussian_depth.imshow(gaussian_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gaussian_depth.set_title("Gaussian depth") + ax_gaussian_depth.axis("off") + fig.colorbar(im1, ax=ax_gaussian_depth, fraction=0.046, pad=0.04) + + im2 = ax_mesh_depth.imshow(mesh_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_mesh_depth.set_title("Mesh depth") + ax_mesh_depth.axis("off") + fig.colorbar(im2, ax=ax_mesh_depth, fraction=0.046, pad=0.04) + + ax_gt_normals.imshow(gt_normals_rgb) + ax_gt_normals.set_title("GT normals") + ax_gt_normals.axis("off") + + ax_gaussian_normals.imshow(gaussian_normals_rgb) + ax_gaussian_normals.set_title("Gaussian normals") + ax_gaussian_normals.axis("off") + + ax_mesh_normals.imshow(mesh_normals_rgb) + ax_mesh_normals.set_title("Mesh normals") + ax_mesh_normals.axis("off") + + info_lines = [ + f"Iteration: {iteration:02d}", + f"View index: {view_index}", + f"GT depth valid px: {int(gt_valid.sum())}", + f"Gaussian depth valid px: {int(gaussian_valid.sum())}", + f"|Pred - GT| mean={diff_mean:.3f}, max={diff_max:.3f}, RMSE={diff_rmse:.3f}", + f"Depth loss={_fmt(depth_loss_value)} (w={loss_weights['depth']:.2f}, mae={depth_stats['mae']:.3f}, rmse={depth_stats['rmse']:.3f})", + f"Normal loss={_fmt(normal_loss_value)} (w={loss_weights['normal']:.2f}, px={normal_stats['valid_px']}, cos={normal_stats['mean_cos']:.3f})", + f"Mesh loss={_fmt(mesh_loss_value)}", + f"Mesh depth loss={_fmt(mesh_depth_loss)} mesh normal loss={_fmt(mesh_normal_loss)}", + f"Occupied centers={_fmt(occupied_loss)} labels={_fmt(labels_loss)}", + ] + fig.suptitle("\n".join(info_lines), fontsize=12, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.94]) + fig.savefig(composite_path, dpi=300) + plt.close(fig) + + mesh_ready_for_export = mesh_active or mesh_state.get("delaunay_tets") is not None + if mesh_ready_for_export: + with torch.no_grad(): + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=output_dir / f"mesh_iter_{iteration:02d}.ply", + reference_camera=None, + ) + else: + if not mesh_export_warned: + print( + "[INFO] Mesh 正则尚未启动或尚未生成 Delaunay,已跳过网格导出以避免全量三角化。" + ) + mesh_export_warned = True + + if should_save_detail and detail_image_dir is not None: + save_detail_visualizations( + iteration=iteration, + detail_dir=detail_image_dir, + view_index=view_index, + gaussian_depth_map=gaussian_depth_map, + mesh_depth_map=mesh_depth_map, + gt_depth_map=gt_depth_map, + gaussian_normals_map=gaussian_normals_map, + mesh_normals_map=mesh_normals_map, + gt_normals_map=gt_normals_map, + gaussian_valid=gaussian_valid, + mesh_valid=mesh_valid, + gt_valid=gt_valid, + shared_min=shared_min if shared_min is not None else 0.0, + shared_max=shared_max if shared_max is not None else 1.0, + depth_stats=depth_stats, + normal_stats=normal_stats, + loss_summary=loss_summary, + ) + + if lock_view_mode and lock_view_output_dir is not None: + if view_index in previous_depth: + prev_depth_map = previous_depth[view_index] + depth_valid_mask = ( + np.isfinite(prev_depth_map) + & np.isfinite(gaussian_depth_map) + & (prev_depth_map > 0.0) + & (gaussian_depth_map > 0.0) + ) + if depth_valid_mask.any(): + depth_delta = np.abs(gaussian_depth_map - prev_depth_map) + depth_diff = np.full_like( + gaussian_depth_map, np.nan, dtype=np.float32 + ) + depth_diff[depth_valid_mask] = depth_delta[depth_valid_mask] + save_heatmap( + depth_diff, + lock_view_output_dir / f"depth_diff_iter_{iteration:02d}_temporal.png", + f"Depth Δ iter {iteration}", + ) + if view_index in previous_normals: + prev_normals_map = previous_normals[view_index] + normal_valid_mask = np.all( + np.isfinite(prev_normals_map), axis=-1 + ) & np.all(np.isfinite(gaussian_normals_map), axis=-1) + if normal_valid_mask.any(): + normal_delta = gaussian_normals_map - prev_normals_map + if normal_delta.ndim == 3: + normal_diff = np.linalg.norm(normal_delta, axis=-1) + else: + normal_diff = np.abs(normal_delta) + normal_diff_vis = np.full_like( + normal_diff, np.nan, dtype=np.float32 + ) + normal_diff_vis[normal_valid_mask] = normal_diff[normal_valid_mask] + save_heatmap( + normal_diff_vis, + lock_view_output_dir / f"normal_diff_iter_{iteration:02d}_temporal.png", + f"Normal Δ iter {iteration}", + ) + + if lock_view_mode: + previous_depth[view_index] = gaussian_depth_map + previous_normals[view_index] = gaussian_normals_map + lock_repeat_ptr += 1 + if lock_repeat_ptr >= lock_repeat: + lock_repeat_ptr = 0 + lock_view_ptr = (lock_view_ptr + 1) % num_views + with torch.no_grad(): + # 输出完整指标轨迹及汇总曲线,方便任务结束后快速复盘 + history_npz = output_dir / "metrics_history.npz" + np.savez( + history_npz, + **{k: np.asarray(v, dtype=np.float32) for k, v in stats_history.items()}, + ) + summary_fig = output_dir / "metrics_summary.png" + if stats_history["iteration"]: + fig, axes = plt.subplots(2, 2, figsize=(16, 10), dpi=200) + iters = np.asarray(stats_history["iteration"]) + axes[0, 0].plot(iters, stats_history["depth_loss"], label="depth") + axes[0, 0].plot(iters, stats_history["normal_loss"], label="normal") + axes[0, 0].plot(iters, stats_history["mesh_loss"], label="mesh") + axes[0, 0].set_title("Total losses") + axes[0, 0].set_xlabel("Iteration") + axes[0, 0].legend() + + axes[0, 1].plot(iters, stats_history["mesh_depth_loss"], label="mesh depth") + axes[0, 1].plot(iters, stats_history["mesh_normal_loss"], label="mesh normal") + axes[0, 1].plot(iters, stats_history["occupied_centers_loss"], label="occupied centers") + axes[0, 1].plot(iters, stats_history["occupancy_labels_loss"], label="occupancy labels") + axes[0, 1].set_title("Mesh regularization components") + axes[0, 1].set_xlabel("Iteration") + axes[0, 1].legend() + + axes[1, 0].plot(iters, stats_history["depth_mae"], label="depth MAE") + axes[1, 0].plot(iters, stats_history["depth_rmse"], label="depth RMSE") + axes[1, 0].set_title("Depth metrics") + axes[1, 0].set_xlabel("Iteration") + axes[1, 0].legend() + + axes[1, 1].plot(iters, stats_history["normal_mean_cos"], label="mean cos") + axes[1, 1].plot(iters, stats_history["grad_norm"], label="grad norm") + axes[1, 1].set_title("Normals / Gradients") + axes[1, 1].set_xlabel("Iteration") + axes[1, 1].legend() + + fig.tight_layout() + fig.savefig(summary_fig) + plt.close(fig) + print(f"[INFO] 已保存曲线汇总:{summary_fig}") + print(f"[INFO] 记录所有迭代指标到 {history_npz}") + final_mesh_path = output_dir / "mesh_final.ply" + final_gaussian_path = output_dir / "gaussians_final.ply" + print(f"[INFO] 导出最终 mesh 到 {final_mesh_path}") + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=final_mesh_path, + reference_camera=None, + ) + print(f"[INFO] 导出最终高斯到 {final_gaussian_path}") + gaussians.save_ply(str(final_gaussian_path)) + print("[INFO] 循环结束,所有结果已写入输出目录。") + + +if __name__ == "__main__": + main() diff --git a/milo/yufu2mesh_new_backup_20250114.py b/milo/yufu2mesh_new_backup_20250114.py new file mode 100644 index 0000000..007300f --- /dev/null +++ b/milo/yufu2mesh_new_backup_20250114.py @@ -0,0 +1,1204 @@ +from pathlib import Path +import math +import json +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple +from types import SimpleNamespace + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +import trimesh +import yaml + +from argparse import ArgumentParser + +from functional import ( + compute_delaunay_triangulation, + extract_mesh, + frustum_cull_mesh, +) +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +from gaussian_renderer import render_full +from gaussian_renderer.radegs import render_radegs, integrate_radegs +from arguments import PipelineParams +from regularization.regularizer.mesh import ( + initialize_mesh_regularization, + compute_mesh_regularization, +) +from regularization.sdf.learnable import convert_occupancy_to_sdf +from utils.geometry_utils import flatten_voronoi_features + +# DISCOVER-SE 相机轨迹使用 OpenGL 右手坐标系(相机前方为 -Z,向上为 +Y), +# 而 MILo/colmap 渲染管线假设的是前方 +Z、向上 -Y。需要在读入时做一次轴翻转。 +OPENGL_TO_COLMAP = np.diag([1.0, -1.0, -1.0]).astype(np.float32) + + +class DepthProvider: + """负责加载并缓存 Discoverse 深度图,统一形状、裁剪和掩码。""" + + def __init__( + self, + depth_root: Path, + image_height: int, + image_width: int, + device: torch.device, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + self.depth_root = Path(depth_root) + if not self.depth_root.is_dir(): + raise FileNotFoundError(f"深度目录不存在:{self.depth_root}") + self.image_height = image_height + self.image_width = image_width + self.device = device + self.clip_min = clip_min + self.clip_max = clip_max + self._cache: Dict[int, torch.Tensor] = {} + self._mask_cache: Dict[int, torch.Tensor] = {} + + def _file_for_index(self, view_index: int) -> Path: + return self.depth_root / f"depth_img_0_{view_index}.npy" + + def _load_numpy(self, file_path: Path) -> np.ndarray: + depth_np = np.load(file_path) + depth_np = np.squeeze(depth_np) + if depth_np.ndim != 2: + raise ValueError(f"{file_path} 深度数组维度异常:{depth_np.shape}") + if depth_np.shape != (self.image_height, self.image_width): + raise ValueError( + f"{file_path} 深度分辨率应为 {(self.image_height, self.image_width)},当前为 {depth_np.shape}" + ) + if self.clip_min is not None or self.clip_max is not None: + min_val = self.clip_min if self.clip_min is not None else None + max_val = self.clip_max if self.clip_max is not None else None + depth_np = np.clip( + depth_np, + min_val if min_val is not None else depth_np.min(), + max_val if max_val is not None else depth_np.max(), + ) + return depth_np.astype(np.float32) + + def get(self, view_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """返回 (depth_tensor, valid_mask),均在 GPU 上。""" + if view_index not in self._cache: + file_path = self._file_for_index(view_index) + if not file_path.is_file(): + raise FileNotFoundError(f"缺少深度文件:{file_path}") + depth_np = self._load_numpy(file_path) + depth_tensor = torch.from_numpy(depth_np).to(self.device) + valid_mask = torch.isfinite(depth_tensor) & (depth_tensor > 0.0) + if self.clip_min is not None: + valid_mask &= depth_tensor >= self.clip_min + if self.clip_max is not None: + valid_mask &= depth_tensor <= self.clip_max + self._cache[view_index] = depth_tensor + self._mask_cache[view_index] = valid_mask + return self._cache[view_index], self._mask_cache[view_index] + + def as_numpy(self, view_index: int) -> np.ndarray: + depth_tensor, _ = self.get(view_index) + return depth_tensor.detach().cpu().numpy() + + +class NormalGroundTruthCache: + """缓存以初始高斯生成的法线 GT,避免训练阶段重复渲染。""" + + def __init__( + self, + cache_dir: Path, + image_height: int, + image_width: int, + device: torch.device, + ) -> None: + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.image_height = image_height + self.image_width = image_width + self.device = device + self._memory_cache: Dict[int, torch.Tensor] = {} + + def _file_path(self, view_index: int) -> Path: + return self.cache_dir / f"normal_view_{view_index:04d}.npy" + + def has(self, view_index: int) -> bool: + return self._file_path(view_index).is_file() + + def store(self, view_index: int, normal_tensor: torch.Tensor) -> None: + normals_np = prepare_normals(normal_tensor) + np.save(self._file_path(view_index), normals_np.astype(np.float16)) + + def get(self, view_index: int) -> torch.Tensor: + if view_index not in self._memory_cache: + path = self._file_path(view_index) + if not path.is_file(): + raise FileNotFoundError( + f"未找到视角 {view_index} 的法线缓存:{path},请先完成预计算。" + ) + normals_np = np.load(path) + expected_shape = (self.image_height, self.image_width, 3) + if normals_np.shape != expected_shape: + raise ValueError( + f"{path} 法线缓存尺寸应为 {expected_shape},当前为 {normals_np.shape}" + ) + normals_tensor = ( + torch.from_numpy(normals_np.astype(np.float32)) + .permute(2, 0, 1) + .to(self.device) + ) + self._memory_cache[view_index] = normals_tensor + return self._memory_cache[view_index] + + def clear_memory_cache(self) -> None: + self._memory_cache.clear() + + def ensure_all(self, cameras: Sequence[Camera], render_fn) -> None: + total = len(cameras) + for idx, camera in enumerate(cameras): + if self.has(idx): + continue + with torch.no_grad(): + pkg = render_fn(camera) + normal_tensor = pkg["normal"] + self.store(idx, normal_tensor) + if (idx + 1) % 10 == 0 or idx + 1 == total: + print(f"[INFO] 预计算法线缓存 {idx + 1}/{total}") + + +def compute_depth_loss_tensor( + pred_depth: torch.Tensor, + gt_depth: torch.Tensor, + gt_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """返回深度 L1 损失及统计信息。""" + if pred_depth.dim() == 3: + pred_depth = pred_depth.squeeze(0) + if pred_depth.shape != gt_depth.shape: + raise ValueError( + f"预测深度尺寸 {pred_depth.shape} 与 GT {gt_depth.shape} 不一致。" + ) + valid_mask = gt_mask & torch.isfinite(pred_depth) & (pred_depth > 0.0) + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_depth.device) + stats = {"valid_px": 0, "mae": float("nan"), "rmse": float("nan")} + return zero, stats + diff = pred_depth - gt_depth + abs_diff = diff.abs() + loss = abs_diff[valid_mask].mean() + rmse = torch.sqrt((diff[valid_mask] ** 2).mean()) + stats = { + "valid_px": valid_pixels, + "mae": float(abs_diff[valid_mask].mean().detach().item()), + "rmse": float(rmse.detach().item()), + } + return loss, stats + + +def compute_normal_loss_tensor( + pred_normals: torch.Tensor, + gt_normals: torch.Tensor, + base_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """基于余弦相似度的法线损失。""" + if pred_normals.dim() != 3 or pred_normals.shape[0] != 3: + raise ValueError(f"预测法线维度应为 (3,H,W),当前为 {pred_normals.shape}") + if gt_normals.shape != pred_normals.shape: + raise ValueError( + f"法线 GT 尺寸 {gt_normals.shape} 与预测 {pred_normals.shape} 不一致。" + ) + gt_mask = torch.isfinite(gt_normals).all(dim=0) + pred_mask = torch.isfinite(pred_normals).all(dim=0) + valid_mask = base_mask & gt_mask & pred_mask + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_normals.device) + stats = {"valid_px": 0, "mean_cos": float("nan")} + return zero, stats + + pred_unit = F.normalize(pred_normals, dim=0) + gt_unit = F.normalize(gt_normals, dim=0) + cos_sim = (pred_unit * gt_unit).sum(dim=0).clamp(-1.0, 1.0) + loss_map = (1.0 - cos_sim) * valid_mask + loss = loss_map.sum() / valid_mask.sum() + stats = {"valid_px": valid_pixels, "mean_cos": float(cos_sim[valid_mask].mean().item())} + return loss, stats + + +class ManualScene: + """最小 Scene 封装,提供 mesh regularization 所需的接口。""" + + def __init__(self, cameras: Sequence[Camera]): + self._train_cameras = list(cameras) + + def getTrainCameras(self, scale: float = 1.0) -> List[Camera]: + return list(self._train_cameras) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ) -> List[Camera]: + return list(self._train_cameras) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + """构建训练/重建所需的渲染器:训练走 render_full 以获得精确深度梯度,SDF 仍沿用 RaDe-GS。""" + + def render_view(view: Camera) -> Dict[str, torch.Tensor]: + pkg = render_full( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + compute_expected_normals=False, + compute_expected_depth=True, + compute_accurate_median_depth_gradient=True, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def load_mesh_config(config_name: str) -> Dict[str, Any]: + """支持直接传文件路径或 configs/mesh/.yaml。""" + candidate = Path(config_name) + if not candidate.is_file(): + candidate = Path(__file__).resolve().parent / "configs" / "mesh" / f"{config_name}.yaml" + if not candidate.is_file(): + raise FileNotFoundError(f"无法找到 mesh 配置:{config_name}") + with candidate.open("r", encoding="utf-8") as fh: + return yaml.safe_load(fh) + + +def ensure_gaussian_occupancy(gaussians: GaussianModel) -> None: + """mesh regularization 依赖 9 维 occupancy 网格,此处在推理环境补齐缓冲。""" + needs_init = ( + not getattr(gaussians, "learn_occupancy", False) + or not hasattr(gaussians, "_occupancy_shift") + or gaussians._occupancy_shift.numel() == 0 + ) + if needs_init: + gaussians.learn_occupancy = True + base = torch.zeros((gaussians._xyz.shape[0], 9), device=gaussians._xyz.device) + shift = torch.zeros_like(base) + gaussians._base_occupancy = nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = nn.Parameter(shift.requires_grad_(True)) + + +def export_mesh_from_state( + gaussians: GaussianModel, + mesh_state: Dict[str, Any], + output_path: Path, + reference_camera: Optional[Camera] = None, +) -> None: + """根据当前 mesh_state 导出网格,并可选做视椎裁剪。""" + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + delaunay_tets = mesh_state.get("delaunay_tets") + + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_save = mesh if reference_camera is None else frustum_cull_mesh(mesh, reference_camera) + verts = mesh_to_save.verts.detach().cpu().numpy() + faces = mesh_to_save.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def quaternion_to_rotation_matrix(quaternion: Sequence[float]) -> np.ndarray: + """将单位四元数转换为 3x3 旋转矩阵。""" + # 这里显式转换 DISCOVERSE 导出的四元数,确保后续符合 MILo 的旋转约定 + q = np.asarray(quaternion, dtype=np.float64) + if q.shape != (4,): + raise ValueError(f"四元数需要包含 4 个分量,当前形状为 {q.shape}") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], + ] + ) + return rotation + + +def freeze_gaussian_model(model: GaussianModel) -> None: + """显式关闭高斯模型中参数的梯度。""" + # 推理阶段冻结高斯参数,后续循环只做前向评估 + tensor_attrs = [ + "_xyz", + "_features_dc", + "_features_rest", + "_scaling", + "_rotation", + "_opacity", + ] + for attr in tensor_attrs: + value = getattr(model, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + + +def prepare_depth_map(depth_tensor: torch.Tensor) -> np.ndarray: + """将深度张量转为二维 numpy 数组。""" + # 统一 squeeze 逻辑,防止 Matplotlib 因 shape 异常报错 + depth_np = depth_tensor.detach().cpu().numpy() + depth_np = np.squeeze(depth_np) + if depth_np.ndim == 1: + depth_np = np.expand_dims(depth_np, axis=0) + return depth_np + + +def prepare_normals(normal_tensor: torch.Tensor) -> np.ndarray: + """将法线张量转换为 HxWx3 的 numpy 数组。""" + # 兼容渲染输出为 (3,H,W) 或 (H,W,3) 的两种格式 + normals_np = normal_tensor.detach().cpu().numpy() + normals_np = np.squeeze(normals_np) + if normals_np.ndim == 3 and normals_np.shape[0] == 3: + normals_np = np.transpose(normals_np, (1, 2, 0)) + if normals_np.ndim == 2: + normals_np = normals_np[..., None] + return normals_np + + +def normals_to_rgb(normals: np.ndarray) -> np.ndarray: + """将 [-1,1] 范围的法线向量映射到 [0,1] 以便可视化。""" + normals = np.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0) + rgb = 0.5 * (normals + 1.0) + return np.clip(rgb, 0.0, 1.0).astype(np.float32) + + +def save_normal_visualization(normal_rgb: np.ndarray, output_path: Path) -> None: + """保存法线可视化图像。""" + plt.imsave(output_path, normal_rgb) + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, +) -> List[Camera]: + """按照 MILo 的视图约定读取相机文件,并进行坐标系转换。""" + # 自定义读取 DISCOVERSE 风格 JSON,统一转换到 COLMAP 世界->相机坐标 + pose_path = Path(json_path) + if not pose_path.is_file(): + raise FileNotFoundError(f"未找到相机 JSON:{json_path}") + + with pose_path.open("r", encoding="utf-8") as fh: + camera_list = json.load(fh) + + if isinstance(camera_list, dict): + for key in ("frames", "poses", "camera_poses"): + if key in camera_list and isinstance(camera_list[key], list): + camera_list = camera_list[key] + break + else: + raise ValueError(f"{json_path} 中的 JSON 结构不包含可识别的相机列表。") + + if not isinstance(camera_list, list) or not camera_list: + raise ValueError(f"{json_path} 中没有有效的相机条目。") + + fov_y = math.radians(fov_y_deg) + aspect_ratio = image_width / image_height + fov_x = 2.0 * math.atan(aspect_ratio * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(camera_list): + if "quaternion" in entry: + rotation_c2w = quaternion_to_rotation_matrix(entry["quaternion"]).astype(np.float32) + elif "rotation" in entry: + rotation_c2w = np.asarray(entry["rotation"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 quaternion 或 rotation。") + + if rotation_c2w.shape != (3, 3): + raise ValueError(f"相机条目 {idx} 的旋转矩阵形状应为 (3,3),实际为 {rotation_c2w.shape}") + + # DISCOVER-SE 的 quaternion/rotation 直接导入后,渲染出来的 PNG 会上下翻转, + # 说明其前进方向仍是 OpenGL 的 -Z。通过右乘 diag(1,-1,-1) 将其显式转换到 + # MILo/colmap 的坐标系,使得后续投影矩阵与深度图一致。 + rotation_c2w = rotation_c2w @ OPENGL_TO_COLMAP + + if "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"相机条目 {idx} 的 position 应为 3 维向量,实际为 {camera_center.shape}") + rotation_w2c = rotation_c2w.T + translation = (-rotation_w2c @ camera_center).astype(np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + # 如果 JSON 已直接存储 colmap 风格的 T(即世界到相机),这里假设它与旋转 + # 一样来自 OpenGL 坐标。严格来说也应执行同样的轴变换,但现有数据集只有 + # position 字段;为避免重复转换,这里只做类型检查并保留原值。 + elif "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 position/translation/tvec 信息。") + + if translation.shape != (3,): + raise ValueError(f"相机条目 {idx} 的平移向量应为长度 3,实际为 {translation.shape}") + + image_name = ( + entry.get("name") + or entry.get("image_name") + or entry.get("img_name") + or f"view_{idx:04d}" + ) + + camera = Camera( + colmap_id=str(idx), + R=rotation_c2w, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device="cuda", + ) + cameras.append(camera) + return cameras + + +def save_heatmap(data: np.ndarray, output_path: Path, title: str) -> None: + """将二维数据保存为热力图,便于直观观察差异。""" + # 迭代间深度 / 法线差分可视化,快速定位局部变化 + plt.figure(figsize=(6, 4)) + plt.imshow(data, cmap="inferno") + plt.title(title) + plt.colorbar() + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + +def main(): + parser = ArgumentParser(description="桥梁场景高斯到网格迭代分析脚本") + parser.add_argument("--num_iterations", type=int, default=5, help="执行循环的次数") + parser.add_argument("--ma_beta", type=float, default=0.8, help="loss 滑动平均系数") + parser.add_argument( + "--depth_loss_weight", type=float, default=0.3, help="深度一致性项权重(默认 0.3)" + ) + parser.add_argument( + "--normal_loss_weight", type=float, default=0.05, help="法线一致性项权重(默认 0.05)" + ) + parser.add_argument( + "--lr", + "--learning_rate", + dest="lr", + type=float, + default=1e-3, + help="XYZ 学习率(默认 1e-3)", + ) + parser.add_argument( + "--shape_lr", + type=float, + default=5e-4, + help="缩放/旋转/不透明度的学习率(默认 5e-4)", + ) + parser.add_argument( + "--delaunay_reset_interval", + type=int, + default=1000, + help="每隔多少次迭代重建一次 Delaunay(<=0 表示每次重建)", + ) + parser.add_argument( + "--mesh_config", + type=str, + default="medium", + help="mesh 配置名称或路径(默认 medium)", + ) + parser.add_argument( + "--save_interval", + type=int, + default=None, + help="保存可视化/npz 的间隔,默认与 Delaunay 重建间隔相同", + ) + parser.add_argument( + "--heatmap_dir", + type=str, + default="yufu2mesh_outputs", + help="保存热力图等输出的目录", + ) + parser.add_argument( + "--depth_gt_dir", + type=str, + default="/home/zoyo/Desktop/MILo_rtx50/milo/data/bridge_clean/depth", + help="Discoverse 深度 npy 所在目录", + ) + parser.add_argument( + "--depth_clip_min", + type=float, + default=0.0, + help="深度最小裁剪值,<=0 表示不裁剪", + ) + parser.add_argument( + "--depth_clip_max", + type=float, + default=None, + help="深度最大裁剪值,None 表示不裁剪", + ) + parser.add_argument( + "--normal_cache_dir", + type=str, + default=None, + help="法线缓存目录,默认为 runs//normal_gt", + ) + parser.add_argument( + "--skip_normal_gt_generation", + action="store_true", + help="已存在缓存时跳过初始法线 GT 预计算", + ) + parser.add_argument("--seed", type=int, default=0, help="控制随机性的种子") + parser.add_argument( + "--lock_view_index", + type=int, + default=None, + help="固定视角索引,仅在指定时输出热力图", + ) + parser.add_argument( + "--log_interval", + type=int, + default=100, + help="控制迭代日志的打印频率(默认每次迭代打印)", + ) + parser.add_argument( + "--warn_until_iter", + type=int, + default=3000, + help="surface sampling warmup 迭代数(用于 mesh downsample)", + ) + parser.add_argument( + "--imp_metric", + type=str, + default="outdoor", + choices=["outdoor", "indoor"], + help="surface sampling 的重要性度量类型", + ) + parser.add_argument( + "--mesh_start_iter", + type=int, + default=2000, + help="mesh 正则起始迭代(默认 2000,避免冷启动阶段干扰)", + ) + parser.add_argument( + "--mesh_update_interval", + type=int, + default=5, + help="mesh 正则重建/回传间隔,>1 可减少 DMTet 抖动(默认 5)", + ) + parser.add_argument( + "--mesh_depth_weight", + type=float, + default=0.1, + help="mesh 深度项权重覆盖(默认 0.1,原配置通常为 0.05)", + ) + parser.add_argument( + "--mesh_normal_weight", + type=float, + default=0.1, + help="mesh 法线项权重覆盖(默认 0.1,原配置通常为 0.05)", + ) + parser.add_argument( + "--disable_shape_training", + action="store_true", + help="禁用缩放/旋转/不透明度的优化,仅用于调试", + ) + pipe = PipelineParams(parser) + args = parser.parse_args() + + pipe.debug = getattr(args, "debug", False) + + # 所有输出固定写入 milo/runs/ 下,便于管理实验产物 + base_run_dir = Path(__file__).resolve().parent / "runs" + output_dir = base_run_dir / args.heatmap_dir + output_dir.mkdir(parents=True, exist_ok=True) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + ply_path = "/home/zoyo/Desktop/MILo_rtx50/milo/data/bridge_clean/yufu_bridge_cleaned.ply" + camera_poses_json = "/home/zoyo/Desktop/MILo_rtx50/milo/data/bridge_clean/camera_poses_cam1.json" + + gaussians = GaussianModel(sh_degree=0, learn_occupancy=True) + gaussians.load_ply(ply_path) + freeze_attrs = [ + "_features_dc", + "_features_rest", + "_base_occupancy", + "_occupancy_shift", + ] + for attr in freeze_attrs: + value = getattr(gaussians, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + gaussians._xyz.requires_grad_(True) + + shape_trainable = [] + shape_attr_list = ["_scaling", "_rotation", "_opacity"] + if not args.disable_shape_training: + for attr in shape_attr_list: + value = getattr(gaussians, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(True) + shape_trainable.append(value) + else: + for attr in shape_attr_list: + value = getattr(gaussians, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + + height = 720 + width = 1280 + fov_y_deg = 75.0 + + train_cameras = load_cameras_from_json( + json_path=camera_poses_json, + image_height=height, + image_width=width, + fov_y_deg=fov_y_deg, + ) + num_views = len(train_cameras) + print(f"[INFO] 成功加载 {num_views} 个相机视角。") + + device = gaussians._xyz.device + background = torch.tensor([0.0, 0.0, 0.0], device=device) + + mesh_config = load_mesh_config(args.mesh_config) + mesh_config["start_iter"] = max(0, args.mesh_start_iter) + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", args.num_iterations), args.num_iterations) + mesh_config["mesh_update_interval"] = max(1, args.mesh_update_interval) + mesh_config["delaunay_reset_interval"] = args.delaunay_reset_interval + mesh_config["depth_weight"] = args.mesh_depth_weight + mesh_config["normal_weight"] = args.mesh_normal_weight + # 这里默认沿用 surface 采样以对齐训练阶段;如仅需快速分析,也可以切换为 random 提升速度。 + mesh_config["delaunay_sampling_method"] = "surface" + + scene_wrapper = ManualScene(train_cameras) + + ensure_gaussian_occupancy(gaussians) + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + + depth_clip_min = args.depth_clip_min if args.depth_clip_min > 0.0 else None + depth_clip_max = args.depth_clip_max + depth_provider = DepthProvider( + depth_root=Path(args.depth_gt_dir), + image_height=height, + image_width=width, + device=device, + clip_min=depth_clip_min, + clip_max=depth_clip_max, + ) + + normal_cache_dir = Path(args.normal_cache_dir) if args.normal_cache_dir else (output_dir / "normal_gt") + normal_cache = NormalGroundTruthCache( + cache_dir=normal_cache_dir, + image_height=height, + image_width=width, + device=device, + ) + if args.skip_normal_gt_generation: + missing = [idx for idx in range(num_views) if not normal_cache.has(idx)] + if missing: + raise RuntimeError( + f"跳过法线 GT 预计算被拒绝,仍有 {len(missing)} 个视角缺少缓存(示例 {missing[:5]})。" + ) + else: + print("[INFO] 开始预计算初始法线 GT(仅进行一次,若存在缓存会自动跳过)。") + normal_cache.ensure_all(train_cameras, render_view) + normal_cache.clear_memory_cache() + + mesh_renderer, mesh_state = initialize_mesh_regularization(scene_wrapper, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + param_groups = [{"params": [gaussians._xyz], "lr": args.lr}] + if shape_trainable: + param_groups.append({"params": shape_trainable, "lr": args.shape_lr}) + optimizer = torch.optim.Adam(param_groups) + mesh_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + ) + + # 记录整个迭代过程中的指标与梯度,结束时统一写入 npz/曲线 + stats_history: Dict[str, List[float]] = { + "iteration": [], + "depth_loss": [], + "normal_loss": [], + "mesh_loss": [], + "mesh_depth_loss": [], + "mesh_normal_loss": [], + "occupied_centers_loss": [], + "occupancy_labels_loss": [], + "depth_mae": [], + "depth_rmse": [], + "normal_mean_cos": [], + "normal_valid_px": [], + "grad_norm": [], + } + + moving_loss = None + previous_depth: Dict[int, np.ndarray] = {} + previous_normals: Dict[int, np.ndarray] = {} + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + save_interval = args.save_interval if args.save_interval is not None else args.delaunay_reset_interval + if save_interval is None or save_interval <= 0: + save_interval = 1 + + for iteration in range(args.num_iterations): + optimizer.zero_grad(set_to_none=True) + if args.lock_view_index is not None: + view_index = args.lock_view_index % num_views + else: + if not camera_stack: + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + view_index = camera_stack.pop() + viewpoint = train_cameras[view_index] + + training_pkg = render_view(viewpoint) + gt_depth_tensor, gt_depth_mask = depth_provider.get(view_index) + depth_loss_tensor, depth_stats = compute_depth_loss_tensor( + pred_depth=training_pkg["median_depth"], + gt_depth=gt_depth_tensor, + gt_mask=gt_depth_mask, + ) + gt_normals_tensor = normal_cache.get(view_index) + normal_loss_tensor, normal_stats = compute_normal_loss_tensor( + pred_normals=training_pkg["normal"], + gt_normals=gt_normals_tensor, + base_mask=gt_depth_mask, + ) + def _zero_mesh_pkg() -> Dict[str, Any]: + zero = torch.zeros((), device=device) + depth_zero = torch.zeros_like(training_pkg["median_depth"]) + normal_zero = torch.zeros_like(training_pkg["normal"].permute(1, 2, 0)) + return { + "mesh_loss": zero, + "mesh_depth_loss": zero, + "mesh_normal_loss": zero, + "occupied_centers_loss": zero, + "occupancy_labels_loss": zero, + "updated_state": mesh_state, + "mesh_render_pkg": { + "depth": depth_zero, + "normals": normal_zero, + }, + } + + mesh_active = iteration >= mesh_config["start_iter"] + if mesh_active: + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=training_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_index, + gaussians=gaussians, + scene=scene_wrapper, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=1.0, + args=mesh_args, + integrate_func=integrate_radegs, + ) + else: + mesh_pkg = _zero_mesh_pkg() + + mesh_state = mesh_pkg["updated_state"] + mesh_loss_tensor = mesh_pkg["mesh_loss"] + total_loss = ( + args.depth_loss_weight * depth_loss_tensor + + args.normal_loss_weight * normal_loss_tensor + + mesh_loss_tensor + ) + depth_loss_value = float(depth_loss_tensor.detach().item()) + normal_loss_value = float(normal_loss_tensor.detach().item()) + mesh_loss_value = float(mesh_loss_tensor.detach().item()) + loss_value = float(total_loss.detach().item()) + + if total_loss.requires_grad: + total_loss.backward() + grad_norm = float(gaussians._xyz.grad.detach().norm().item()) + optimizer.step() + else: + optimizer.zero_grad(set_to_none=True) + grad_norm = float("nan") + + mesh_render_pkg = mesh_pkg["mesh_render_pkg"] + mesh_depth_map = prepare_depth_map(mesh_render_pkg["depth"]) + mesh_normals_map = prepare_normals(mesh_render_pkg["normals"]) + gaussian_depth_map = prepare_depth_map(training_pkg["median_depth"]) + gaussian_normals_map = prepare_normals(training_pkg["normal"]) + gt_depth_map = depth_provider.as_numpy(view_index) + gt_normals_map = prepare_normals(gt_normals_tensor) + + mesh_valid = np.isfinite(mesh_depth_map) & (mesh_depth_map > 0.0) + gaussian_valid = np.isfinite(gaussian_depth_map) & (gaussian_depth_map > 0.0) + gt_valid = np.isfinite(gt_depth_map) & (gt_depth_map > 0.0) + overlap_mask = gaussian_valid & gt_valid + + depth_delta = gaussian_depth_map - gt_depth_map + if overlap_mask.any(): + delta_abs = np.abs(depth_delta[overlap_mask]) + diff_mean = float(delta_abs.mean()) + diff_max = float(delta_abs.max()) + diff_rmse = float(np.sqrt(np.mean(depth_delta[overlap_mask] ** 2))) + else: + diff_mean = diff_max = diff_rmse = float("nan") + + mesh_depth_loss = float(mesh_pkg["mesh_depth_loss"].item()) + mesh_normal_loss = float(mesh_pkg["mesh_normal_loss"].item()) + occupied_loss = float(mesh_pkg["occupied_centers_loss"].item()) + labels_loss = float(mesh_pkg["occupancy_labels_loss"].item()) + + moving_loss = ( + loss_value + if moving_loss is None + else args.ma_beta * moving_loss + (1 - args.ma_beta) * loss_value + ) + + stats_history["iteration"].append(float(iteration)) + stats_history["depth_loss"].append(depth_loss_value) + stats_history["normal_loss"].append(normal_loss_value) + stats_history["mesh_loss"].append(mesh_loss_value) + stats_history["mesh_depth_loss"].append(mesh_depth_loss) + stats_history["mesh_normal_loss"].append(mesh_normal_loss) + stats_history["occupied_centers_loss"].append(occupied_loss) + stats_history["occupancy_labels_loss"].append(labels_loss) + stats_history["depth_mae"].append(depth_stats["mae"]) + stats_history["depth_rmse"].append(depth_stats["rmse"]) + stats_history["normal_mean_cos"].append(normal_stats["mean_cos"]) + stats_history["normal_valid_px"].append(float(normal_stats["valid_px"])) + stats_history["grad_norm"].append(grad_norm) + + def _fmt(value: float) -> str: + return f"{value:.6f}" + + if (iteration % max(1, args.log_interval) == 0) or (iteration == args.num_iterations - 1): + print( + "[INFO] Iter {iter:02d} | loss={total} (depth={depth}, normal={normal}, mesh={mesh}) | ma_loss={ma}".format( + iter=iteration, + total=_fmt(loss_value), + depth=_fmt(depth_loss_value), + normal=_fmt(normal_loss_value), + mesh=_fmt(mesh_loss_value), + ma=f"{moving_loss:.6f}", + ) + ) + + should_save = (save_interval <= 0) or (iteration % save_interval == 0) + if should_save: + valid_values: List[np.ndarray] = [] + if mesh_valid.any(): + valid_values.append(mesh_depth_map[mesh_valid].reshape(-1)) + if gaussian_valid.any(): + valid_values.append(gaussian_depth_map[gaussian_valid].reshape(-1)) + if gt_valid.any(): + valid_values.append(gt_depth_map[gt_valid].reshape(-1)) + if valid_values: + all_valid = np.concatenate(valid_values) + shared_min = float(all_valid.min()) + shared_max = float(all_valid.max()) + else: + shared_min, shared_max = 0.0, 1.0 + + gaussian_depth_vis_path = output_dir / f"gaussian_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gaussian_depth_vis_path, + gaussian_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + depth_vis_path = output_dir / f"mesh_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + depth_vis_path, + mesh_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + gt_depth_vis_path = output_dir / f"gt_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gt_depth_vis_path, + gt_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + normal_vis_path = output_dir / f"mesh_normal_vis_iter_{iteration:02d}.png" + mesh_normals_rgb = normals_to_rgb(mesh_normals_map) + save_normal_visualization(mesh_normals_rgb, normal_vis_path) + + gaussian_normal_vis_path = output_dir / f"gaussian_normal_vis_iter_{iteration:02d}.png" + gaussian_normals_rgb = normals_to_rgb(gaussian_normals_map) + save_normal_visualization(gaussian_normals_rgb, gaussian_normal_vis_path) + + gt_normal_vis_path = output_dir / f"gt_normal_vis_iter_{iteration:02d}.png" + gt_normals_rgb = normals_to_rgb(gt_normals_map) + save_normal_visualization(gt_normals_rgb, gt_normal_vis_path) + + output_npz = output_dir / f"mesh_render_iter_{iteration:02d}.npz" + np.savez( + output_npz, + mesh_depth=mesh_depth_map, + gaussian_depth=gaussian_depth_map, + depth_gt=gt_depth_map, + mesh_normals=mesh_normals_map, + gaussian_normals=gaussian_normals_map, + normal_gt=gt_normals_map, + depth_loss=depth_loss_value, + normal_loss=normal_loss_value, + mesh_loss=mesh_loss_value, + mesh_depth_loss=mesh_depth_loss, + mesh_normal_loss=mesh_normal_loss, + occupied_centers_loss=occupied_loss, + occupancy_labels_loss=labels_loss, + loss=loss_value, + moving_loss=moving_loss, + depth_mae=depth_stats["mae"], + depth_rmse=depth_stats["rmse"], + normal_valid_px=normal_stats["valid_px"], + normal_mean_cos=normal_stats["mean_cos"], + grad_norm=grad_norm, + iteration=iteration, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + if overlap_mask.any(): + depth_diff_vis = np.zeros_like(gaussian_depth_map) + depth_diff_vis[overlap_mask] = depth_delta[overlap_mask] + save_heatmap( + np.abs(depth_diff_vis), + output_dir / f"depth_diff_iter_{iteration:02d}.png", + f"|Pred-GT| iter {iteration}", + ) + + composite_path = output_dir / f"comparison_iter_{iteration:02d}.png" + fig, axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_depth, ax_gaussian_depth, ax_mesh_depth = axes[0] + ax_gt_normals, ax_gaussian_normals, ax_mesh_normals = axes[1] + + im0 = ax_gt_depth.imshow(gt_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gt_depth.set_title("GT depth") + ax_gt_depth.axis("off") + fig.colorbar(im0, ax=ax_gt_depth, fraction=0.046, pad=0.04) + + im1 = ax_gaussian_depth.imshow(gaussian_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gaussian_depth.set_title("Gaussian depth") + ax_gaussian_depth.axis("off") + fig.colorbar(im1, ax=ax_gaussian_depth, fraction=0.046, pad=0.04) + + im2 = ax_mesh_depth.imshow(mesh_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_mesh_depth.set_title("Mesh depth") + ax_mesh_depth.axis("off") + fig.colorbar(im2, ax=ax_mesh_depth, fraction=0.046, pad=0.04) + + ax_gt_normals.imshow(gt_normals_rgb) + ax_gt_normals.set_title("GT normals") + ax_gt_normals.axis("off") + + ax_gaussian_normals.imshow(gaussian_normals_rgb) + ax_gaussian_normals.set_title("Gaussian normals") + ax_gaussian_normals.axis("off") + + ax_mesh_normals.imshow(mesh_normals_rgb) + ax_mesh_normals.set_title("Mesh normals") + ax_mesh_normals.axis("off") + + info_lines = [ + f"Iteration: {iteration:02d}", + f"View index: {view_index}", + f"GT depth valid px: {int(gt_valid.sum())}", + f"Gaussian depth valid px: {int(gaussian_valid.sum())}", + f"|Pred - GT| mean={diff_mean:.3f}, max={diff_max:.3f}, RMSE={diff_rmse:.3f}", + f"Depth loss={_fmt(depth_loss_value)} (w={args.depth_loss_weight:.2f}, mae={depth_stats['mae']:.3f}, rmse={depth_stats['rmse']:.3f})", + f"Normal loss={_fmt(normal_loss_value)} (w={args.normal_loss_weight:.2f}, px={normal_stats['valid_px']}, cos={normal_stats['mean_cos']:.3f})", + f"Mesh loss={_fmt(mesh_loss_value)}", + f"Mesh depth loss={_fmt(mesh_depth_loss)} mesh normal loss={_fmt(mesh_normal_loss)}", + f"Occupied centers={_fmt(occupied_loss)} labels={_fmt(labels_loss)}", + ] + fig.suptitle("\n".join(info_lines), fontsize=12, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.94]) + fig.savefig(composite_path, dpi=300) + plt.close(fig) + + if args.lock_view_index is not None: + if view_index in previous_depth: + depth_diff = np.abs(gaussian_depth_map - previous_depth[view_index]) + save_heatmap( + depth_diff, + output_dir / f"depth_diff_iter_{iteration:02d}_temporal.png", + f"Depth Δ iter {iteration}", + ) + if view_index in previous_normals: + normal_delta = gaussian_normals_map - previous_normals[view_index] + if normal_delta.ndim == 3: + normal_diff = np.linalg.norm(normal_delta, axis=-1) + else: + normal_diff = np.abs(normal_delta) + save_heatmap( + normal_diff, + output_dir / f"normal_diff_iter_{iteration:02d}_temporal.png", + f"Normal Δ iter {iteration}", + ) + + with torch.no_grad(): + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=output_dir / f"mesh_iter_{iteration:02d}.ply", + reference_camera=None, + ) + + if args.lock_view_index is not None: + previous_depth[view_index] = gaussian_depth_map + previous_normals[view_index] = gaussian_normals_map + with torch.no_grad(): + # 输出完整指标轨迹及汇总曲线,方便任务结束后快速复盘 + history_npz = output_dir / "metrics_history.npz" + np.savez( + history_npz, + **{k: np.asarray(v, dtype=np.float32) for k, v in stats_history.items()}, + ) + summary_fig = output_dir / "metrics_summary.png" + if stats_history["iteration"]: + fig, axes = plt.subplots(2, 2, figsize=(16, 10), dpi=200) + iters = np.asarray(stats_history["iteration"]) + axes[0, 0].plot(iters, stats_history["depth_loss"], label="depth") + axes[0, 0].plot(iters, stats_history["normal_loss"], label="normal") + axes[0, 0].plot(iters, stats_history["mesh_loss"], label="mesh") + axes[0, 0].set_title("Total losses") + axes[0, 0].set_xlabel("Iteration") + axes[0, 0].legend() + + axes[0, 1].plot(iters, stats_history["mesh_depth_loss"], label="mesh depth") + axes[0, 1].plot(iters, stats_history["mesh_normal_loss"], label="mesh normal") + axes[0, 1].plot(iters, stats_history["occupied_centers_loss"], label="occupied centers") + axes[0, 1].plot(iters, stats_history["occupancy_labels_loss"], label="occupancy labels") + axes[0, 1].set_title("Mesh regularization components") + axes[0, 1].set_xlabel("Iteration") + axes[0, 1].legend() + + axes[1, 0].plot(iters, stats_history["depth_mae"], label="depth MAE") + axes[1, 0].plot(iters, stats_history["depth_rmse"], label="depth RMSE") + axes[1, 0].set_title("Depth metrics") + axes[1, 0].set_xlabel("Iteration") + axes[1, 0].legend() + + axes[1, 1].plot(iters, stats_history["normal_mean_cos"], label="mean cos") + axes[1, 1].plot(iters, stats_history["grad_norm"], label="grad norm") + axes[1, 1].set_title("Normals / Gradients") + axes[1, 1].set_xlabel("Iteration") + axes[1, 1].legend() + + fig.tight_layout() + fig.savefig(summary_fig) + plt.close(fig) + print(f"[INFO] 已保存曲线汇总:{summary_fig}") + print(f"[INFO] 记录所有迭代指标到 {history_npz}") + final_mesh_path = output_dir / "mesh_final.ply" + final_gaussian_path = output_dir / "gaussians_final.ply" + print(f"[INFO] 导出最终 mesh 到 {final_mesh_path}") + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=final_mesh_path, + reference_camera=None, + ) + print(f"[INFO] 导出最终高斯到 {final_gaussian_path}") + gaussians.save_ply(str(final_gaussian_path)) + print("[INFO] 循环结束,所有结果已写入输出目录。") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 380ca65..ff9da00 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,19 @@ +# Core dependencies +torch==2.7.1 +torchvision==0.22.1 +torchaudio==2.7.1 +numpy==2.3.3 +pillow==11.3.0 +scipy==1.16.2 + +# Computer Vision and 3D processing open3d==0.19.0 trimesh==4.6.8 +pymeshlab==2025.7 scikit-image==0.24.0 opencv-python==4.11.0.86 plyfile==1.1 -tqdm==4.67.1 \ No newline at end of file + +# Utilities +tqdm==4.67.1 +matplotlib==3.10.7 \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h index 3ff1b54..344b618 100755 --- a/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h @@ -10,6 +10,9 @@ */ #pragma once +#include +#include +#include #include #include diff --git a/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h b/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h index 647f1a9..5e924fa 100644 --- a/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h +++ b/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h @@ -10,6 +10,9 @@ */ #pragma once +#include +#include +#include #include #include diff --git a/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h b/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h index 5c76881..2a1f616 100644 --- a/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h +++ b/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h @@ -10,6 +10,9 @@ */ #pragma once +#include +#include +#include #include #include diff --git a/submodules/simple-knn/simple_knn.cu b/submodules/simple-knn/simple_knn.cu index e72e4c9..b998aaf 100644 --- a/submodules/simple-knn/simple_knn.cu +++ b/submodules/simple-knn/simple_knn.cu @@ -1,3 +1,5 @@ +#include +#include /* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco @@ -20,7 +22,6 @@ #include #include #include -#define __CUDACC__ #include #include diff --git a/submodules/simple-knn/spatial.cu b/submodules/simple-knn/spatial.cu index 1a6a654..207fc3a 100644 --- a/submodules/simple-knn/spatial.cu +++ b/submodules/simple-knn/spatial.cu @@ -20,7 +20,7 @@ distCUDA2(const torch::Tensor& points) auto float_opts = points.options().dtype(torch::kFloat32); torch::Tensor means = torch::full({P}, 0.0, float_opts); - SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); + SimpleKNN::knn(P, (float3*)points.contiguous().data_ptr(), means.contiguous().data_ptr()); return means; } \ No newline at end of file diff --git a/submodules/tetra_triangulation/src/force_abi.h b/submodules/tetra_triangulation/src/force_abi.h new file mode 100644 index 0000000..db5ecf3 --- /dev/null +++ b/submodules/tetra_triangulation/src/force_abi.h @@ -0,0 +1,38 @@ +// force_abi.h +// ----------------------------------------------------------------------------- +// Force GNU libstdc++ to use the *new* C++11 ABI (_GLIBCXX_USE_CXX11_ABI = 1) +// for this translation unit. **Must be included before ANY standard headers.** +// +// Why? +// - PyTorch ≥ 2.6 official binaries are built with the new C++11 ABI (=1). +// - If your extension is compiled with the old ABI (=0), you’ll hit runtime +// linker errors such as undefined symbol: c10::detail::torchCheckFail(...RKSs). +// +// Usage: +// Place this as the VERY FIRST include in each .cpp that builds your extension: +// #include "force_abi.h" +// #include +// ... +// ----------------------------------------------------------------------------- + +#pragma once + +// Only meaningful for GCC's libstdc++; harmless elsewhere. +#if defined(__GNUC__) && !defined(_LIBCPP_VERSION) + +// If the macro was already defined (e.g., by compiler flags), reset it first. +# ifdef _GLIBCXX_USE_CXX11_ABI +# undef _GLIBCXX_USE_CXX11_ABI +# endif +// Enforce the new (C++11) ABI. +# define _GLIBCXX_USE_CXX11_ABI 1 + +// Optional sanity check: if some libstdc++ internals are already visible, +// it likely means a standard header slipped in before this file. In that case +// overriding the ABI here won't affect those already-included headers. +# if defined(_GLIBCXX_RELEASE) || defined(__GLIBCXX__) || defined(_GLIBCXX_BEGIN_NAMESPACE_VERSION) +# warning "force_abi.h should be included BEFORE any standard library headers." +# endif + +#endif // defined(__GNUC__) && !defined(_LIBCPP_VERSION) + diff --git a/submodules/tetra_triangulation/src/py_binding.cpp b/submodules/tetra_triangulation/src/py_binding.cpp index 0b2d43f..8fafa7e 100755 --- a/submodules/tetra_triangulation/src/py_binding.cpp +++ b/submodules/tetra_triangulation/src/py_binding.cpp @@ -1,3 +1,4 @@ +#include "force_abi.h" #include #include diff --git a/submodules/tetra_triangulation/src/triangulation.cpp b/submodules/tetra_triangulation/src/triangulation.cpp index 8044295..f083085 100755 --- a/submodules/tetra_triangulation/src/triangulation.cpp +++ b/submodules/tetra_triangulation/src/triangulation.cpp @@ -1,3 +1,4 @@ +#include "force_abi.h" #include "triangulation.h" #include @@ -66,4 +67,4 @@ std::vector triangulate(size_t num_points, float3* points) { // 0, max_depth); return cells; -} \ No newline at end of file +}