diff --git a/.gitignore b/.gitignore index 544d455..dbb7ff0 100755 --- a/.gitignore +++ b/.gitignore @@ -206,4 +206,13 @@ dmypy.json # Pyre type checker .pyre/ -learnableearthparser/fast_sampler/_sampler.c \ No newline at end of file +learnableearthparser/fast_sampler/_sampler.c +/.idea/inspectionProfiles/profiles_settings.xml +/.idea/.gitignore +/.idea/MILo_rtx50.iml +/.idea/misc.xml +/.idea/modules.xml +/.idea/vcs.xml +milo/data/* +!milo/data/.gitkeep +/milo/runs/* diff --git a/README.md b/README.md old mode 100755 new mode 100644 index b9ced4a..4c35562 --- a/README.md +++ b/README.md @@ -1,654 +1,484 @@ -
Iterable[Path]: + if root.is_file(): + if root.suffix.lower() in VALID_EXTENSIONS: + yield root + return + + if not root.is_dir(): + print(f"[WARN] {root} 既不是文件也不是目录,跳过。") + return + + pattern = "**/*" if recursive else "*" + for path in root.glob(pattern): + if path.is_file() and path.suffix.lower() in VALID_EXTENSIONS: + yield path + + +def convert_to_convex_hull( + mesh_path: Path, + suffix: str, + overwrite: bool, +) -> Path: + mesh = trimesh.load(mesh_path, force="mesh", process=False) + + convex = mesh.convex_hull + output_path = mesh_path.with_name(f"{mesh_path.stem}{suffix}{mesh_path.suffix}") + if output_path.exists() and not overwrite: + raise FileExistsError(f"{output_path} 已存在,使用 --overwrite 以覆盖。") + convex.export(output_path) + return output_path + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="将单个 mesh 或目录中的所有 mesh 转为凸包表示,文件名追加 `_convex`。", + ) + parser.add_argument("input_path", type=Path, help="mesh 文件或目录路径。") + parser.add_argument( + "--recursive", + action="store_true", + help="若 input_path 为目录,递归遍历所有子目录。", + ) + parser.add_argument( + "--suffix", + default="_convex", + help="输出文件名后缀(默认 _convex)。", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="若目标文件已存在则覆盖。", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + mesh_paths = list(iter_mesh_paths(args.input_path, args.recursive)) + if not mesh_paths: + print(f"[WARN] 在 {args.input_path} 下找不到 mesh 文件(支持扩展名: {', '.join(VALID_EXTENSIONS)})。") + return 1 + + for path in mesh_paths: + try: + output = convert_to_convex_hull( + path, + suffix=args.suffix, + overwrite=args.overwrite, + ) + print(f"[INFO] {path} -> {output}") + except Exception as exc: # noqa: BLE001 + print(f"[ERROR] 处理 {path} 失败:{exc}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/milo/scripts/run_depth_sweep.py b/milo/scripts/run_depth_sweep.py new file mode 100644 index 0000000..67f8c60 --- /dev/null +++ b/milo/scripts/run_depth_sweep.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +Run a small sweep of depth-only training configurations. + +Each configuration executes `milo/depth_train.py` with 1000 iterations by default, +optionally restricting training to a single camera for reproducibility. After the +runs finish, a compact summary of the final metrics is emitted to stdout and +saved under `output/depth_sweep_summary.txt`. +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + + +DEFAULT_CONFIGS: List[Dict] = [ + { + "name": "lr0.7_clip7_dense", + "initial_lr_scale": 0.70, + "depth_clip_min": 0.1, + "depth_clip_max": 7.0, + "enable_densification": True, + }, + { + "name": "lr0.8_clip7_dense", + "initial_lr_scale": 0.80, + "depth_clip_min": 0.2, + "depth_clip_max": 7.0, + "enable_densification": True, + }, + { + "name": "lr0.9_clip6p5_dense", + "initial_lr_scale": 0.90, + "depth_clip_min": 0.3, + "depth_clip_max": 6.5, + "enable_densification": True, + }, + { + "name": "lr1.0_clip6_dense", + "initial_lr_scale": 1.00, + "depth_clip_min": 0.3, + "depth_clip_max": 6.0, + "enable_densification": True, + }, + { + "name": "lr1.15_clip6_dense", + "initial_lr_scale": 1.15, + "depth_clip_min": 0.4, + "depth_clip_max": 6.0, + "enable_densification": True, + }, + { + "name": "lr1.3_clip5p5_dense", + "initial_lr_scale": 1.30, + "depth_clip_min": 0.5, + "depth_clip_max": 5.5, + "enable_densification": True, + }, + { + "name": "lr0.75_clip7_no_dense", + "initial_lr_scale": 0.75, + "depth_clip_min": 0.1, + "depth_clip_max": 7.0, + "enable_densification": False, + }, + { + "name": "lr1.0_clip6_no_dense", + "initial_lr_scale": 1.0, + "depth_clip_min": 0.3, + "depth_clip_max": 6.0, + "enable_densification": False, + }, +] + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Sweep depth training configurations.") + parser.add_argument("--ply_path", required=True, type=Path, help="Input Gaussian PLY.") + parser.add_argument("--camera_poses", required=True, type=Path, help="Camera JSON file.") + parser.add_argument("--depth_dir", required=True, type=Path, help="Directory of depth .npy files.") + parser.add_argument("--output_root", type=Path, default=Path("runs/depth_sweep"), help="Base directory for sweep outputs.") + parser.add_argument("--iterations", type=int, default=1000, help="Iterations per configuration.") + parser.add_argument("--fixed_view_idx", type=int, default=0, help="Camera index to lock during sweep (-1 = random shuffling).") + parser.add_argument("--cuda_blocking", action="store_true", help="Set CUDA_LAUNCH_BLOCKING=1 for each run.") + parser.add_argument("--extra_arg", action="append", default=[], help="Extra CLI arguments passed verbatim to depth_train.py.") + parser.add_argument("--resume_if_exists", action="store_true", help="Skip configs whose output directory already exists.") + return parser + + +def ensure_directory(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def run_depth_train( + script_path: Path, + cfg: Dict, + args: argparse.Namespace, + run_dir: Path, +) -> int: + cmd: List[str] = [ + sys.executable, + str(script_path), + "--ply_path", + str(args.ply_path), + "--camera_poses", + str(args.camera_poses), + "--depth_dir", + str(args.depth_dir), + "--output_dir", + str(run_dir), + "--iterations", + str(args.iterations), + "--initial_lr_scale", + str(cfg["initial_lr_scale"]), + "--log_depth_stats", + ] + if cfg.get("depth_clip_min", 0.0) > 0.0: + cmd.extend(["--depth_clip_min", str(cfg["depth_clip_min"])]) + if cfg.get("depth_clip_max") is not None: + cmd.extend(["--depth_clip_max", str(cfg["depth_clip_max"])]) + if cfg.get("enable_densification", False): + cmd.append("--enable_densification") + if args.fixed_view_idx >= 0: + cmd.extend(["--fixed_view_idx", str(args.fixed_view_idx)]) + for extra in args.extra_arg: + cmd.append(extra) + + env = os.environ.copy() + if args.cuda_blocking: + env["CUDA_LAUNCH_BLOCKING"] = "1" + + print(f"[SWEEP] Running {cfg['name']} -> {run_dir}") + print(" Command:", " ".join(cmd)) + sys.stdout.flush() + result = subprocess.run(cmd, env=env) + if result.returncode != 0: + print(f"[SWEEP] {cfg['name']} exited with code {result.returncode}") + return result.returncode + + +def read_final_metrics(run_dir: Path) -> Optional[Dict]: + log_path = run_dir / "logs" / "losses.jsonl" + if not log_path.exists(): + return None + last_line: Optional[str] = None + with open(log_path, "r", encoding="utf-8") as log_file: + for line in log_file: + last_line = line.strip() + if not last_line: + return None + try: + data = json.loads(last_line) + except json.JSONDecodeError: + return None + return { + "iteration": data.get("iteration"), + "depth_loss": data.get("depth_loss"), + "pred_depth_mean": data.get("pred_depth_mean"), + "target_depth_mean": data.get("target_depth_mean"), + "pred_depth_max": data.get("pred_depth_max"), + "pred_depth_min": data.get("pred_depth_min"), + "target_depth_max": data.get("target_depth_max"), + "target_depth_min": data.get("target_depth_min"), + } + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + + script_path = Path("milo/depth_train.py").resolve() + ensure_directory(args.output_root) + ensure_directory(Path("output")) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + summary_lines: List[str] = [] + + for cfg in DEFAULT_CONFIGS: + run_dir = args.output_root / f"{timestamp}_{cfg['name']}" + if run_dir.exists() and args.resume_if_exists: + print(f"[SWEEP] Skipping {cfg['name']} (directory exists).") + metrics = read_final_metrics(run_dir) + else: + ensure_directory(run_dir) + exit_code = run_depth_train(script_path, cfg, args, run_dir) + if exit_code != 0: + summary_lines.append(f"{cfg['name']}: FAILED (code {exit_code})") + continue + metrics = read_final_metrics(run_dir) + + if not metrics: + summary_lines.append(f"{cfg['name']}: missing/invalid log") + continue + + summary_lines.append( + "{name}: depth_loss={loss:.4f} pred_mean={p_mean:.4f} target_mean={t_mean:.4f}".format( + name=cfg["name"], + loss=metrics.get("depth_loss", float("nan")), + p_mean=metrics.get("pred_depth_mean", float("nan")), + t_mean=metrics.get("target_depth_mean", float("nan")), + ) + ) + + summary_path = Path("output") / "depth_sweep_summary.txt" + with open(summary_path, "a", encoding="utf-8") as summary_file: + summary_file.write(f"# Sweep {timestamp}\n") + for line in summary_lines: + summary_file.write(line + "\n") + summary_file.write("\n") + + print("\n[SWEEP] Summary:") + for line in summary_lines: + print(" -", line) + print(f"[SWEEP] Full summary appended to {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/milo/scripts/scannet_to_colmap.py b/milo/scripts/scannet_to_colmap.py new file mode 100644 index 0000000..16ed700 --- /dev/null +++ b/milo/scripts/scannet_to_colmap.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +""" +将 ScanNet 场景 (RGB-D + 相机轨迹) 转换成 Milo 训练所需的纯 COLMAP 目录。 + +脚本会: +1. 解析 .sens 获取 RGB 帧与 camera-to-world pose。 +2. 应用 axisAlignment(若存在)以保持 ScanNet 公开场景的惯用坐标。 +3. 把选中的帧原封不动写成 JPEG,生成 COLMAP cameras.txt / images.txt。 +4. 从 *_vh_clean_2.ply(或备用 *_vh_clean.ply)抽取点云,写入 points3D.bin/txt。 +生成结果与 milo/data/Ignatius 相同(images + sparse/0)。 +""" + +from __future__ import annotations + +import argparse +import io +import struct +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator, List, Optional, Sequence, Tuple + +import numpy as np +from plyfile import PlyData + +try: + from PIL import Image +except ImportError: + Image = None + + +@dataclass +class ScanNetFrame: + """Minimal容器,只保留 Milo 转换需要的信息。""" + + index: int + camera_to_world: np.ndarray # (4, 4) + timestamp_color: int + timestamp_depth: int + color_bytes: bytes + + +class ScanNetSensorData: + """直接解析 ScanNet .sens(二进制 RGB-D 轨迹)。""" + + def __init__(self, sens_path: Path): + self.sens_path = Path(sens_path) + self._fh = None + self.version = None + self.sensor_name = "" + self.intrinsic_color = None + self.extrinsic_color = None + self.intrinsic_depth = None + self.extrinsic_depth = None + self.color_compression = None + self.depth_compression = None + self.color_width = 0 + self.color_height = 0 + self.depth_width = 0 + self.depth_height = 0 + self.depth_shift = 0.0 + self.num_frames = 0 + + def __enter__(self) -> "ScanNetSensorData": + self._fh = self.sens_path.open("rb") + self._read_header() + return self + + def __exit__(self, exc_type, exc, tb): + if self._fh: + self._fh.close() + self._fh = None + + def _read_header(self) -> None: + fh = self._fh + assert fh is not None + read = fh.read + self.version = struct.unpack(" Iterator[ScanNetFrame]: + if self._fh is None: + raise RuntimeError("Sensor file is not opened. Use within a context manager.") + + for frame_idx in range(self.num_frames): + mat = np.frombuffer(self._fh.read(16 * 4), dtype="np.ndarray: + """复制自 COLMAP read_write_model,实现矩阵->四元数。""" + R = np.asarray(R, dtype=np.float64) + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec / np.linalg.norm(qvec) + + +def parse_axis_alignment(meta_txt: Path) -> np.ndarray: + """读取 axisAlignment=... 行,没有则返回单位阵。""" + if not meta_txt.is_file(): + return np.eye(4, dtype=np.float64) + + axis = None + with meta_txt.open("r", encoding="utf-8") as fh: + for line in fh: + if line.strip().startswith("axisAlignment"): + values = line.split("=", 1)[1].strip().split() + if len(values) != 16: + raise ValueError(f"axisAlignment 需要 16 个数,当前 {len(values)}") + axis = np.array([float(v) for v in values], dtype=np.float64).reshape(4, 4) + break + if axis is None: + axis = np.eye(4, dtype=np.float64) + return axis + + +def infer_scene_id(scene_root: Path) -> str: + """默认场景目录名就是 sceneXXXX_YY;否则尝试寻找唯一的 *.sens。""" + if scene_root.name.startswith("scene") and "_" in scene_root.name: + return scene_root.name + sens_files = list(scene_root.glob("*.sens")) + if len(sens_files) != 1: + raise ValueError("无法唯一确定 scene id,请使用 --scene-id") + return sens_files[0].stem + + +def find_point_cloud(scene_root: Path, scene_id: str, override: Optional[Path]) -> Path: + if override: + pc_path = Path(override) + if not pc_path.is_file(): + raise FileNotFoundError(f"指定点云 {pc_path} 不存在") + return pc_path + candidates = [ + scene_root / f"{scene_id}_vh_clean_2.ply", + scene_root / f"{scene_id}_vh_clean.ply", + ] + for cand in candidates: + if cand.is_file(): + return cand + raise FileNotFoundError("找不到 *_vh_clean_*.ply 点云,请用 --points-source 指定") + + +def load_point_cloud( + ply_path: Path, + stride: int = 1, + max_points: Optional[int] = None, + seed: int = 0, + transform: Optional[np.ndarray] = None, +) -> Tuple[np.ndarray, np.ndarray]: + ply = PlyData.read(str(ply_path)) + verts = ply["vertex"].data + xyz = np.vstack([verts["x"], verts["y"], verts["z"]]).T.astype(np.float64) + if {"red", "green", "blue"}.issubset(verts.dtype.names): + colors = np.vstack([verts["red"], verts["green"], verts["blue"]]).T.astype(np.uint8) + else: + colors = np.full((xyz.shape[0], 3), 255, dtype=np.uint8) + + idx = np.arange(xyz.shape[0]) + if stride > 1: + idx = idx[::stride] + if max_points is not None and idx.size > max_points: + rng = np.random.default_rng(seed) + idx = rng.choice(idx, size=max_points, replace=False) + xyz = xyz[idx] + if transform is not None: + if transform.shape != (4, 4): + raise ValueError("transform must be 4x4 homogeneous matrix") + homo = np.concatenate([xyz, np.ones((xyz.shape[0], 1), dtype=np.float64)], axis=1) + xyz = (homo @ transform.T)[:, :3] + return xyz, colors[idx] + + +def ensure_output_dirs(output_root: Path) -> Tuple[Path, Path]: + images_dir = output_root / "images" + sparse_dir = output_root / "sparse" / "0" + if output_root.exists(): + existing = [p for p in output_root.iterdir() if not p.name.startswith(".")] + if existing: + raise FileExistsError(f"{output_root} 已存在且非空,请指定新的输出目录") + sparse_dir.mkdir(parents=True, exist_ok=True) + images_dir.mkdir(parents=True, exist_ok=True) + return images_dir, sparse_dir + + +@dataclass +class ImageRecord: + image_id: int + name: str + qvec: np.ndarray + tvec: np.ndarray + frame_index: int + + +def write_cameras_txt(path: Path, camera_id: int, width: int, height: int, fx: float, fy: float, cx: float, cy: float) -> None: + with path.open("w", encoding="utf-8") as fh: + fh.write("# Camera list with one line of data per camera:\n") + fh.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") + fh.write("# Number of cameras: 1\n") + fh.write(f"{camera_id} PINHOLE {width} {height} {fx:.9f} {fy:.9f} {cx:.9f} {cy:.9f}\n") + + +def write_images_txt(path: Path, camera_id: int, records: Sequence[ImageRecord]) -> None: + with path.open("w", encoding="utf-8") as fh: + fh.write("# Image list with two lines of data per image:\n") + fh.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n") + fh.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n") + fh.write(f"# Number of images: {len(records)}, mean observations per image: 0\n") + for rec in records: + q = rec.qvec + t = rec.tvec + fh.write( + f"{rec.image_id} {q[0]:.12f} {q[1]:.12f} {q[2]:.12f} {q[3]:.12f} " + f"{t[0]:.12f} {t[1]:.12f} {t[2]:.12f} {camera_id} {rec.name}\n" + ) + fh.write("\n") + + +def write_points3d_txt(path: Path, xyz: np.ndarray, rgb: np.ndarray) -> None: + with path.open("w", encoding="utf-8") as fh: + fh.write("# 3D point list with one line of data per point:\n") + fh.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[]\n") + fh.write(f"# Number of points: {xyz.shape[0]}\n") + for idx, (pos, color) in enumerate(zip(xyz, rgb), start=1): + fh.write( + f"{idx} {pos[0]:.9f} {pos[1]:.9f} {pos[2]:.9f} " + f"{int(color[0])} {int(color[1])} {int(color[2])} 0\n" + ) + + +def write_points3d_bin(path: Path, xyz: np.ndarray, rgb: np.ndarray) -> None: + with path.open("wb") as fh: + fh.write(struct.pack(" bytes: + if Image is None: + raise RuntimeError("Pillow 未安装,无法使用 --resize-width/--resize-height。请先 pip install pillow。") + width, height = target_size + resample_map = { + "nearest": Image.NEAREST, + "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC, + "lanczos": Image.LANCZOS, + } + resample = resample_map[filter_name] + with Image.open(io.BytesIO(jpeg_bytes)) as img: + img = img.convert("RGB") + resized = img.resize((width, height), resample=resample) + buffer = io.BytesIO() + resized.save(buffer, format="JPEG", quality=quality) + return buffer.getvalue() + + +def convert_scene(args: argparse.Namespace) -> None: + if args.frame_step <= 0: + raise ValueError("frame-step 必须为正整数") + if args.start_frame < 0: + raise ValueError("start-frame 不能为负") + if args.points_stride <= 0: + raise ValueError("points-stride 必须为正整数") + resize_dims: Optional[Tuple[int, int]] = None + if args.resize_width is not None or args.resize_height is not None: + if args.resize_width is None or args.resize_height is None: + raise ValueError("--resize-width/--resize-height 需要同时指定") + if args.resize_width <= 0 or args.resize_height <= 0: + raise ValueError("resize 尺寸必须为正整数") + if not 1 <= args.resize_jpeg_quality <= 100: + raise ValueError("resize-jpeg-quality 需在 [1, 100]") + resize_dims = (args.resize_width, args.resize_height) + + scene_root = Path(args.scene_root).resolve() + scene_id = args.scene_id or infer_scene_id(scene_root) + sens_path = scene_root / f"{scene_id}.sens" + if not sens_path.is_file(): + raise FileNotFoundError(f"未找到 {sens_path}") + meta_path = scene_root / f"{scene_id}.txt" + axis = parse_axis_alignment(meta_path) if args.apply_axis_alignment else np.eye(4, dtype=np.float64) + images_dir, sparse_dir = ensure_output_dirs(Path(args.output).resolve()) + point_cloud_path = find_point_cloud(scene_root, scene_id, args.points_source) + + print(f"[INFO] 转换场景 {scene_id} -> {args.output}") + print(f"[INFO] 使用点云: {point_cloud_path}") + + with ScanNetSensorData(sens_path) as sensor: + if sensor.color_compression != 2: + raise NotImplementedError(f"暂不支持 color_compression={sensor.color_compression} 的 .sens") + + fx = float(sensor.intrinsic_color[0, 0]) + fy = float(sensor.intrinsic_color[1, 1]) + cx = float(sensor.intrinsic_color[0, 2]) + cy = float(sensor.intrinsic_color[1, 2]) + camera_id = 1 + + selected: List[ImageRecord] = [] + next_image_id = 1 + max_frames = args.max_frames if args.max_frames and args.max_frames > 0 else None + start = args.start_frame + for frame in sensor.iter_frames(): + if frame.index < start: + continue + if (frame.index - start) % args.frame_step != 0: + continue + if max_frames is not None and len(selected) >= max_frames: + break + + c2w = frame.camera_to_world + if args.apply_axis_alignment: + c2w = axis @ c2w + if not np.all(np.isfinite(c2w)): + print(f"[WARN] 跳过第 {frame.index} 帧:pose 含 NaN") + continue + + w2c = np.linalg.inv(c2w) + rot = w2c[:3, :3] + tvec = w2c[:3, 3] + qvec = rotmat_to_qvec(rot) + image_name = f"frame_{frame.index:06d}.jpg" + image_path = images_dir / image_name + if resize_dims is None: + with image_path.open("wb") as im_fh: + im_fh.write(frame.color_bytes) + else: + resized_bytes = resize_color_bytes( + frame.color_bytes, + resize_dims, + args.resize_jpeg_quality, + args.resize_filter, + ) + with image_path.open("wb") as im_fh: + im_fh.write(resized_bytes) + + selected.append( + ImageRecord( + image_id=next_image_id, + name=image_name, + qvec=qvec, + tvec=tvec, + frame_index=frame.index, + ) + ) + next_image_id += 1 + if len(selected) % 100 == 0: + print(f"[INFO] 已写入 {len(selected)} 张图像") + + if not selected: + raise RuntimeError("没有任何帧满足采样条件,请检查 start/step/max 参数。") + + cams_txt = sparse_dir / "cameras.txt" + write_cameras_txt(cams_txt, camera_id, sensor.color_width, sensor.color_height, fx, fy, cx, cy) + imgs_txt = sparse_dir / "images.txt" + write_images_txt(imgs_txt, camera_id, selected) + + xyz, rgb = load_point_cloud( + point_cloud_path, + stride=args.points_stride, + max_points=args.points_max, + seed=args.points_seed, + transform=axis if args.apply_axis_alignment else None, + ) + write_points3d_txt(sparse_dir / "points3D.txt", xyz, rgb) + write_points3d_bin(sparse_dir / "points3D.bin", xyz, rgb) + + print(f"[INFO] 转换完成:{len(selected)} 张图像,{xyz.shape[0]} 个点。") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="将 ScanNet 场景转换为 Milo 所需的 COLMAP 布局。") + parser.add_argument("--scene-root", required=True, help="包含 sceneXXXX_YY.* 的目录(目录内有 .sens/.txt/.ply)。") + parser.add_argument("--output", required=True, help="输出目录(需不存在或为空,将创建 images/ 与 sparse/0/)。") + parser.add_argument("--scene-id", help="可选,显式指定 sceneXXXX_YY。默认取目录名或自动推断。") + parser.add_argument("--start-frame", type=int, default=0, help="从第多少帧开始采样(默认 0)。") + parser.add_argument("--frame-step", type=int, default=1, help="帧采样步长,例如 5 表示每 5 帧取 1 帧。") + parser.add_argument("--max-frames", type=int, help="最多输出多少帧,默认全部。") + parser.add_argument("--no-axis-alignment", dest="apply_axis_alignment", action="store_false", help="不使用 axisAlignment。") + parser.set_defaults(apply_axis_alignment=True) + parser.add_argument("--points-source", type=Path, help="自定义点云 .ply 路径(默认自动找 *_vh_clean_2.ply)。") + parser.add_argument("--points-stride", type=int, default=1, help="点云下采样步长(1 表示保留全部)。") + parser.add_argument("--points-max", type=int, help="点云最多保留多少点。") + parser.add_argument("--points-seed", type=int, default=0, help="点云随机采样用的随机种子。") + parser.add_argument("--resize-width", type=int, help="可选,将 RGB 输出缩放到指定宽度(像素)。需要安装 pillow。") + parser.add_argument("--resize-height", type=int, help="可选,将 RGB 输出缩放到指定高度(像素)。") + parser.add_argument( + "--resize-filter", + choices=["nearest", "bilinear", "bicubic", "lanczos"], + default="lanczos", + help="可选缩放滤波器(默认 lanczos)。", + ) + parser.add_argument( + "--resize-jpeg-quality", + type=int, + default=95, + help="缩放后重新写入 JPEG 时的质量系数(1-100,默认 95)。", + ) + return parser + + +def main(): + parser = build_arg_parser() + args = parser.parse_args() + convert_scene(args) + + +if __name__ == "__main__": + main() diff --git a/milo/scripts/verify_camera_poses.py b/milo/scripts/verify_camera_poses.py new file mode 100644 index 0000000..d139ec0 --- /dev/null +++ b/milo/scripts/verify_camera_poses.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +""" +Verification script to check camera pose interpretation. +This script loads the camera poses and PLY file to verify if Gaussians +are in front of or behind the cameras. +""" + +import json +import numpy as np +from pathlib import Path + + +def quaternion_to_rotation_matrix(q): + """Convert quaternion [w, x, y, z] to rotation matrix.""" + q = np.asarray(q, dtype=np.float64) + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float32, + ) + return rotation + + +def load_ply_points(ply_path): + """Load point positions from PLY file.""" + from plyfile import PlyData + + plydata = PlyData.read(ply_path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + return positions + + +def check_camera_pose_current(camera_entry, points): + """Check with CURRENT (incorrect) pose interpretation.""" + quaternion = camera_entry["quaternion"] + camera_center = np.array(camera_entry["position"], dtype=np.float32) + + # Current code (WRONG according to codex): + rotation = quaternion_to_rotation_matrix(quaternion) + translation = -rotation.T @ camera_center + + # Transform points to camera space + # With current interpretation: R is passed as-is, T is translation + # getWorld2View2 expects R to be C2W and builds W2C as [R.T | T] + # So W2C rotation is rotation.T, W2C translation is translation + R_w2c = rotation.T + t_w2c = translation + + # Transform points: p_cam = R_w2c @ p_world + t_w2c + points_cam = (R_w2c @ points.T).T + t_w2c + + # Check z coordinate (positive = in front of camera) + in_front = np.sum(points_cam[:, 2] > 0) + total = len(points) + fraction = in_front / total + + return fraction, in_front, total + + +def check_camera_pose_corrected(camera_entry, points): + """Check with CORRECTED pose interpretation.""" + quaternion = camera_entry["quaternion"] + camera_center = np.array(camera_entry["position"], dtype=np.float32) + + # Corrected code (as suggested by codex): + rotation_w2c = quaternion_to_rotation_matrix(quaternion) + rotation_c2w = rotation_w2c.T + translation_w2c = -rotation_w2c @ camera_center + + # With corrected interpretation: R_c2w is passed, T is translation_w2c + # getWorld2View2 builds W2C as [R_c2w.T | T] = [R_w2c | translation_w2c] + R_w2c = rotation_c2w.T # = rotation_w2c + t_w2c = translation_w2c + + # Transform points: p_cam = R_w2c @ p_world + t_w2c + points_cam = (R_w2c @ points.T).T + t_w2c + + # Check z coordinate (positive = in front of camera) + in_front = np.sum(points_cam[:, 2] > 0) + total = len(points) + fraction = in_front / total + + return fraction, in_front, total + + +def main(): + # Paths + camera_poses_path = Path("/milo/data/bridge_small/camera_poses_cam1.json") + ply_path = Path("/milo/data/bridge_small/yufu_bridge_small.ply") + + # Load data + print(f"Loading camera poses from {camera_poses_path}") + with open(camera_poses_path, 'r') as f: + camera_entries = json.load(f) + print(f"Loaded {len(camera_entries)} cameras") + + print(f"\nLoading point cloud from {ply_path}") + points = load_ply_points(ply_path) + print(f"Loaded {len(points)} points") + + # Check first camera with both interpretations + print("\n" + "="*80) + print("CHECKING CAMERA 0 (traj_0_cam0)") + print("="*80) + + camera_0 = camera_entries[0] + print(f"Camera position: {camera_0['position']}") + print(f"Camera quaternion: {camera_0['quaternion']}") + + print("\n--- Current (WRONG) interpretation ---") + frac_current, in_front_current, total = check_camera_pose_current(camera_0, points) + print(f"Points in front of camera: {in_front_current}/{total} ({frac_current*100:.2f}%)") + + print("\n--- Corrected interpretation ---") + frac_corrected, in_front_corrected, total = check_camera_pose_corrected(camera_0, points) + print(f"Points in front of camera: {in_front_corrected}/{total} ({frac_corrected*100:.2f}%)") + + # Check a few more cameras + print("\n" + "="*80) + print("CHECKING FIRST 5 CAMERAS") + print("="*80) + print(f"{'Camera':<15} {'Current (wrong)':<20} {'Corrected':<20}") + print("-" * 80) + + for i in range(min(5, len(camera_entries))): + camera = camera_entries[i] + frac_current, _, _ = check_camera_pose_current(camera, points) + frac_corrected, _, _ = check_camera_pose_corrected(camera, points) + print(f"{camera['name']:<15} {frac_current*100:>6.2f}% {frac_corrected*100:>6.2f}%") + + print("\n" + "="*80) + print("CONCLUSION") + print("="*80) + if frac_current < 0.5 and frac_corrected > 0.5: + print("✓ Codex's analysis is CORRECT!") + print(" - Current code: Most points are BEHIND cameras (wrong)") + print(" - Corrected code: Most points are IN FRONT of cameras (correct)") + print("\nThe quaternions should be interpreted as world→camera rotations,") + print("and the fix suggested by codex is needed.") + else: + print("✗ Results don't match codex's analysis.") + print(" Further investigation needed.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/milo/test_opt_config.py b/milo/test_opt_config.py new file mode 100644 index 0000000..e7bedb5 --- /dev/null +++ b/milo/test_opt_config.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +""" +测试优化配置文件加载功能 +""" +import sys +from pathlib import Path + +# 添加当前目录到路径 +sys.path.insert(0, str(Path(__file__).parent)) + +from yufu2mesh_new import load_optimization_config + +def test_config(config_name: str): + """测试加载指定的配置文件""" + print(f"\n{'='*60}") + print(f"测试配置: {config_name}") + print('='*60) + + try: + config = load_optimization_config(config_name) + + print("\n✓ 配置加载成功!") + print("\n高斯参数配置:") + print("-" * 40) + for param_name, param_cfg in config["gaussian_params"].items(): + trainable = param_cfg.get("trainable", False) + lr = param_cfg.get("lr", 0.0) + status = "✓ 可训练" if trainable else "✗ 冻结" + print(f" {param_name:20s} {status:10s} lr={lr:.6f}") + + print("\nLoss权重配置:") + print("-" * 40) + for loss_name, weight in config["loss_weights"].items(): + print(f" {loss_name:20s} {weight:.3f}") + + print("\n深度处理配置:") + print("-" * 40) + depth_cfg = config["depth_processing"] + print(f" clip_min: {depth_cfg.get('clip_min')}") + print(f" clip_max: {depth_cfg.get('clip_max')}") + + print("\nMesh正则化配置:") + print("-" * 40) + mesh_cfg = config["mesh_regularization"] + print(f" depth_weight: {mesh_cfg.get('depth_weight')}") + print(f" normal_weight: {mesh_cfg.get('normal_weight')}") + + return True + + except Exception as e: + print(f"\n✗ 配置加载失败: {e}") + import traceback + traceback.print_exc() + return False + +def main(): + """测试所有预设配置""" + configs_to_test = [ + "default", + "xyz_only", + "xyz_geometry", + "xyz_occupancy", + "full" + ] + + print("开始测试优化配置加载功能...") + + results = {} + for config_name in configs_to_test: + results[config_name] = test_config(config_name) + + print("\n" + "="*60) + print("测试总结") + print("="*60) + for config_name, success in results.items(): + status = "✓ 通过" if success else "✗ 失败" + print(f" {config_name:20s} {status}") + + all_passed = all(results.values()) + if all_passed: + print("\n🎉 所有配置测试通过!") + return 0 + else: + print("\n⚠️ 部分配置测试失败") + return 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/milo/train.py b/milo/train.py index 2f0464a..4fb60cd 100644 --- a/milo/train.py +++ b/milo/train.py @@ -2,6 +2,8 @@ import sys import gc import yaml +import json +import random from functools import partial BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.abspath(os.path.join(BASE_DIR, '..')) @@ -208,6 +210,9 @@ def training( culling=gaussians._culling[:,viewpoint_cam.uid], ) + if "area_max" not in render_pkg: + render_pkg["area_max"] = torch.zeros_like(render_pkg["radii"]) + # ---Compute losses--- image, viewspace_point_tensor, visibility_filter, radii = ( render_pkg["render"], render_pkg["viewspace_points"], diff --git a/milo/useless_maybe/depth_guided_refine.py b/milo/useless_maybe/depth_guided_refine.py new file mode 100644 index 0000000..605aff4 --- /dev/null +++ b/milo/useless_maybe/depth_guided_refine.py @@ -0,0 +1,602 @@ +#!/usr/bin/env python3 +"""Depth-guided refinement of Gaussian SDFs and mesh extraction. + +This script optimizes a pretrained Gaussian Splat (PLY) using per-view depth maps. +Compared to `iterative_occupancy_refine.py`, Gaussian geometry is trainable and +the supervision comes directly from depth instead of RGB images. +""" + +from __future__ import annotations + +import argparse +import json +import os +import random +import re +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from arguments import OptimizationParams, PipelineParams +from gaussian_renderer import integrate_radegs +from milo.useless_maybe.ply2mesh import ( + ManualScene, + export_mesh_from_gaussians, + initialize_mesh_regularization, + load_cameras_from_json, + build_render_functions, +) +from regularization.regularizer.mesh import compute_mesh_regularization +from scene.gaussian_model import GaussianModel, SparseGaussianAdam +from utils.general_utils import get_expon_lr_func +from torch.nn.utils import clip_grad_norm_ + + +def ensure_learnable_occupancy(gaussians: GaussianModel) -> None: + """Ensure occupancy buffers exist and the shift tensor is trainable.""" + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + device = gaussians._xyz.device + n_pts = gaussians._xyz.shape[0] + base = torch.zeros((n_pts, 9), device=device) + shift = torch.zeros_like(base) + gaussians.learn_occupancy = True + gaussians._base_occupancy = torch.nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = torch.nn.Parameter(shift.requires_grad_(True)) + gaussians.set_occupancy_mode("occupancy_shift") + gaussians._occupancy_shift.requires_grad_(True) + + +def extract_loss_scalars(metrics: Dict) -> Dict[str, float]: + scalars: Dict[str, float] = {} + for key, value in metrics.items(): + if not key.endswith("_loss"): + continue + scalar: Optional[float] = None + if isinstance(value, torch.Tensor): + if value.ndim == 0: + scalar = float(value.item()) + elif isinstance(value, (float, int)): + scalar = float(value) + if scalar is not None: + scalars[key] = scalar + return scalars + + +def export_iteration_state( + iteration: int, + gaussians: GaussianModel, + mesh_state: Dict, + output_dir: str, + reference_camera=None, +) -> None: + os.makedirs(output_dir, exist_ok=True) + mesh_path = os.path.join(output_dir, f"mesh_iter_{iteration:05d}.ply") + ply_path = os.path.join(output_dir, f"gaussians_iter_{iteration:05d}.ply") + + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=mesh_path, + reference_camera=reference_camera, + ) + gaussians.save_ply(ply_path) + + +def natural_key(path: str) -> List[object]: + """Split path into text/number tokens for natural sorting.""" + return [ + int(token) if token.isdigit() else token + for token in re.split(r"(\d+)", path) + if token + ] + + +@dataclass +class DepthRecord: + depth: torch.Tensor # (1, H, W) on CPU + valid_mask: torch.Tensor # (1, H, W) on CPU, float mask in {0,1} + + +class DepthMapProvider: + """Loads and serves depth maps corresponding to camera viewpoints.""" + + def __init__( + self, + depth_dir: str, + cameras: Sequence, + depth_scale: float = 1.0, + depth_offset: float = 0.0, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + if not os.path.isdir(depth_dir): + raise FileNotFoundError(f"Depth directory not found: {depth_dir}") + self.depth_dir = depth_dir + self.depth_scale = depth_scale + self.depth_offset = depth_offset + self.clip_min = clip_min + self.clip_max = clip_max + + file_list = [f for f in os.listdir(depth_dir) if f.endswith(".npy")] + if not file_list: + raise ValueError(f"No depth npy files found in {depth_dir}") + + # Index depth files by (cam_idx, frame_idx) when possible. + pattern = re.compile(r"depth_img_(\d+)_(\d+)\.npy$") + indexed_files: Dict[Tuple[int, int], str] = {} + for filename in file_list: + match = pattern.match(filename) + if match: + cam_idx = int(match.group(1)) + frame_idx = int(match.group(2)) + indexed_files[(cam_idx, frame_idx)] = filename + + # Fallback: natural sorted list for sequential mapping. + natural_sorted_files = sorted(file_list, key=natural_key) + + self.depth_height: Optional[int] = None + self.depth_width: Optional[int] = None + self.global_min: float = float("inf") + self.global_max: float = float("-inf") + self.global_valid_pixels: int = 0 + + self.records: List[DepthRecord] = [] + for cam_idx, cam in enumerate(cameras): + depth_path = self._resolve_path( + cam.image_name if hasattr(cam, "image_name") else str(cam_idx), + cam_idx, + indexed_files, + natural_sorted_files, + ) + full_path = os.path.join(depth_dir, depth_path) + depth_np = np.load(full_path) + if depth_np.ndim == 3 and depth_np.shape[-1] == 1: + depth_np = depth_np[..., 0] + if depth_np.ndim == 2: + depth_np = depth_np[None, ...] # (1, H, W) + elif depth_np.ndim == 3 and depth_np.shape[0] == 1: + pass # already (1, H, W) + else: + raise ValueError(f"Unexpected depth shape {depth_np.shape} in {full_path}") + + depth_tensor = torch.from_numpy(depth_np.astype(np.float32)) + depth_tensor = depth_tensor * depth_scale + depth_offset + + if clip_min is not None or clip_max is not None: + depth_tensor = depth_tensor.clamp( + min=clip_min if clip_min is not None else float("-inf"), + max=clip_max if clip_max is not None else float("inf"), + ) + + valid_mask = (depth_tensor > 0.0).float() + # Track global statistics for diagnostics. + if self.depth_height is None: + self.depth_height, self.depth_width = depth_tensor.shape[-2:] + valid_values = depth_tensor[valid_mask > 0.5] + if valid_values.numel() > 0: + self.global_min = min(self.global_min, float(valid_values.min().item())) + self.global_max = max(self.global_max, float(valid_values.max().item())) + self.global_valid_pixels += int(valid_values.numel()) + + self.records.append(DepthRecord(depth=depth_tensor.contiguous(), valid_mask=valid_mask)) + + if len(self.records) != len(cameras): + raise RuntimeError("Depth map count does not match number of cameras.") + if self.global_min == float("inf"): + self.global_min = 0.0 + self.global_max = 0.0 + + def _resolve_path( + self, + camera_name: str, + camera_idx: int, + indexed_files: Dict[Tuple[int, int], str], + fallback_files: List[str], + ) -> str: + match = re.search(r"traj_(\d+)_cam(\d+)", camera_name) + if match: + frame_idx = int(match.group(1)) + cam_idx = int(match.group(2)) + candidate = indexed_files.get((cam_idx, frame_idx)) + if candidate: + return candidate + # Fallback to cam index with ordered list. + if camera_idx >= len(fallback_files): + raise IndexError( + f"Camera index {camera_idx} exceeds depth file count {len(fallback_files)}." + ) + return fallback_files[camera_idx] + + def get(self, index: int, device: torch.device) -> DepthRecord: + record = self.records[index] + depth = record.depth.to(device, non_blocking=True) + valid = record.valid_mask.to(device, non_blocking=True) + return DepthRecord(depth=depth, valid_mask=valid) + + +def compute_depth_loss( + predicted: torch.Tensor, + target: torch.Tensor, + valid_mask: torch.Tensor, + epsilon: float = 1e-8, +) -> Tuple[torch.Tensor, float, float, int]: + """Compute masked L1 loss and return (loss, mean_abs_error, valid_fraction, valid_pixels).""" + if predicted.shape != target.shape: + target = F.interpolate( + target.unsqueeze(0), + size=predicted.shape[-2:], + mode="bilinear", + align_corners=True, + ).squeeze(0) + valid_mask = F.interpolate( + valid_mask.unsqueeze(0), + size=predicted.shape[-2:], + mode="nearest", + ).squeeze(0) + + valid = valid_mask > 0.5 + valid_pixels = valid.sum().item() + if valid_pixels == 0: + zero = torch.zeros((), device=predicted.device, dtype=predicted.dtype) + return zero, 0.0, 0.0, 0 + + diff = (predicted - target).abs() * valid_mask + loss = diff.sum() / (valid_mask.sum() + epsilon) + mae = diff.sum().item() / (valid_pixels + epsilon) + valid_fraction = valid_pixels / valid_mask.numel() + return loss, mae, valid_fraction, int(valid_pixels) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Depth-guided Gaussian refinement with mesh regularization.") + parser.add_argument("--ply_path", type=str, required=True, help="Input Gaussian PLY.") + parser.add_argument("--camera_poses", type=str, required=True, help="Camera pose JSON.") + parser.add_argument("--depth_dir", type=str, required=True, help="Directory containing per-view depth .npy files.") + parser.add_argument("--mesh_config", type=str, default="medium", help="Mesh regularization config name.") + parser.add_argument("--iterations", type=int, default=5000, help="Number of optimization steps.") + parser.add_argument("--depth_loss_weight", type=float, default=1.0, help="(Deprecated) Depth loss multiplier; kept for backward compatibility.") + parser.add_argument("--mesh_loss_weight", type=float, default=1.0, help="(Deprecated) Mesh loss multiplier; kept for backward compatibility.") + parser.add_argument("--occupancy_lr_scale", type=float, default=1.0, help="Multiplier applied to occupancy LR.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered image width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered image height.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field-of-view in degrees.") + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--output_dir", type=str, default="./depth_refine_output", help="Directory to store outputs.") + parser.add_argument("--log_interval", type=int, default=100, help="Logging interval.") + parser.add_argument("--export_interval", type=int, default=1000, help="Mesh export interval (0 disables periodic export).") + parser.add_argument("--depth_scale", type=float, default=1.0, help="Scale factor applied to loaded depth maps.") + parser.add_argument("--depth_offset", type=float, default=0.0, help="Offset applied to loaded depth maps after scaling.") + parser.add_argument("--depth_clip_min", type=float, default=None, help="Clip depth to minimum value (after scaling).") + parser.add_argument("--depth_clip_max", type=float, default=None, help="Clip depth to maximum value (after scaling).") + parser.add_argument("--freeze_colors", dest="freeze_colors", action="store_true", help="Freeze SH features during optimization.") + parser.add_argument("--no-freeze_colors", dest="freeze_colors", action="store_false", help="Allow SH features to be optimized.") + parser.set_defaults(freeze_colors=True) + parser.add_argument("--grad_clip_norm", type=float, default=0.0, help="Apply gradient clipping with given norm (0 disables).") + parser.add_argument("--initial_lr_scale", type=float, default=1.0, help="Multiplier for position lr_init.") + parser.add_argument("--device", type=str, default="cuda", help="Compute device.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration at which mesh regularization starts.") + parser.add_argument("--mesh_stop_iter", type=int, default=None, help="Optional iteration to stop mesh regularization.") + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for surface sampling.") + parser.add_argument("--imp_metric", type=str, default="outdoor", choices=["outdoor", "indoor"], help="Importance metric for surface sampling.") + parser.add_argument("--depth_loss_epsilon", type=float, default=1e-6, help="Numerical epsilon for depth loss denominator.") + return parser + + +def main() -> None: + parser = build_arg_parser() + args = parser.parse_args() + + if not torch.cuda.is_available() and args.device.startswith("cuda"): + raise RuntimeError("CUDA device is required for depth-guided refinement.") + + device = torch.device(args.device) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + ) + print(f"[INFO] Loaded {len(cameras)} cameras from {args.camera_poses}.") + + depth_provider = DepthMapProvider( + depth_dir=args.depth_dir, + cameras=cameras, + depth_scale=args.depth_scale, + depth_offset=args.depth_offset, + clip_min=args.depth_clip_min, + clip_max=args.depth_clip_max, + ) + print(f"[INFO] Loaded {len(depth_provider.records)} depth maps from {args.depth_dir}.") + if depth_provider.depth_height is not None: + depth_h, depth_w = depth_provider.depth_height, depth_provider.depth_width + if depth_h != args.image_height or depth_w != args.image_width: + print( + f"[WARNING] Depth resolution ({depth_w}x{depth_h}) differs from render resolution " + f"({args.image_width}x{args.image_height}). Depth maps will be interpolated." + ) + if depth_provider.global_valid_pixels == 0: + print("[WARNING] No valid depth pixels found across dataset; depth supervision will be ineffective.") + else: + print( + "[INFO] Depth value range after scaling: " + f"{depth_provider.global_min:.4f} – {depth_provider.global_max:.4f} " + f"({depth_provider.global_valid_pixels} valid pixels)." + ) + + scene = ManualScene(cameras) + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + print(f"[INFO] Loaded {gaussians._xyz.shape[0]} Gaussians from {args.ply_path}.") + + ensure_learnable_occupancy(gaussians) + gaussians.init_culling(len(cameras)) + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + + mesh_config = load_mesh_config( + name=args.mesh_config, + start_iter_override=args.mesh_start_iter, + stop_iter_override=args.mesh_stop_iter, + total_iterations=args.iterations, + ) + occupancy_mode = mesh_config.get("occupancy_mode", "occupancy_shift") + if occupancy_mode != "occupancy_shift": + raise ValueError( + f"Depth-guided refinement requires occupancy_mode 'occupancy_shift', got '{occupancy_mode}'. " + "Please adjust the mesh configuration." + ) + gaussians.set_occupancy_mode(occupancy_mode) + print( + "[INFO] Mesh config '{name}': start_iter={start}, stop_iter={stop}, n_max_points_in_delaunay={limit}".format( + name=args.mesh_config, + start=mesh_config.get("start_iter"), + stop=mesh_config.get("stop_iter"), + limit=mesh_config.get("n_max_points_in_delaunay"), + ) + ) + + opt_parser = argparse.ArgumentParser() + opt_params = OptimizationParams(opt_parser) + opt_params.iterations = args.iterations + opt_params.position_lr_init *= args.initial_lr_scale + opt_params.position_lr_final *= args.initial_lr_scale + + gaussians.training_setup(opt_params) + + if args.freeze_colors: + gaussians._features_dc.requires_grad_(False) + gaussians._features_rest.requires_grad_(False) + + lr_xyz_init = opt_params.position_lr_init * gaussians.spatial_lr_scale + + param_groups = [ + {"params": [gaussians._xyz], "lr": lr_xyz_init, "name": "xyz"}, + {"params": [gaussians._opacity], "lr": opt_params.opacity_lr, "name": "opacity"}, + {"params": [gaussians._scaling], "lr": opt_params.scaling_lr, "name": "scaling"}, + {"params": [gaussians._rotation], "lr": opt_params.rotation_lr, "name": "rotation"}, + ] + if not args.freeze_colors: + param_groups.append({"params": [gaussians._features_dc], "lr": opt_params.feature_lr, "name": "f_dc"}) + param_groups.append({"params": [gaussians._features_rest], "lr": opt_params.feature_lr / 20.0, "name": "f_rest"}) + if gaussians.learn_occupancy: + param_groups.append({"params": [gaussians._occupancy_shift], "lr": opt_params.opacity_lr * args.occupancy_lr_scale, "name": "occupancy_shift"}) + + gaussians.optimizer = SparseGaussianAdam(param_groups, lr=0.0, eps=1e-15) + gaussians.xyz_scheduler_args = get_expon_lr_func( + lr_init=lr_xyz_init, + lr_final=opt_params.position_lr_final * gaussians.spatial_lr_scale, + lr_delay_mult=opt_params.position_lr_delay_mult, + max_steps=opt_params.position_lr_max_steps, + ) + + background = torch.zeros(3, dtype=torch.float32, device=device) + pipe_parser = argparse.ArgumentParser() + pipe = PipelineParams(pipe_parser) + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + mesh_renderer, mesh_state = initialize_mesh_regularization(scene, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + runtime_args = argparse.Namespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=getattr(args, "depth_reinit_iter", args.warn_until_iter), + ) + + os.makedirs(args.output_dir, exist_ok=True) + log_dir = os.path.join(args.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + loss_log_path = os.path.join(log_dir, "losses.jsonl") + + ema_depth_loss = None + ema_mesh_loss = None + pending_view_indices: List[int] = [] + printed_depth_diagnostics = False + + with open(loss_log_path, "w", encoding="utf-8") as loss_log_file: + for iteration in range(1, args.iterations + 1): + if not pending_view_indices: + pending_view_indices = list(range(len(cameras))) + random.shuffle(pending_view_indices) + + view_idx = pending_view_indices.pop() + viewpoint = cameras[view_idx] + + depth_record = depth_provider.get(view_idx, device) + render_pkg = render_view(viewpoint) + + pred_depth = render_pkg["median_depth"] + depth_loss, depth_mae, valid_fraction, valid_pixels = compute_depth_loss( + predicted=pred_depth, + target=depth_record.depth, + valid_mask=depth_record.valid_mask, + epsilon=args.depth_loss_epsilon, + ) + + if valid_pixels == 0: + skipped_record = { + "iteration": iteration, + "view_index": view_idx, + "skipped": True, + "skipped_reason": "invalid_depth", + } + loss_log_file.write(json.dumps(skipped_record) + "\n") + loss_log_file.flush() + if iteration % args.log_interval == 0 or iteration == 1: + print(f"[Iter {iteration:05d}] skipped view {view_idx} due to invalid depth.") + continue + + total_loss = depth_loss + + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=gaussians, + scene=scene, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=100.0 / max(args.iterations, 1), + args=runtime_args, + integrate_func=integrate_radegs, + ) + mesh_state = mesh_pkg["updated_state"] + mesh_loss_tensor = mesh_pkg["mesh_loss"] + mesh_loss = mesh_loss_tensor + + if not printed_depth_diagnostics: + depth_valid = depth_record.depth[depth_record.valid_mask > 0.5] + print( + "[DIAG] First valid depth batch: " + f"depth range {float(depth_valid.min().item()):.4f} – {float(depth_valid.max().item()):.4f}, " + f"predicted range {float(pred_depth.min().item()):.4f} – {float(pred_depth.max().item()):.4f}" + ) + print(f"[DIAG] Gaussian spatial_lr_scale: {gaussians.spatial_lr_scale:.6f}") + mesh_loss_unweighted = mesh_loss_tensor.item() + mesh_loss_weighted_diag = mesh_loss.item() + print( + f"[DIAG] Initial losses — depth_loss={depth_loss.item():.6e}, " + f"mesh_loss_raw={mesh_loss_unweighted:.6e}, " + f"mesh_loss_weighted={mesh_loss_weighted_diag:.6e}" + ) + printed_depth_diagnostics = True + + total_loss = total_loss + mesh_loss + + gaussians.optimizer.zero_grad(set_to_none=True) + total_loss.backward() + if args.grad_clip_norm > 0.0: + trainable_params: List[torch.Tensor] = [] + for group in gaussians.optimizer.param_groups: + for param in group.get("params", []): + if isinstance(param, torch.Tensor) and param.requires_grad: + trainable_params.append(param) + if trainable_params: + clip_grad_norm_(trainable_params, args.grad_clip_norm) + gaussians.update_learning_rate(iteration) + visibility = render_pkg["visibility_filter"] + radii = render_pkg["radii"] + gaussians.optimizer.step(visibility, radii.shape[0]) + + total_loss_value = float(total_loss.item()) + depth_loss_value = float(depth_loss.item()) + mesh_loss_value = float(mesh_loss_tensor.item()) + weighted_mesh_loss_value = mesh_loss_value + + ema_depth_loss = depth_loss_value if ema_depth_loss is None else (0.9 * ema_depth_loss + 0.1 * depth_loss_value) + ema_mesh_loss = weighted_mesh_loss_value if ema_mesh_loss is None else (0.9 * ema_mesh_loss + 0.1 * weighted_mesh_loss_value) + + iteration_record = { + "iteration": iteration, + "view_index": view_idx, + "total_loss": total_loss_value, + "depth_loss": depth_loss_value, + "mesh_loss_raw": mesh_loss_value, + "mesh_loss_weighted": weighted_mesh_loss_value, + "ema_depth_loss": ema_depth_loss, + "ema_mesh_loss": ema_mesh_loss, + "depth_mae": depth_mae, + "valid_fraction": valid_fraction, + "valid_pixels": valid_pixels, + } + iteration_record.update(extract_loss_scalars(mesh_pkg)) + loss_log_file.write(json.dumps(iteration_record) + "\n") + loss_log_file.flush() + + if args.export_interval > 0 and iteration % args.export_interval == 0: + export_iteration_state( + iteration=iteration, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=args.output_dir, + reference_camera=None, + ) + + if iteration % args.log_interval == 0 or iteration == 1: + print( + "[Iter {iter:05d}] loss={loss:.6f} depth={depth:.6f} mesh={mesh:.6f} " + "depth_mae={mae:.6f} valid={valid:.3f}".format( + iter=iteration, + loss=total_loss_value, + depth=depth_loss_value, + mesh=weighted_mesh_loss_value, + mae=depth_mae, + valid=valid_fraction, + ) + ) + + final_dir = os.path.join(args.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + export_iteration_state( + iteration=args.iterations, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=final_dir, + reference_camera=None, + ) + print(f"[INFO] Depth-guided refinement completed. Results saved to {args.output_dir}.") + + +def load_mesh_config( + name: str, + start_iter_override: Optional[int] = None, + stop_iter_override: Optional[int] = None, + total_iterations: Optional[int] = None, +) -> Dict: + from milo.useless_maybe.ply2mesh import load_mesh_config_file + + config = load_mesh_config_file(name) + if start_iter_override is not None: + config["start_iter"] = max(1, start_iter_override) + else: + config["start_iter"] = max(1, config.get("start_iter", 1)) + if stop_iter_override is not None: + config["stop_iter"] = stop_iter_override + elif total_iterations is not None: + config["stop_iter"] = max(config.get("stop_iter", total_iterations), total_iterations) + config["stop_iter"] = max(config.get("stop_iter", config["start_iter"]), config["start_iter"]) + return config + + +if __name__ == "__main__": + main() diff --git a/milo/useless_maybe/depth_train.py b/milo/useless_maybe/depth_train.py new file mode 100644 index 0000000..aafeba7 --- /dev/null +++ b/milo/useless_maybe/depth_train.py @@ -0,0 +1,940 @@ +#!/usr/bin/env python3 +""" +Depth-supervised training loop for 3D Gaussian Splatting. + +This script mirrors the original MILo image-supervised training pipeline, but +replaces the photometric loss with a depth reconstruction objective fed by +per-view depth maps. It supports mesh-in-the-loop regularization, gaussian +densification/simplification, and periodic exports for inspection. +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import random +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import yaml +from torch.nn.utils import clip_grad_norm_ +import trimesh + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +ROOT_DIR = os.path.abspath(os.path.join(BASE_DIR, "..")) +sys.path.append(ROOT_DIR) + +from arguments import OptimizationParams, PipelineParams # noqa: E402 +from gaussian_renderer import render_simp # noqa: E402 +from gaussian_renderer.radegs import render_radegs as render_radegs # noqa: E402 +from gaussian_renderer.radegs import integrate_radegs as integrate # noqa: E402 +from regularization.regularizer.mesh import initialize_mesh_regularization # noqa: E402 +from regularization.regularizer.mesh import compute_mesh_regularization # noqa: E402 +from regularization.regularizer.mesh import reset_mesh_state_at_next_iteration # noqa: E402 +from scene.cameras import Camera # noqa: E402 +from scene.gaussian_model import GaussianModel # noqa: E402 +from utils.geometry_utils import flatten_voronoi_features # noqa: E402 +from utils.general_utils import safe_state # noqa: E402 +from functional import extract_mesh, compute_delaunay_triangulation # noqa: E402 +from functional.mesh import frustum_cull_mesh # noqa: E402 +from regularization.sdf.learnable import convert_occupancy_to_sdf # noqa: E402 + + +def quaternion_to_rotation_matrix(q: Sequence[float]) -> np.ndarray: + q = np.asarray(q, dtype=np.float64) + if q.shape != (4,): + raise ValueError("Quaternion must have shape (4,)") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float32, + ) + return rotation + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, + data_device: str, +) -> List[Camera]: + if not os.path.isfile(json_path): + raise FileNotFoundError(f"Camera JSON not found: {json_path}") + with open(json_path, "r", encoding="utf-8") as f: + entries = json.load(f) + if not entries: + raise ValueError(f"No camera entries in {json_path}") + + fov_y = math.radians(fov_y_deg) + aspect = image_width / image_height + fov_x = 2.0 * math.atan(aspect * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(entries): + if "quaternion" in entry: + rotation = quaternion_to_rotation_matrix(entry["quaternion"]) + elif "rotation" in entry: + rotation = np.asarray(entry["rotation"], dtype=np.float32) + if rotation.shape != (3, 3): + raise ValueError(f"Camera entry {idx} rotation must be 3x3") + else: + raise KeyError(f"Camera entry {idx} missing rotation or quaternion.") + + if "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + elif "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"Camera entry {idx} position must be length-3.") + translation = -rotation.T @ camera_center + else: + raise KeyError(f"Camera entry {idx} missing translation/position.") + + if translation.shape != (3,): + raise ValueError(f"Camera entry {idx} translation must be length-3.") + + image_name = ( + entry.get("name") + or entry.get("img_name") + or entry.get("image_name") + or f"view_{idx:04d}" + ) + + camera = Camera( + colmap_id=str(idx), + R=rotation, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device=data_device, + ) + cameras.append(camera) + return cameras + + +def _clone_camera_for_scale(camera: Camera, scale: float) -> Camera: + if math.isclose(scale, 1.0): + return camera + + new_height = max(1, int(round(camera.image_height / scale))) + new_width = max(1, int(round(camera.image_width / scale))) + blank_image = torch.zeros(3, new_height, new_width, dtype=torch.float32) + + # Camera expects rotation/translation as numpy arrays; reuse existing values. + return Camera( + colmap_id=camera.colmap_id, + R=camera.R, + T=camera.T, + FoVx=camera.FoVx, + FoVy=camera.FoVy, + image=blank_image, + gt_alpha_mask=None, + image_name=camera.image_name, + uid=camera.uid, + data_device=str(camera.data_device), + ) + + +def _build_scaled_cameras( + cameras: Sequence[Camera], + scales: Sequence[float] = (1.0, 2.0), +) -> Dict[float, List[Camera]]: + scaled: Dict[float, List[Camera]] = {} + for scale in scales: + if math.isclose(scale, 1.0): + scaled[float(scale)] = list(cameras) + else: + scaled[float(scale)] = [_clone_camera_for_scale(cam, scale) for cam in cameras] + return scaled + + +class ManualScene: + """Minimal adapter exposing camera access expected by mesh regularizer.""" + + def __init__(self, cameras_by_scale: Dict[float, Sequence[Camera]]): + if 1.0 not in cameras_by_scale: + raise ValueError("At least scale 1.0 cameras must be provided.") + self._train_cameras: Dict[float, List[Camera]] = { + float(scale): list(cam_list) for scale, cam_list in cameras_by_scale.items() + } + + def getTrainCameras(self, scale: float = 1.0): + scale_key = float(scale) + if scale_key not in self._train_cameras: + scale_key = 1.0 + return list(self._train_cameras[scale_key]) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ): + preferred = scale2 if iteration <= warn_until_iter and float(scale2) in self._train_cameras else scale + fallback_scale = float(preferred) if float(preferred) in self._train_cameras else 1.0 + return list(self._train_cameras[fallback_scale]) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + def _render( + view: Camera, + pc_obj: GaussianModel, + pipe_obj: PipelineParams, + bg_color: torch.Tensor, + *, + kernel_size: float = 0.0, + require_coord: bool = False, + require_depth: bool = True, + ): + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_view(view: Camera): + return _render(view, gaussians, pipe, background) + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = _render( + view, + pc_obj, + pipe_obj, + bg_color, + kernel_size=kernel_size, + require_coord=require_coord, + require_depth=require_depth, + ) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def export_mesh_from_gaussians( + gaussians: GaussianModel, + mesh_state: Dict, + output_path: str, + reference_camera: Optional[Camera] = None, +) -> None: + delaunay_tets = mesh_state.get("delaunay_tets") + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_export = mesh + if reference_camera is not None: + mesh_to_export = frustum_cull_mesh(mesh, reference_camera) + + verts = mesh_to_export.verts.detach().cpu().numpy() + faces = mesh_to_export.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def load_mesh_config_file(name: str) -> Dict: + config_path = os.path.join(BASE_DIR, "configs", "mesh", f"{name}.yaml") + if not os.path.isfile(config_path): + raise FileNotFoundError(f"Mesh config not found: {config_path}") + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def ensure_learnable_occupancy(gaussians: GaussianModel) -> None: + """Ensure occupancy buffers exist and shifts are trainable.""" + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + device = gaussians._xyz.device + n_pts = gaussians._xyz.shape[0] + base = torch.zeros((n_pts, 9), device=device) + shift = torch.zeros_like(base) + gaussians.learn_occupancy = True + gaussians._base_occupancy = torch.nn.Parameter( + base.requires_grad_(False), requires_grad=False + ) + gaussians._occupancy_shift = torch.nn.Parameter(shift.requires_grad_(True)) + gaussians.set_occupancy_mode("occupancy_shift") + gaussians._occupancy_shift.requires_grad_(True) + + +@dataclass +class DepthRecord: + depth: torch.Tensor # (1, H, W) + valid_mask: torch.Tensor # (1, H, W) + + +class DepthMapProvider: + """Loads depth maps and matches them to cameras via naming convention.""" + + def __init__( + self, + depth_dir: Path, + cameras: Sequence, + depth_scale: float = 1.0, + depth_offset: float = 0.0, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + if not depth_dir.is_dir(): + raise FileNotFoundError(f"Depth directory not found: {depth_dir}") + + file_list = sorted([f.name for f in depth_dir.iterdir() if f.suffix == ".npy"]) + if not file_list: + raise ValueError(f"No depth .npy files found in {depth_dir}") + + pattern = re.compile(r"depth_img_(\d+)_(\d+)\.npy$") + indexed: Dict[Tuple[int, int], str] = {} + for filename in file_list: + match = pattern.match(filename) + if match: + cam_idx = int(match.group(1)) + frame_idx = int(match.group(2)) + indexed[(cam_idx, frame_idx)] = filename + + fallback_files = sorted(file_list, key=self._natural_key) + + self.depth_scale = depth_scale + self.depth_offset = depth_offset + self.clip_min = clip_min + self.clip_max = clip_max + + self.depth_height: Optional[int] = None + self.depth_width: Optional[int] = None + self.global_min: float = float("inf") + self.global_max: float = float("-inf") + self.global_valid_pixels: int = 0 + self.records: List[DepthRecord] = [] + + for cam_index, camera in enumerate(cameras): + depth_path = self._resolve_path( + camera_name=getattr(camera, "image_name", str(cam_index)), + camera_idx=cam_index, + indexed_files=indexed, + fallback_files=fallback_files, + ) + full_path = depth_dir / depth_path + depth_np = np.load(full_path) + if depth_np.ndim == 3 and depth_np.shape[-1] == 1: + depth_np = depth_np[..., 0] + if depth_np.ndim == 2: + depth_np = depth_np[None, ...] + elif depth_np.ndim == 3 and depth_np.shape[0] == 1: + pass + else: + raise ValueError(f"Unexpected depth shape {depth_np.shape} in {full_path}") + + depth = torch.from_numpy(depth_np.astype(np.float32)) + depth = depth * depth_scale + depth_offset + if clip_min is not None or clip_max is not None: + depth = depth.clamp( + min=clip_min if clip_min is not None else float("-inf"), + max=clip_max if clip_max is not None else float("inf"), + ) + + mask = (depth > 0.0).float() + if self.depth_height is None: + self.depth_height, self.depth_width = depth.shape[-2:] + + valid_values = depth[mask > 0.5] + if valid_values.numel() > 0: + self.global_min = min(self.global_min, float(valid_values.min())) + self.global_max = max(self.global_max, float(valid_values.max())) + self.global_valid_pixels += int(valid_values.numel()) + + self.records.append(DepthRecord(depth=depth.contiguous(), valid_mask=mask)) + + if self.global_min == float("inf"): + self.global_min = 0.0 + self.global_max = 0.0 + + @staticmethod + def _natural_key(path: str) -> List[object]: + tokens = re.split(r"(\d+)", Path(path).stem) + return [int(tok) if tok.isdigit() else tok for tok in tokens if tok] + + @staticmethod + def _resolve_path( + camera_name: str, + camera_idx: int, + indexed_files: Dict[Tuple[int, int], str], + fallback_files: Sequence[str], + ) -> str: + match = re.search(r"traj_(\d+)_cam(\d+)", camera_name) + if match: + frame_idx = int(match.group(1)) + cam_idx = int(match.group(2)) + candidate = indexed_files.get((cam_idx, frame_idx)) + if candidate: + return candidate + if camera_idx >= len(fallback_files): + raise IndexError( + f"Camera index {camera_idx} exceeds number of depth files ({len(fallback_files)})." + ) + return fallback_files[camera_idx] + + def __len__(self) -> int: + return len(self.records) + + def get(self, index: int, device: torch.device) -> DepthRecord: + record = self.records[index] + return DepthRecord( + depth=record.depth.to(device, non_blocking=True), + valid_mask=record.valid_mask.to(device, non_blocking=True), + ) + + +def compute_depth_loss( + predicted: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + epsilon: float, +) -> Tuple[torch.Tensor, float, float, int]: + if predicted.shape != target.shape: + target = F.interpolate( + target.unsqueeze(0), + size=predicted.shape[-2:], + mode="bilinear", + align_corners=True, + ).squeeze(0) + mask = F.interpolate( + mask.unsqueeze(0), + size=predicted.shape[-2:], + mode="nearest", + ).squeeze(0) + + valid = mask > 0.5 + valid_pixels = int(valid.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=predicted.device, dtype=predicted.dtype) + return zero, 0.0, 0.0, 0 + + diff = (predicted - target).abs() * mask + loss = diff.sum() / (mask.sum() + epsilon) + mae = diff.sum().item() / (valid_pixels + epsilon) + valid_fraction = valid_pixels / mask.numel() + return loss, mae, valid_fraction, valid_pixels + + +class DepthTrainer: + """Orchestrates depth-supervised optimization of a Gaussian model.""" + + def __init__(self, args: argparse.Namespace) -> None: + self.args = args + self.device = torch.device(args.device) + self._prepare_seeds(args.seed) + + base_cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + data_device=args.data_device, + ) + print(f"[INFO] Loaded {len(base_cameras)} cameras.") + self.cameras_by_scale = _build_scaled_cameras(base_cameras, scales=(1.0, 2.0)) + self.scene = ManualScene(self.cameras_by_scale) + self.cameras = self.cameras_by_scale[1.0] + + self.fixed_view_idx = args.fixed_view_idx + if self.fixed_view_idx is not None: + if not (0 <= self.fixed_view_idx < len(self.cameras)): + raise ValueError( + f"fixed_view_idx {self.fixed_view_idx} out of bounds for {len(self.cameras)} cameras." + ) + + depth_dir = Path(args.depth_dir) + self.depth_provider = DepthMapProvider( + depth_dir=depth_dir, + cameras=self.cameras, + depth_scale=args.depth_scale, + depth_offset=args.depth_offset, + clip_min=args.depth_clip_min, + clip_max=args.depth_clip_max, + ) + if self.depth_provider.global_valid_pixels == 0: + raise RuntimeError("No valid depth pixels found across the dataset.") + print( + "[INFO] Depth statistics after scaling: " + f"{self.depth_provider.global_min:.4f} – {self.depth_provider.global_max:.4f} " + f"({self.depth_provider.global_valid_pixels} valid pixels)" + ) + + self.scene.cameras_extent = self._estimate_extent(args.ply_path) + + self.gaussians = GaussianModel( + sh_degree=args.sh_degree, + use_mip_filter=not args.disable_mip_filter, + learn_occupancy=True, + use_appearance_network=False, + ) + self.gaussians.load_ply(args.ply_path) + ensure_learnable_occupancy(self.gaussians) + self.gaussians.init_culling(len(self.cameras)) + if self.gaussians.spatial_lr_scale <= 0: + self.gaussians.spatial_lr_scale = 1.0 + + opt_parser = argparse.ArgumentParser(add_help=False) + opt_params = OptimizationParams(opt_parser) + opt_params.iterations = args.iterations + opt_params.position_lr_init *= args.initial_lr_scale + opt_params.position_lr_final *= args.initial_lr_scale + self.gaussians.training_setup(opt_params) + if args.freeze_colors: + if hasattr(self.gaussians, "_features_dc"): + self.gaussians._features_dc.requires_grad_(False) + if hasattr(self.gaussians, "_features_rest"): + self.gaussians._features_rest.requires_grad_(False) + + self.background = torch.zeros(3, dtype=torch.float32, device=self.device) + pipe_parser = argparse.ArgumentParser(add_help=False) + self.pipe = PipelineParams(pipe_parser) + self.pipe.compute_cov3D_python = args.compute_cov3d_python + self.pipe.convert_SHs_python = args.convert_shs_python + self.pipe.debug = args.debug + + self.render_view, self.render_for_mesh = build_render_functions( + self.gaussians, self.pipe, self.background + ) + self.mesh_enabled = args.mesh_regularization + if self.mesh_enabled: + mesh_config = self._load_mesh_config( + args.mesh_config, args.mesh_start_iter, args.mesh_stop_iter, args.iterations + ) + occupancy_mode = mesh_config.get("occupancy_mode", "occupancy_shift") + if occupancy_mode != "occupancy_shift": + raise ValueError( + f"Mesh config '{args.mesh_config}' must use occupancy_mode 'occupancy_shift', got '{occupancy_mode}'." + ) + self.gaussians.set_occupancy_mode(occupancy_mode) + self.mesh_renderer, self.mesh_state = initialize_mesh_regularization( + self.scene, + mesh_config, + ) + self.mesh_state["reset_delaunay_samples"] = True + self.mesh_state["reset_sdf_values"] = True + self.mesh_config = mesh_config + self.runtime_args = argparse.Namespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=args.depth_reinit_iter, + ) + self._warmup_mesh_visibility() + else: + self.mesh_renderer = None + self.mesh_state = {} + self.mesh_config = {} + self.runtime_args = None + + self.optimizer = self.gaussians.optimizer + self.opt_params = opt_params + + self.output_dir = Path(args.output_dir) + (self.output_dir / "logs").mkdir(parents=True, exist_ok=True) + self.loss_log_path = self.output_dir / "logs" / "losses.jsonl" + self.pending_indices: List[int] = [] + self.ema_depth: Optional[float] = None + self.ema_mesh: Optional[float] = None + self.printed_depth_diag = False + self.log_depth_stats = bool(args.log_depth_stats or self.fixed_view_idx is not None) + + @staticmethod + def _prepare_seeds(seed: int) -> None: + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + def _estimate_extent(self, ply_path: str) -> float: + import trimesh + + mesh = trimesh.load(ply_path, process=False) + if hasattr(mesh, "vertices"): + vertices = np.asarray(mesh.vertices) + center = vertices.mean(axis=0) + radius = np.linalg.norm(vertices - center, axis=1).max() + return float(radius) + raise ValueError("Could not estimate scene extent from PLY.") + + def _load_mesh_config( + self, + name: str, + start_iter_override: Optional[int], + stop_iter_override: Optional[int], + total_iterations: int, + ) -> Dict: + config = load_mesh_config_file(name) + if start_iter_override is not None: + config["start_iter"] = max(1, start_iter_override) + if stop_iter_override is not None: + config["stop_iter"] = stop_iter_override + else: + config["stop_iter"] = max(config.get("stop_iter", total_iterations), total_iterations) + config["stop_iter"] = max(config["stop_iter"], config.get("start_iter", 1)) + if "occupancy_mode" not in config: + config["occupancy_mode"] = "occupancy_shift" + self.mesh_config = config + return config + + def _check_gaussian_numerics(self, label: str) -> None: + """Detect NaNs/Infs or extreme magnitudes before hitting CUDA kernels.""" + stats = { + "xyz": self.gaussians.get_xyz, + "scaling": self.gaussians.get_scaling, + "rotation": self.gaussians.get_rotation, + "opacity": self.gaussians.get_opacity, + } + for name, tensor in stats.items(): + if not torch.isfinite(tensor).all(): + invalid_mask = ~torch.isfinite(tensor) + num_bad = int(invalid_mask.sum().item()) + example_idx = invalid_mask.nonzero(as_tuple=False)[:5].flatten().tolist() + raise RuntimeError( + f"[NUMERIC] Detected {num_bad} non-finite entries in '{name}' " + f"during {label}. Sample indices: {example_idx}" + ) + max_abs = tensor.abs().max().item() + if max_abs > 1e6: + print( + f"[WARN] Large magnitude detected in '{name}' during {label}: " + f"{max_abs:.3e}" + ) + + def _warmup_mesh_visibility(self) -> None: + warmup_views = self.scene.getTrainCameras_warn_up( + iteration=1, + warn_until_iter=self.args.warn_until_iter, + scale=1.0, + scale2=2.0, + ) + for view in warmup_views: + render_simp( + view, + self.gaussians, + self.pipe, + self.background, + culling=self.gaussians._culling[:, view.uid], + ) + + def _select_view(self) -> int: + if self.fixed_view_idx is not None: + return self.fixed_view_idx + if not self.pending_indices: + self.pending_indices = list(range(len(self.cameras))) + random.shuffle(self.pending_indices) + return self.pending_indices.pop() + + def _log_iteration( + self, + iteration: int, + view_idx: int, + total_loss: float, + depth_loss: float, + mesh_loss: float, + depth_mae: float, + valid_fraction: float, + valid_pixels: int, + extra: Dict[str, float], + ) -> None: + record = { + "iteration": iteration, + "view_index": view_idx, + "total_loss": total_loss, + "depth_loss": depth_loss, + "mesh_loss": mesh_loss, + "depth_mae": depth_mae, + "valid_fraction": valid_fraction, + "valid_pixels": valid_pixels, + } + if self.ema_depth is not None: + record["ema_depth_loss"] = self.ema_depth + if self.ema_mesh is not None: + record["ema_mesh_loss"] = self.ema_mesh + record.update(extra) + with open(self.loss_log_path, "a", encoding="utf-8") as f: + f.write(json.dumps(record) + "\n") + + def run(self) -> None: + os.makedirs(self.output_dir, exist_ok=True) + with open(self.loss_log_path, "w", encoding="utf-8"): + pass + + for iteration in range(1, self.args.iterations + 1): + self.gaussians.update_learning_rate(iteration) + view_idx = self._select_view() + viewpoint = self.cameras[view_idx] + + depth_record = self.depth_provider.get(view_idx, self.device) + render_pkg = self.render_view(viewpoint) + pred_depth = render_pkg["median_depth"] + + depth_loss, depth_mae, valid_fraction, valid_pixels = compute_depth_loss( + predicted=pred_depth, + target=depth_record.depth, + mask=depth_record.valid_mask, + epsilon=self.args.depth_loss_epsilon, + ) + if valid_pixels == 0: + if iteration % self.args.log_interval == 0 or iteration == 1: + print(f"[Iter {iteration:05d}] skip view {view_idx} (no valid depth)") + continue + + mask_valid = depth_record.valid_mask.to(pred_depth.device) > 0.5 + pred_valid = pred_depth[mask_valid] + target_valid = depth_record.depth[mask_valid] + + if not self.printed_depth_diag: + if target_valid.numel() > 0 and pred_valid.numel() > 0: + print( + "[DIAG] First depth batch — target range {t_min:.4f} – {t_max:.4f}, " + "predicted range {p_min:.4f} – {p_max:.4f}".format( + t_min=float(target_valid.min().item()), + t_max=float(target_valid.max().item()), + p_min=float(pred_valid.min().item()), + p_max=float(pred_valid.max().item()), + ) + ) + self.printed_depth_diag = True + + total_loss = depth_loss + mesh_loss_tensor = torch.zeros_like(depth_loss) + mesh_pkg: Dict[str, torch.Tensor] = {} + mesh_active = self.mesh_enabled and iteration >= self.mesh_config.get("start_iter", 1) + if mesh_active: + self._check_gaussian_numerics(f"iter_{iteration}_before_mesh") + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=self.gaussians, + scene=self.scene, + pipe=self.pipe, + background=self.background, + kernel_size=0.0, + config=self.mesh_config, + mesh_renderer=self.mesh_renderer, + mesh_state=self.mesh_state, + render_func=self.render_for_mesh, + weight_adjustment=100.0 / max(self.args.iterations, 1), + args=self.runtime_args, + integrate_func=integrate, + ) + mesh_loss_tensor = mesh_pkg["mesh_loss"] + self.mesh_state = mesh_pkg["updated_state"] + total_loss = total_loss + mesh_loss_tensor + + self.optimizer.zero_grad(set_to_none=True) + total_loss.backward() + if self.args.grad_clip_norm > 0.0: + params: List[torch.Tensor] = [] + for group in self.optimizer.param_groups: + for param in group.get("params", []): + if isinstance(param, torch.Tensor) and param.requires_grad: + params.append(param) + if params: + clip_grad_norm_(params, self.args.grad_clip_norm) + + visibility = render_pkg["visibility_filter"] + radii = render_pkg["radii"] + self.optimizer.step(visibility, radii.shape[0]) + + total_val = float(total_loss.item()) + depth_val = float(depth_loss.item()) + mesh_val = float(mesh_loss_tensor.item()) + self.ema_depth = depth_val if self.ema_depth is None else (0.9 * self.ema_depth + 0.1 * depth_val) + self.ema_mesh = mesh_val if self.ema_mesh is None else (0.9 * self.ema_mesh + 0.1 * mesh_val) + + extra = {k: float(v.item()) for k, v in mesh_pkg.items() if hasattr(v, "item") and k.endswith("_loss")} + if self.log_depth_stats and target_valid.numel() > 0: + extra.update( + { + "pred_depth_min": float(pred_valid.min().item()), + "pred_depth_max": float(pred_valid.max().item()), + "pred_depth_mean": float(pred_valid.mean().item()), + "pred_depth_std": float(pred_valid.std(unbiased=False).item()), + "target_depth_min": float(target_valid.min().item()), + "target_depth_max": float(target_valid.max().item()), + "target_depth_mean": float(target_valid.mean().item()), + "target_depth_std": float(target_valid.std(unbiased=False).item()), + } + ) + self._log_iteration( + iteration=iteration, + view_idx=view_idx, + total_loss=total_val, + depth_loss=depth_val, + mesh_loss=mesh_val, + depth_mae=depth_mae, + valid_fraction=valid_fraction, + valid_pixels=valid_pixels, + extra=extra, + ) + + if iteration % self.args.log_interval == 0 or iteration == 1: + print( + "[Iter {iter:05d}] loss={loss:.6f} depth={depth:.6f} mesh={mesh:.6f} " + "mae={mae:.6f} valid={valid:.3f}".format( + iter=iteration, + loss=total_val, + depth=depth_val, + mesh=mesh_val, + mae=depth_mae, + valid=valid_fraction, + ) + ) + + if mesh_active and mesh_pkg.get("gaussians_changed", False): + self.mesh_state = reset_mesh_state_at_next_iteration(self.mesh_state) + + if self.args.export_interval > 0 and iteration % self.args.export_interval == 0: + self._export_state(iteration) + + self._export_state(self.args.iterations, final=True) + + def _sink_path(self, iteration: int, final: bool = False) -> Path: + + target_dir = self.output_dir / ("final" if final else f"iter_{iteration:05d}") + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir + + def _export_state(self, iteration: int, final: bool = False) -> None: + target_dir = self._sink_path(iteration, final) + ply_path = target_dir / f"gaussians_iter_{iteration:05d}.ply" + save_mesh = ( + self.mesh_enabled + and (iteration >= self.mesh_config.get("start_iter", 1) or final) + and self.mesh_state + ) + if save_mesh and self.mesh_state.get("delaunay_tets") is not None: + mesh_path = target_dir / f"mesh_iter_{iteration:05d}.ply" + export_mesh_from_gaussians( + gaussians=self.gaussians, + mesh_state=self.mesh_state, + output_path=str(mesh_path), + reference_camera=None, + ) + self.gaussians.save_ply(str(ply_path)) + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Depth-supervised training for Gaussian Splatting.") + parser.add_argument("--ply_path", type=str, required=True, help="Initial Gaussian PLY file.") + parser.add_argument("--camera_poses", type=str, required=True, help="Camera pose JSON compatible with ply2mesh.load_cameras_from_json.") + parser.add_argument("--depth_dir", type=str, required=True, help="Folder of per-view depth .npy files.") + parser.add_argument("--output_dir", type=str, default="./depth_training_output", help="Directory for logs and exports.") + parser.add_argument("--iterations", type=int, default=5000, help="Number of optimization steps.") + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--device", type=str, default="cuda", help="PyTorch device identifier.") + parser.add_argument("--data_device", type=str, default="cpu", help="Device to store camera image tensors.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered height.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field of view in degrees.") + parser.add_argument("--depth_scale", type=float, default=1.0, help="Scale factor applied to loaded depth maps.") + parser.add_argument("--depth_offset", type=float, default=0.0, help="Additive offset applied to depth.") + parser.add_argument("--depth_clip_min", type=float, default=None, help="Minimum depth after scaling.") + parser.add_argument("--depth_clip_max", type=float, default=None, help="Maximum depth after scaling.") + parser.add_argument("--depth_loss_epsilon", type=float, default=1e-6, help="Stability epsilon for depth loss denominator.") + parser.add_argument("--mesh_config", type=str, default="medium", help="Mesh-in-the-loop configuration name.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration to start mesh regularization.") + parser.add_argument("--mesh_stop_iter", type=int, default=None, help="Iteration to stop mesh regularization.") + parser.add_argument("--export_interval", type=int, default=1000, help="Export mesh/ply every N iterations.") + parser.add_argument("--log_interval", type=int, default=100, help="Console log interval.") + parser.add_argument("--grad_clip_norm", type=float, default=0.0, help="Gradient clipping norm (0 disables).") + parser.add_argument("--initial_lr_scale", type=float, default=1.0, help="Scaling factor for position learning rate.") + parser.add_argument("--convert_shs_python", action="store_true", help="Use PyTorch SH conversion (debug only).") + parser.add_argument("--compute_cov3d_python", action="store_true", help="Use PyTorch covariance (debug only).") + parser.add_argument("--debug", action="store_true", help="Enable renderer debug outputs.") + parser.add_argument("--disable_mip_filter", action="store_true", help="Disable 3D Mip filter.") + parser.add_argument("--sh_degree", type=int, default=0, help="Spherical harmonic degree for Gaussian colors.") + parser.add_argument("--mesh_regularization", action="store_true", help="Enable mesh-in-the-loop regularization.") + parser.add_argument("--freeze_colors", dest="freeze_colors", action="store_true", help="Freeze SH features during depth training.", default=True) + parser.add_argument("--no-freeze_colors", dest="freeze_colors", action="store_false", help="Allow SH features to be optimized.") + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for densification/mesh utilities.") + parser.add_argument("--imp_metric", type=str, default="outdoor", choices=["outdoor", "indoor"], help="Importance metric for mesh sampling heuristics.") + parser.add_argument("--depth_reinit_iter", type=int, default=2000, help="Iteration to trigger optional depth reinitialization routines.") + parser.add_argument("--fixed_view_idx", type=int, default=None, help="If provided, always train on this camera index (for debugging).") + parser.add_argument("--log_depth_stats", action="store_true", help="Record detailed depth statistics per iteration.") + return parser + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + safe_state(False) + trainer = DepthTrainer(args) + trainer.run() + + +if __name__ == "__main__": + main() diff --git a/milo/useless_maybe/downsample_colmap_points.py b/milo/useless_maybe/downsample_colmap_points.py new file mode 100644 index 0000000..8180ecb --- /dev/null +++ b/milo/useless_maybe/downsample_colmap_points.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +"""Downsample COLMAP points3D.bin (and optionally regenerate points3D.ply).""" + +from __future__ import annotations + +import argparse +import os +import shutil +import struct +from pathlib import Path +from typing import Optional + +import numpy as np +from plyfile import PlyData, PlyElement + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--input-bin", + required=True, + help="Path to the COLMAP points3D.bin file", + ) + parser.add_argument( + "--output-bin", + help=( + "Path to the output points3D.bin file. Defaults to overwriting the input " + "after creating a .bak backup." + ), + ) + parser.add_argument( + "--target", + type=int, + default=4_000_000, + help="Maximum number of points to keep (default: 4,000,000)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible sampling (default: 42)", + ) + parser.add_argument( + "--mode", + choices=("random", "radius"), + default="random", + help=( + "Sampling strategy: 'random' selects uniformly at random; " + "'radius' keeps points closest to the centroid (default: random)" + ), + ) + parser.add_argument( + "--ply-output", + help=( + "Optional path to write a matching points3D.ply. Defaults to the sibling " + "points3D.ply next to the binary file when omitted. Use '--ply-output '' ' to skip." + ), + ) + return parser.parse_args() + + +RECORD_STRUCT = struct.Struct("np.ndarray: + """Read point positions once to compute spatial statistics.""" + positions = np.empty((num_points, 3), dtype=np.float32) + with path.open("rb") as fin: + fin.seek(8) # Skip count header + for idx in range(num_points): + record_bytes = fin.read(RECORD_STRUCT.size) + if not record_bytes: + raise EOFError("Unexpected end of file while gathering positions") + _, x, y, z, _, _, _, _ = RECORD_STRUCT.unpack(record_bytes) + positions[idx] = (x, y, z) + (track_len,) = TRACK_LEN_STRUCT.unpack(fin.read(TRACK_LEN_STRUCT.size)) + fin.seek(8 * track_len, os.SEEK_CUR) + return positions + + +def pick_indices( + total: int, + target: int, + rng: np.random.Generator, + mode: str, + positions: np.ndarray | None = None, +) -> np.ndarray: + if target <= 0 or total <= target: + return np.arange(total, dtype=np.int64) + if mode == "random": + indices = rng.choice(total, size=target, replace=False) + elif mode == "radius": + if positions is None: + raise ValueError("Positions are required for radius-based sampling.") + centroid = positions.mean(axis=0, dtype=np.float64) + dists = np.sum((positions - centroid) ** 2, axis=1) + partition = np.argpartition(dists, target - 1)[:target] + indices = np.sort(partition) + else: + raise ValueError(f"Unknown sampling mode: {mode}") + if not np.all(np.diff(indices) >= 0): + indices.sort() + return indices.astype(np.int64) + + +def write_ply(path: Path, positions: np.ndarray, colors: np.ndarray) -> None: + count = positions.shape[0] + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + elements = np.empty(count, dtype=dtype) + elements["x"] = positions[:, 0] + elements["y"] = positions[:, 1] + elements["z"] = positions[:, 2] + elements["nx"] = 0.0 + elements["ny"] = 0.0 + elements["nz"] = 0.0 + elements["red"] = colors[:, 0] + elements["green"] = colors[:, 1] + elements["blue"] = colors[:, 2] + ply = PlyData([PlyElement.describe(elements, "vertex")], text=False) + ply.write(str(path)) + + +def main() -> None: + args = parse_args() + input_bin = Path(args.input_bin).expanduser().resolve() + output_bin = Path(args.output_bin).expanduser().resolve() if args.output_bin else input_bin + + if not input_bin.exists(): + raise FileNotFoundError(f"Input binary file not found: {input_bin}") + + rng = np.random.default_rng(args.seed) + + with input_bin.open("rb") as fin: + num_points = struct.unpack(" 0 else None + + with input_bin.open("rb") as fin, tmp_bin.open("wb") as fout: + fin.seek(8) + fout.write(struct.pack("argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--input", required=True, help="Path to the source PLY file") + parser.add_argument( + "--output", + help=( + "Path to the output PLY file. If omitted, the input is overwritten " + "after creating a backup with suffix .bak" + ), + ) + parser.add_argument( + "--target", + type=int, + default=4_000_000, + help="Desired maximum number of points (default: 4,000,000)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible sampling (default: 42)", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + input_path = Path(args.input).expanduser().resolve() + if args.output: + output_path = Path(args.output).expanduser().resolve() + else: + output_path = input_path + + if not input_path.exists(): + raise FileNotFoundError(f"Input PLY not found: {input_path}") + + print(f"[INFO] Loading PLY: {input_path}") + ply = PlyData.read(str(input_path)) + + if "vertex" not in ply: + raise ValueError("Input PLY does not contain a vertex element") + + vertex_data = ply["vertex"] + total_vertices = len(vertex_data) + target = max(int(args.target), 0) + print(f"[INFO] Total vertices: {total_vertices}") + print(f"[INFO] Target vertices: {target}") + + if target == 0 or total_vertices <= target: + print("[INFO] No downsampling needed.") + if output_path != input_path: + print(f"[INFO] Copying file to {output_path}") + shutil.copyfile(input_path, output_path) + else: + print("[INFO] Input already satisfies target; nothing to do.") + return + + rng = np.random.default_rng(args.seed) + print("[INFO] Sampling indices...") + sample_indices = rng.choice(total_vertices, size=target, replace=False) + sample_indices.sort() + downsampled_vertex = vertex_data[sample_indices] + + print("[INFO] Preparing PLY structure...") + new_vertex_element = PlyElement.describe(downsampled_vertex, "vertex") + new_ply = PlyData([new_vertex_element], text=ply.text, byte_order=ply.byte_order) + new_ply.comments = ply.comments + + if output_path == input_path: + backup_path = input_path.with_suffix(input_path.suffix + ".bak") + if not backup_path.exists(): + print(f"[INFO] Creating backup at {backup_path}") + shutil.copyfile(input_path, backup_path) + else: + print(f"[WARNING] Backup already exists at {backup_path}; reusing it.") + + os.makedirs(output_path.parent, exist_ok=True) + print(f"[INFO] Writing downsampled PLY to {output_path}") + new_ply.write(str(output_path)) + print("[INFO] Done.") + + +if __name__ == "__main__": + main() diff --git a/milo/useless_maybe/iterative_occupancy_refine.py b/milo/useless_maybe/iterative_occupancy_refine.py new file mode 100644 index 0000000..2eaacf9 --- /dev/null +++ b/milo/useless_maybe/iterative_occupancy_refine.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +"""Iteratively refine learnable occupancy (SDF) while keeping Gaussian geometry fixed.""" + +import argparse +import json +import os +import random +from types import SimpleNamespace + +import numpy as np +import torch +import torch.nn.functional as F + +from arguments import PipelineParams +from gaussian_renderer.radegs import integrate_radegs +from ply2mesh import ( + ManualScene, + load_cameras_from_json, + freeze_gaussian_rigid_parameters, + build_render_functions, + load_mesh_config_file, + export_mesh_from_gaussians, +) +from regularization.regularizer.mesh import initialize_mesh_regularization, compute_mesh_regularization +from regularization.sdf.learnable import convert_occupancy_to_sdf +from scene.gaussian_model import GaussianModel +from utils.geometry_utils import flatten_voronoi_features + + +def ensure_learnable_occupancy(gaussians: GaussianModel) -> None: + """Ensure occupancy buffers exist and only the shift is trainable.""" + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + device = gaussians._xyz.device + n_pts = gaussians._xyz.shape[0] + base = torch.zeros((n_pts, 9), device=device) + shift = torch.zeros_like(base) + gaussians.learn_occupancy = True + gaussians._base_occupancy = torch.nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = torch.nn.Parameter(shift.requires_grad_(True)) + gaussians.set_occupancy_mode("occupancy_shift") + gaussians._occupancy_shift.requires_grad_(True) + + +def extract_loss_scalars(metrics: dict) -> dict: + """Extract scalar loss values from the mesh regularization outputs.""" + scalars = {} + for key, value in metrics.items(): + if not key.endswith("_loss"): + continue + scalar = None + if isinstance(value, torch.Tensor): + if value.ndim == 0: + scalar = float(value.item()) + elif isinstance(value, (float, int)): + scalar = float(value) + if scalar is not None: + scalars[key] = scalar + return scalars + + +def export_iteration_state( + iteration: int, + gaussians: GaussianModel, + mesh_state: dict, + output_dir: str, + reference_camera=None, +) -> None: + os.makedirs(output_dir, exist_ok=True) + mesh_path = os.path.join(output_dir, f"mesh_iter_{iteration:05d}.ply") + ply_path = os.path.join(output_dir, f"gaussians_iter_{iteration:05d}.ply") + + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=mesh_path, + reference_camera=reference_camera, + ) + gaussians.save_ply(ply_path) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Occupancy-only refinement from a pretrained Gaussian PLY.") + parser.add_argument("--ply_path", type=str, required=True, help="Input Gaussian PLY (assumed geometrically correct).") + parser.add_argument("--camera_poses", type=str, required=True, help="JSON with camera poses matching the scene.") + parser.add_argument("--mesh_config", type=str, default="default", help="Mesh regularization config name.") + parser.add_argument("--iterations", type=int, default=2000, help="Number of optimization steps.") + parser.add_argument("--occupancy_lr", type=float, default=0.001, help="Learning rate for occupancy shift.") + parser.add_argument( + "--mesh_loss_weight", + type=float, + default=5.0, + help="Global weight applied to the mesh regularization loss.", + ) + parser.add_argument("--log_interval", type=int, default=100, help="Logging interval.") + parser.add_argument("--export_interval", type=int, default=1000, help="Mesh export interval (0 disables periodic export).") + parser.add_argument("--output_dir", type=str, default="./occupancy_refine_output", help="Directory to store outputs.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field-of-view in degrees.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered image width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered image height.") + parser.add_argument( + "--background", + type=float, + nargs=3, + default=(0.0, 0.0, 0.0), + help="Background color used for rendering (RGB).", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration at which mesh regularization starts.") + parser.add_argument( + "--mesh_stop_iter", + type=int, + default=None, + help="Iteration after which mesh regularization stops (defaults to total iterations).", + ) + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for surface sampling.") + parser.add_argument( + "--imp_metric", + type=str, + default="outdoor", + choices=["outdoor", "indoor"], + help="Importance metric used for surface sampling.", + ) + parser.add_argument( + "--cull_on_export", + action="store_true", + help="Frustum cull meshes using the first camera before export.", + ) + parser.add_argument( + "--sdf_log_samples", + type=int, + default=32, + help="Number of SDF values recorded per iteration (0 disables sampling).", + ) + parser.add_argument( + "--loss_log_filename", + type=str, + default="losses.jsonl", + help="Filename used for per-iteration loss logs.", + ) + parser.add_argument( + "--sdf_log_filename", + type=str, + default="sdf_samples.jsonl", + help="Filename used for per-iteration SDF sample logs.", + ) + parser.add_argument( + "--surface_gaussians_filename", + type=str, + default="surface_gaussians_initial.ply", + help="Filename for the first batch of surface Gaussians (empty string disables export).", + ) + return parser + + +def main() -> None: + parser = build_arg_parser() + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device is required for occupancy refinement.") + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + ) + print(f"[INFO] Loaded {len(cameras)} cameras from {args.camera_poses}.") + + scene = ManualScene(cameras) + + mesh_config = load_mesh_config_file(args.mesh_config) + mesh_config["start_iter"] = max(1, args.mesh_start_iter) + if args.mesh_stop_iter is not None: + mesh_config["stop_iter"] = args.mesh_stop_iter + else: + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", args.iterations), args.iterations) + + pipe_parser = argparse.ArgumentParser() + pipe = PipelineParams(pipe_parser) + background = torch.tensor(args.background, dtype=torch.float32, device="cuda") + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + print(f"[INFO] Loaded {gaussians._xyz.shape[0]} Gaussians from {args.ply_path}.") + + ensure_learnable_occupancy(gaussians) + + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + + gaussians.init_culling(len(cameras)) + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + freeze_gaussian_rigid_parameters(gaussians) + + optimizer = torch.optim.Adam([gaussians._occupancy_shift], lr=args.occupancy_lr) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + mesh_renderer, mesh_state = initialize_mesh_regularization(scene, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + surface_gaussians_path = None + if args.surface_gaussians_filename: + surface_gaussians_path = os.path.join(args.output_dir, args.surface_gaussians_filename) + print(f"[INFO] Will export first sampled surface Gaussians to {surface_gaussians_path}.") + mesh_state["surface_sample_export_path"] = surface_gaussians_path + mesh_state["surface_sample_saved"] = False + mesh_state["surface_sample_saved_iter"] = None + + runtime_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=getattr(args, "depth_reinit_iter", args.warn_until_iter), + ) + + os.makedirs(args.output_dir, exist_ok=True) + log_dir = os.path.join(args.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + loss_log_path = os.path.join(log_dir, args.loss_log_filename) + sdf_log_path = os.path.join(log_dir, args.sdf_log_filename) + + ema_loss = None + pending_view_indices: list[int] = [] + sdf_sample_indices_tensor = None # Stored on the same device as pivots_sdf_flat + sdf_sample_indices_list = None + + with open(loss_log_path, "w", encoding="utf-8") as loss_log_file, open( + sdf_log_path, "w", encoding="utf-8" + ) as sdf_log_file: + # Iterate through all cameras without replacement; reshuffle when one pass finishes. + for iteration in range(1, args.iterations + 1): + if not pending_view_indices: + pending_view_indices = list(range(len(cameras))) + random.shuffle(pending_view_indices) + + view_idx = pending_view_indices.pop() + viewpoint = cameras[view_idx] + + render_pkg = render_view(viewpoint) + + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=gaussians, + scene=scene, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=100.0 / max(args.iterations, 1), + args=runtime_args, + integrate_func=integrate_radegs, + ) + mesh_state = mesh_pkg["updated_state"] + + with torch.no_grad(): + current_occ = torch.sigmoid(gaussians._base_occupancy + gaussians._occupancy_shift) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(current_occ)) + pivots_sdf_flat = pivots_sdf.view(-1).detach() + if pivots_sdf_flat.numel() > 0: + sdf_mean = float(pivots_sdf_flat.mean().item()) + sdf_std = float(pivots_sdf_flat.std(unbiased=False).item()) + else: + sdf_mean = 0.0 + sdf_std = 0.0 + + sample_indices_list = [] + sample_values_list = [] + if args.sdf_log_samples > 0 and pivots_sdf_flat.numel() > 0: + sample_count = min(args.sdf_log_samples, pivots_sdf_flat.numel()) + need_refresh = sdf_sample_indices_tensor is None or sdf_sample_indices_tensor.numel() != sample_count + if not need_refresh: + max_index = int(sdf_sample_indices_tensor.max().item()) + need_refresh = max_index >= pivots_sdf_flat.numel() + if need_refresh: + # Draw once so the same subset of pivots is tracked across iterations. + sdf_sample_indices_tensor = torch.randperm( + pivots_sdf_flat.shape[0], device=pivots_sdf_flat.device + )[:sample_count] + sdf_sample_indices_list = sdf_sample_indices_tensor.detach().cpu().tolist() + else: + if sdf_sample_indices_tensor.device != pivots_sdf_flat.device: + sdf_sample_indices_tensor = sdf_sample_indices_tensor.to( + pivots_sdf_flat.device, non_blocking=True + ) + sample_values = pivots_sdf_flat[sdf_sample_indices_tensor] + sample_indices_list = sdf_sample_indices_list or [] + sample_values_list = sample_values.cpu().tolist() + + raw_mesh_loss = mesh_pkg["mesh_loss"] + loss = args.mesh_loss_weight * raw_mesh_loss + loss_value = float(loss.item()) + raw_loss_value = float(raw_mesh_loss.item()) + + loss_scalars = extract_loss_scalars(mesh_pkg) + skip_iteration = ( + mesh_pkg.get("mesh_triangles") is not None and mesh_pkg["mesh_triangles"].numel() == 0 + ) + + iteration_record = { + "iteration": iteration, + "view_index": view_idx, + "total_loss": loss_value, + "raw_mesh_loss": raw_loss_value, + "sdf_mean": sdf_mean, + "sdf_std": sdf_std, + "skipped": bool(skip_iteration), + } + if ema_loss is not None: + iteration_record["ema_loss"] = ema_loss + iteration_record.update(loss_scalars) + + sdf_record = { + "iteration": iteration, + "sdf_mean": sdf_mean, + "sdf_std": sdf_std, + "sample_count": len(sample_values_list), + "sample_indices": sample_indices_list, + "sample_values": sample_values_list, + } + + if skip_iteration: + iteration_record["skipped_reason"] = "empty_mesh" + loss_log_file.write(json.dumps(iteration_record) + "\n") + loss_log_file.flush() + sdf_record["skipped"] = True + sdf_log_file.write(json.dumps(sdf_record) + "\n") + sdf_log_file.flush() + print(f"[WARNING] Empty mesh at iteration {iteration}; skipping optimizer step.") + continue + + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + ema_loss = loss_value if ema_loss is None else (0.9 * ema_loss + 0.1 * loss_value) + iteration_record["ema_loss"] = ema_loss + + loss_log_file.write(json.dumps(iteration_record) + "\n") + loss_log_file.flush() + + sdf_log_file.write(json.dumps(sdf_record) + "\n") + sdf_log_file.flush() + + if iteration % args.log_interval == 0 or iteration == 1: + mesh_depth = loss_scalars.get("mesh_depth_loss", 0.0) + mesh_normal = loss_scalars.get("mesh_normal_loss", 0.0) + occupied_centers = loss_scalars.get("occupied_centers_loss", 0.0) + occupancy_labels = loss_scalars.get("occupancy_labels_loss", 0.0) + + print( + "[Iter {iter:05d}] loss={loss:.6f} ema={ema:.6f} depth={depth:.6f} " + "normal={normal:.6f} occ_centers={centers:.6f} labels={labels:.6f} " + "sdf_mean={sdf_mean:.6f} mesh_raw={raw_mesh:.6f}".format( + iter=iteration, + loss=loss_value, + ema=ema_loss, + depth=mesh_depth, + normal=mesh_normal, + centers=occupied_centers, + labels=occupancy_labels, + sdf_mean=sdf_mean, + raw_mesh=raw_loss_value, + ) + ) + + if args.export_interval > 0 and iteration % args.export_interval == 0: + export_iteration_state( + iteration=iteration, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=args.output_dir, + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + if surface_gaussians_path and not mesh_state.get("surface_sample_saved", False): + print( + "[WARNING] Requested export of surface Gaussians but no samples were saved. " + "Verify surface sampling settings." + ) + + final_dir = os.path.join(args.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + export_iteration_state( + iteration=args.iterations, + gaussians=gaussians, + mesh_state=mesh_state, + output_dir=final_dir, + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + print(f"[INFO] Occupancy refinement completed. Results saved to {args.output_dir}.") + + +if __name__ == "__main__": + main() diff --git a/milo/mesh_extract_integration.py b/milo/useless_maybe/mesh_extract_integration.py similarity index 100% rename from milo/mesh_extract_integration.py rename to milo/useless_maybe/mesh_extract_integration.py diff --git a/milo/useless_maybe/ply2mesh.py b/milo/useless_maybe/ply2mesh.py new file mode 100644 index 0000000..d16f1c6 --- /dev/null +++ b/milo/useless_maybe/ply2mesh.py @@ -0,0 +1,459 @@ +import os +import json +import math +import random +from argparse import ArgumentParser +from typing import List, Optional, Sequence + +import yaml +from types import SimpleNamespace +import torch +import torch.nn as nn +import numpy as np +import trimesh + +from arguments import PipelineParams +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +from regularization.regularizer.mesh import initialize_mesh_regularization, compute_mesh_regularization +from functional import extract_mesh, compute_delaunay_triangulation +from functional.mesh import frustum_cull_mesh +from regularization.sdf.learnable import convert_occupancy_to_sdf +from utils.geometry_utils import flatten_voronoi_features +from gaussian_renderer.radegs import render_radegs, integrate_radegs + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def quaternion_to_rotation_matrix(q: Sequence[float]) -> np.ndarray: + """Convert a unit quaternion [w, x, y, z] to a 3x3 rotation matrix.""" + q = np.asarray(q, dtype=np.float64) + if q.shape != (4,): + raise ValueError("Quaternion must have shape (4,)") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float32, + ) + return rotation + + +class ManualScene: + """Minimal scene wrapper exposing the API expected by mesh regularization utilities.""" + + def __init__(self, cameras: Sequence[Camera]): + self._train_cameras = list(cameras) + + def getTrainCameras(self, scale: float = 1.0): + return list(self._train_cameras) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ): + return list(self._train_cameras) + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, +) -> List[Camera]: + if not os.path.isfile(json_path): + raise FileNotFoundError(f"Camera JSON not found: {json_path}") + + with open(json_path, "r", encoding="utf-8") as f: + camera_entries = json.load(f) + + if not camera_entries: + raise ValueError(f"No camera entries found in {json_path}") + + fov_y = math.radians(fov_y_deg) + aspect_ratio = image_width / image_height + fov_x = 2.0 * math.atan(aspect_ratio * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(camera_entries): + if "quaternion" in entry: + rotation = quaternion_to_rotation_matrix(entry["quaternion"]) + elif "rotation" in entry: + rotation = np.asarray(entry["rotation"], dtype=np.float32) + if rotation.shape != (3, 3): + raise ValueError(f"Camera entry {idx} rotation must be 3x3, got {rotation.shape}") + else: + raise KeyError(f"Camera entry {idx} must provide either 'quaternion' or 'rotation'.") + + translation = None + if "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + elif "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"Camera entry {idx} position must be length-3, got shape {camera_center.shape}") + # Camera expects world-to-view translation (COLMAP convention t = -R * C). + rotation_w2c = rotation.T # rotation is camera-to-world + translation = -rotation_w2c @ camera_center + else: + raise KeyError(f"Camera entry {idx} must provide 'position', 'translation', or 'tvec'.") + + if translation.shape != (3,): + raise ValueError(f"Camera entry {idx} translation must be length-3, got shape {translation.shape}") + + image_name = ( + entry.get("name") + or entry.get("img_name") + or entry.get("image_name") + or f"view_{idx:04d}" + ) + camera = Camera( + colmap_id=str(idx), + R=rotation, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device="cuda", + ) + cameras.append(camera) + return cameras + + +def freeze_gaussian_rigid_parameters(gaussians: GaussianModel) -> None: + """Disable gradients for geometric and appearance parameters, keeping occupancy shift trainable.""" + freeze_attrs = [ + "_xyz", + "_features_dc", + "_features_rest", + "_opacity", + "_scaling", + "_rotation", + ] + for attr in freeze_attrs: + param = getattr(gaussians, attr, None) + if isinstance(param, nn.Parameter): + param.requires_grad_(False) + + if hasattr(gaussians, "_base_occupancy") and isinstance(gaussians._base_occupancy, nn.Parameter): + gaussians._base_occupancy.requires_grad_(False) + if hasattr(gaussians, "_occupancy_shift") and isinstance(gaussians._occupancy_shift, nn.Parameter): + gaussians._occupancy_shift.requires_grad_(True) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + def _render( + view: Camera, + pc_obj: GaussianModel, + pipe_obj: PipelineParams, + bg_color: torch.Tensor, + *, + kernel_size: float = 0.0, + require_coord: bool = False, + require_depth: bool = True, + ): + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_view(view: Camera): + return _render(view, gaussians, pipe, background) + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = _render( + view, + pc_obj, + pipe_obj, + bg_color, + kernel_size=kernel_size, + require_coord=require_coord, + require_depth=require_depth, + ) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def load_mesh_config_file(name: str) -> dict: + config_path = os.path.join(BASE_DIR, "configs", "mesh", f"{name}.yaml") + if not os.path.isfile(config_path): + raise FileNotFoundError(f"Mesh config not found: {config_path}") + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def export_mesh_from_gaussians( + gaussians: GaussianModel, + mesh_state: dict, + output_path: str, + reference_camera: Optional[Camera] = None, +) -> None: + delaunay_tets = mesh_state.get("delaunay_tets") + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_save = mesh + if reference_camera is not None: + mesh_to_save = frustum_cull_mesh(mesh, reference_camera) + + verts = mesh_to_save.verts.detach().cpu().numpy() + faces = mesh_to_save.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def train_occupancy_only(args) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device is required for occupancy fine-tuning.") + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + cameras = load_cameras_from_json( + json_path=args.camera_poses, + image_height=args.image_height, + image_width=args.image_width, + fov_y_deg=args.fov_y, + ) + print(f"[INFO] Loaded {len(cameras)} cameras from {args.camera_poses}.") + + scene = ManualScene(cameras) + + mesh_config = load_mesh_config_file(args.mesh_config) + mesh_config["start_iter"] = max(1, args.mesh_start_iter) + if args.mesh_stop_iter is not None: + mesh_config["stop_iter"] = args.mesh_stop_iter + else: + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", args.iterations), args.iterations) + + pipe_parser = ArgumentParser() + pipe = PipelineParams(pipe_parser) + + background = torch.tensor(args.background, dtype=torch.float32, device="cuda") + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + print(f"[INFO] Loaded {gaussians._xyz.shape[0]} Gaussians from {args.ply_path}.") + + if not gaussians.learn_occupancy or not hasattr(gaussians, "_occupancy_shift"): + print("[INFO] PLY does not provide occupancy buffers; initializing them to zeros.") + gaussians.learn_occupancy = True + base_occupancy = torch.zeros((gaussians._xyz.shape[0], 9), device=gaussians._xyz.device) + occupancy_shift = torch.zeros_like(base_occupancy) + gaussians._base_occupancy = nn.Parameter(base_occupancy.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = nn.Parameter(occupancy_shift.requires_grad_(True)) + + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + + gaussians.init_culling(len(cameras)) + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + freeze_gaussian_rigid_parameters(gaussians) + + optimizer = torch.optim.Adam([gaussians._occupancy_shift], lr=args.occupancy_lr) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + mesh_renderer, mesh_state = initialize_mesh_regularization(scene, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + runtime_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + depth_reinit_iter=getattr(args, 'depth_reinit_iter', args.warn_until_iter), + ) + + os.makedirs(args.output_dir, exist_ok=True) + + ema_loss: Optional[float] = None + + for iteration in range(1, args.iterations + 1): + view_idx = random.randrange(len(cameras)) + viewpoint = cameras[view_idx] + + render_pkg = render_view(viewpoint) + + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=render_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_idx, + gaussians=gaussians, + scene=scene, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=100.0 / max(args.iterations, 1), + args=runtime_args, + integrate_func=integrate_radegs, + ) + mesh_state = mesh_pkg["updated_state"] + + if mesh_pkg.get("mesh_triangles") is not None and mesh_pkg["mesh_triangles"].numel() == 0: + print(f"[WARNING] Empty mesh at iteration {iteration}; skipping optimizer step.") + continue + + loss = mesh_pkg["mesh_loss"] + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + loss_value = float(loss.item()) + ema_loss = loss_value if ema_loss is None else (0.9 * ema_loss + 0.1 * loss_value) + + if iteration % args.log_interval == 0 or iteration == 1: + print( + "[Iter {iter:05d}] loss={loss:.6f} ema={ema:.6f} depth={depth:.6f} " + "normal={normal:.6f} occ_centers={centers:.6f} labels={labels:.6f}".format( + iter=iteration, + loss=loss_value, + ema=ema_loss, + depth=mesh_pkg["mesh_depth_loss"].item(), + normal=mesh_pkg["mesh_normal_loss"].item(), + centers=mesh_pkg["occupied_centers_loss"].item(), + labels=mesh_pkg["occupancy_labels_loss"].item(), + ) + ) + + if args.export_interval > 0 and iteration % args.export_interval == 0: + iteration_dir = os.path.join(args.output_dir, f"iter_{iteration:05d}") + os.makedirs(iteration_dir, exist_ok=True) + gaussians.save_ply(os.path.join(iteration_dir, "point_cloud.ply")) + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=os.path.join(iteration_dir, "mesh.ply"), + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + final_dir = os.path.join(args.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + gaussians.save_ply(os.path.join(final_dir, "point_cloud.ply")) + export_mesh_from_gaussians( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=os.path.join(final_dir, "mesh.ply"), + reference_camera=cameras[0] if args.cull_on_export else None, + ) + + print(f"[INFO] Occupancy-only training completed. Results saved to {args.output_dir}.") + + +def build_arg_parser() -> ArgumentParser: + parser = ArgumentParser(description="Occupancy-only fine-tuning from a pretrained Gaussian PLY.") + parser.add_argument("--ply_path", type=str, required=True, help="Input PLY file with pretrained Gaussians.") + parser.add_argument("--camera_poses", type=str, required=True, help="JSON file containing camera poses.") + parser.add_argument("--mesh_config", type=str, default="default", help="Mesh regularization config name.") + parser.add_argument("--iterations", type=int, default=2000, help="Number of optimization steps.") + parser.add_argument("--occupancy_lr", type=float, default=0.01, help="Learning rate for occupancy shift.") + parser.add_argument("--log_interval", type=int, default=50, help="Console logging interval.") + parser.add_argument("--export_interval", type=int, default=200, help="Mesh/PLY export interval.") + parser.add_argument("--output_dir", type=str, default="./ply2mesh_output", help="Directory to store outputs.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical field-of-view in degrees.") + parser.add_argument("--image_width", type=int, default=1280, help="Rendered image width.") + parser.add_argument("--image_height", type=int, default=720, help="Rendered image height.") + parser.add_argument( + "--background", + type=float, + nargs=3, + default=(0.0, 0.0, 0.0), + help="Background color used for rendering (RGB).", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument("--mesh_start_iter", type=int, default=1, help="Iteration at which mesh regularization starts.") + parser.add_argument( + "--mesh_stop_iter", + type=int, + default=None, + help="Iteration after which mesh regularization stops (defaults to total iterations).", + ) + parser.add_argument("--warn_until_iter", type=int, default=3000, help="Warmup iterations for surface sampling.") + parser.add_argument("--imp_metric", type=str, default="outdoor", choices=["outdoor", "indoor"], help="Importance metric for surface sampling.") + parser.add_argument("--cull_on_export", action="store_true", help="Enable frustum culling using the first camera before exporting meshes.") + return parser + + +if __name__ == "__main__": + argument_parser = build_arg_parser() + parsed_args = argument_parser.parse_args() + train_occupancy_only(parsed_args) diff --git a/milo/useless_maybe/yufu2mesh_iterative.py b/milo/useless_maybe/yufu2mesh_iterative.py new file mode 100644 index 0000000..b95f7bf --- /dev/null +++ b/milo/useless_maybe/yufu2mesh_iterative.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +"""Iteratively optimize SDF pivots so that the extracted mesh adheres to the provided Gaussian point cloud.""" + +import argparse +import json +import math +import os +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +import trimesh + +from arguments import PipelineParams +from functional import ( + sample_gaussians_on_surface, + extract_gaussian_pivots, + compute_initial_sdf_values, + compute_delaunay_triangulation, + extract_mesh, +) +from gaussian_renderer.radegs import render_radegs +from regularization.sdf.learnable import convert_sdf_to_occupancy +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel + + +def quaternion_to_rotation_matrix(q: List[float]) -> np.ndarray: + """Convert unit quaternion [w, x, y, z] to a rotation matrix.""" + w, x, y, z = q + xx, yy, zz = x * x, y * y, z * z + xy, xz, yz = x * y, x * z, y * z + wx, wy, wz = w * x, w * y, w * z + return np.array([ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], + ]) + + +def load_cameras( + poses_json: str, + height: int, + width: int, + fov_y_deg: float, + device: str, +) -> List[Camera]: + with open(poses_json, "r", encoding="utf-8") as f: + poses = json.load(f) + + fov_y = math.radians(fov_y_deg) + aspect = width / height + fov_x = 2.0 * math.atan(aspect * math.tan(fov_y / 2.0)) + + cameras: List[Camera] = [] + for idx, info in enumerate(poses): + cam = Camera( + colmap_id=str(idx), + R=quaternion_to_rotation_matrix(info["quaternion"]), + T=np.asarray(info["position"]), + FoVx=fov_x, + FoVy=fov_y, + image=torch.empty(3, height, width), + gt_alpha_mask=None, + image_name=info.get("name", f"view_{idx:05d}"), + uid=idx, + data_device=device, + ) + cameras.append(cam) + return cameras + + +def build_render_function( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + def render_func(view: Camera): + render_pkg = render_radegs( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + kernel_size=0.0, + scaling_modifier=1.0, + require_coord=False, + require_depth=True, + ) + return {"render": render_pkg["render"], "depth": render_pkg["median_depth"]} + + return render_func + + +def sample_tensor(tensor: torch.Tensor, max_samples: int) -> torch.Tensor: + if max_samples <= 0 or tensor.shape[0] <= max_samples: + return tensor + idx = torch.randperm(tensor.shape[0], device=tensor.device)[:max_samples] + return tensor[idx] + + +def export_mesh(mesh, path: str) -> None: + verts = mesh.verts.detach().cpu().numpy() + faces = mesh.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(path) + + +def main(): + parser = argparse.ArgumentParser(description="Iteratively refine SDF pivots using Chamfer supervision from the Gaussian cloud.") + parser.add_argument("--ply_path", type=str, required=True, help="Perfect Gaussian PLY.") + parser.add_argument("--camera_poses", type=str, required=True, help="Camera pose JSON.") + parser.add_argument("--output_dir", type=str, default="./iter_occ_refine", help="Output directory.") + parser.add_argument("--iterations", type=int, default=400, help="Number of SDF optimization steps.") + parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate for SDF pivots.") + parser.add_argument("--reg_weight", type=float, default=5e-4, help="L2 regularization weight towards the initial SDF.") + parser.add_argument("--mesh_sample_count", type=int, default=4096, help="Number of mesh vertices sampled per step.") + parser.add_argument("--gaussian_sample_count", type=int, default=4096, help="Number of Gaussian centers sampled per step.") + parser.add_argument("--surface_sample_limit", type=int, default=400000, help="Maximum Gaussians kept for Delaunay pivots.") + parser.add_argument("--clamp_sdf", type=float, default=1.0, help="Clamp range for SDF values.") + parser.add_argument("--log_interval", type=int, default=10, help="Logging interval.") + parser.add_argument("--export_interval", type=int, default=100, help="Mesh export interval (0 disables periodic export).") + parser.add_argument("--image_height", type=int, default=720, help="Renderer image height.") + parser.add_argument("--image_width", type=int, default=1280, help="Renderer image width.") + parser.add_argument("--fov_y", type=float, default=75.0, help="Vertical FoV in degrees.") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device is required for occupancy refinement.") + + device = "cuda" + os.makedirs(args.output_dir, exist_ok=True) + + cameras = load_cameras( + poses_json=args.camera_poses, + height=args.image_height, + width=args.image_width, + fov_y_deg=args.fov_y, + device=device, + ) + + gaussians = GaussianModel( + sh_degree=0, + use_mip_filter=False, + learn_occupancy=True, + use_appearance_network=False, + ) + gaussians.load_ply(args.ply_path) + + pipe_parser = argparse.ArgumentParser() + pipe = PipelineParams(pipe_parser) + background = torch.tensor([0.0, 0.0, 0.0], device=device) + render_func = build_render_function(gaussians, pipe, background) + + with torch.no_grad(): + means = gaussians.get_xyz.detach().clone() + scales = gaussians.get_scaling.detach().clone() + rotations = gaussians.get_rotation.detach().clone() + + with torch.no_grad(): + surface_gaussians_idx = sample_gaussians_on_surface( + views=cameras, + means=means, + scales=scales, + rotations=rotations, + opacities=gaussians.get_opacity, + n_max_samples=args.surface_sample_limit, + scene_type="outdoor", + ) + + if surface_gaussians_idx.numel() == 0: + raise RuntimeError("Surface sampling returned zero Gaussians.") + + surface_means = means[surface_gaussians_idx].detach() + + initial_sdf = compute_initial_sdf_values( + views=cameras, + render_func=render_func, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ).detach() + + pivots, _ = extract_gaussian_pivots( + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + delaunay_tets = compute_delaunay_triangulation( + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + learned_sdf = torch.nn.Parameter(initial_sdf.clone()) + optimizer = torch.optim.Adam([learned_sdf], lr=args.lr) + + for iteration in range(1, args.iterations + 1): + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=learned_sdf, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + mesh_verts = mesh.verts + if mesh_verts.numel() == 0: + print(f"[Iter {iteration:05d}] Empty mesh, skipping update.") + continue + + sampled_mesh_pts = sample_tensor(mesh_verts, args.mesh_sample_count) + sampled_gaussian_pts = sample_tensor(surface_means, args.gaussian_sample_count) + + with torch.no_grad(): + nn_idx_forward = torch.cdist( + sampled_mesh_pts.detach(), + sampled_gaussian_pts.detach(), + p=2, + ).argmin(dim=1) + nn_idx_backward = torch.cdist( + sampled_gaussian_pts, + sampled_mesh_pts.detach(), + p=2, + ).argmin(dim=1) + + nearest_gauss = sampled_gaussian_pts[nn_idx_forward] + nearest_mesh = sampled_mesh_pts[nn_idx_backward] + + loss_forward = torch.mean(torch.sum((sampled_mesh_pts - nearest_gauss) ** 2, dim=1)) + loss_backward = torch.mean(torch.sum((sampled_gaussian_pts - nearest_mesh) ** 2, dim=1)) + chamfer_loss = loss_forward + loss_backward + + reg_loss = F.mse_loss(learned_sdf, initial_sdf) + loss = chamfer_loss + args.reg_weight * reg_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + learned_sdf.clamp_(-args.clamp_sdf, args.clamp_sdf) + + if iteration % args.log_interval == 0 or iteration == 1: + print( + f"[Iter {iteration:05d}] chamfer={chamfer_loss.item():.6f} " + f"reg={reg_loss.item():.6f} total={loss.item():.6f} " + f"|mesh|={sampled_mesh_pts.shape[0]} |gauss|={sampled_gaussian_pts.shape[0]}" + ) + + if args.export_interval > 0 and iteration % args.export_interval == 0: + export_mesh( + mesh=mesh, + path=os.path.join(args.output_dir, f"mesh_iter_{iteration:05d}.ply"), + ) + + final_mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=learned_sdf, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, + ) + + export_mesh(final_mesh, os.path.join(args.output_dir, "final_mesh.ply")) + + with torch.no_grad(): + final_occ = convert_sdf_to_occupancy(learned_sdf.detach()).view(-1, 9) + base_occ = convert_sdf_to_occupancy(initial_sdf).view(-1, 9) + gaussians.learn_occupancy = True + total_gaussians = gaussians._xyz.shape[0] + base_buffer = base_occ.new_zeros((total_gaussians, 9)) + shift_buffer = base_occ.new_zeros((total_gaussians, 9)) + surface_idx = surface_gaussians_idx.long() + base_buffer.index_copy_(0, surface_idx, base_occ) + shift_buffer.index_copy_(0, surface_idx, final_occ - base_occ) + gaussians._base_occupancy = torch.nn.Parameter(base_buffer, requires_grad=False) + gaussians._occupancy_shift = torch.nn.Parameter(shift_buffer, requires_grad=False) + gaussians.save_ply(os.path.join(args.output_dir, "refined_gaussians.ply")) + + print(f"[INFO] Optimization complete. Results saved to {args.output_dir}.") + + +if __name__ == "__main__": + main() diff --git a/milo/utils/log_utils.py b/milo/utils/log_utils.py index 905d108..2430d27 100644 --- a/milo/utils/log_utils.py +++ b/milo/utils/log_utils.py @@ -1,5 +1,5 @@ import os -from typing import List, Union, Dict, Any +from typing import List, Union, Dict, Any, Optional import torch import numpy as np import math @@ -117,6 +117,22 @@ def make_log_figure( return log_images_dict +def save_inter_figure(depth_diff: torch.Tensor, normal_diff: torch.Tensor, save_path: str): + plt.figure(figsize=(12, 6)) + plt.suptitle("inter") + plt.subplot(1, 2, 1) + plt.imshow(depth_diff.cpu(), cmap="Spectral") + plt.title("depth") + plt.colorbar() + plt.subplot(1, 2, 2) + plt.imshow(normal_diff.cpu(), cmap="Spectral", vmin=0.0, vmax=2.0) + plt.title("normal") + plt.colorbar() + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.savefig(save_path) + plt.close() + + def training_report(iteration, l1_loss, testing_iterations, scene, renderFunc, renderArgs): # Report test and samples of training set if iteration in testing_iterations: @@ -263,6 +279,35 @@ def log_training_progress( )) / 2. ) titles_to_log.append(f"Mesh Normals {viewpoint_idx}") + + # Depth/Normal diff for the current logging view + mesh_depth_map = mesh_render_pkg["depth"].detach().squeeze() + gauss_depth_map = render_pkg["median_depth"].detach().squeeze() + valid_depth_mask = (mesh_depth_map > 0) & (gauss_depth_map > 0) + depth_diff = torch.zeros_like(mesh_depth_map) + depth_diff[valid_depth_mask] = (mesh_depth_map - gauss_depth_map).abs()[valid_depth_mask] + + mesh_normals_view = fix_normal_map( + viewpoint_cam, + mesh_render_pkg["normals"].detach(), + normal_in_view_space=True, + ) + gauss_normals_view = fix_normal_map( + viewpoint_cam, + render_pkg["normal"].detach(), + normal_in_view_space=True, + ) + if mesh_normals_view.shape[0] == 3: + mesh_normals_view = mesh_normals_view.permute(1, 2, 0) + if gauss_normals_view.shape[0] == 3: + gauss_normals_view = gauss_normals_view.permute(1, 2, 0) + normal_dot = (mesh_normals_view * gauss_normals_view).sum(dim=-1).clamp(-1., 1.) + normal_diff = (1. - normal_dot) * valid_depth_mask.float() + + images_to_log.append(depth_diff) + titles_to_log.append(f"Depth Diff {viewpoint_idx}") + images_to_log.append(normal_diff) + titles_to_log.append(f"Normal Diff {viewpoint_idx}") log_images_dict = make_log_figure( images=images_to_log, @@ -319,4 +364,4 @@ def log_training_progress( ema_mesh_depth_loss_for_log, ema_mesh_normal_loss_for_log, ema_occupied_centers_loss_for_log, ema_occupancy_labels_loss_for_log, ema_depth_order_loss_for_log - ) \ No newline at end of file + ) diff --git a/milo/yufu2mesh.py b/milo/yufu2mesh.py new file mode 100644 index 0000000..9fa015b --- /dev/null +++ b/milo/yufu2mesh.py @@ -0,0 +1,219 @@ +from functional import ( + sample_gaussians_on_surface, + extract_gaussian_pivots, + compute_initial_sdf_values, + compute_delaunay_triangulation, + extract_mesh, + frustum_cull_mesh, +) +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +import json, math, torch, trimesh +import numpy as np +from arguments import ModelParams, PipelineParams, OptimizationParams, read_config +def quaternion_to_rotation_matrix(q): + """ + 将单位四元数转换为3x3旋转矩阵。 + + 参数: + q: 一个包含四个元素的列表或数组 [w, x, y, z] + + 返回: + R: 一个3x3的NumPy数组表示的旋转矩阵。 + """ + w, x, y, z = q + # 计算矩阵的每个元素,避免重复计算以提高效率 + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + R = np.array([ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)] + ]) + return R + +# Load or initialize a 3DGS-like model and training cameras +ply_path = "/home/zoyo/Desktop/MILo_rtx50/milo/data/Bridge/yufu_bridge_cleaned.ply" +camera_poses_json = "/home/zoyo/Desktop/MILo_rtx50/milo/data/Bridge/camera_poses_cam1.json" +camera_poses = json.load(open(camera_poses_json)) +with open(camera_poses_json, 'r') as fcc_file: + fcc_data = json.load(fcc_file) + print(len(fcc_data),type(fcc_data)) + +gaussians = GaussianModel( + sh_degree=0, + # use_mip_filter=use_mip_filter, + # learn_occupancy=args.mesh_regularization, + # use_appearance_network=args.decoupled_appearance, + ) +gaussians.load_ply(ply_path) +train_cameras = [] +height = 720 +width = 1280 +fov_y = math.radians(75) +# fov_x = math.radians(108) +aspect_ratio = width / height +fov_x = 2 * math.atan(aspect_ratio * math.tan(fov_y / 2)) +for i in range(len(fcc_data)): + camera_info = fcc_data[i] + camera = Camera( + colmap_id=str(i), + R=quaternion_to_rotation_matrix(camera_info['quaternion']), + T=np.asarray(camera_info['position']), + FoVx=fov_x, + FoVy=fov_y, + image=torch.empty(3, height, width), + gt_alpha_mask=None, + image_name=camera_info['name'], + uid=i, + data_device='cuda', + ) + train_cameras.append(camera) + +# following this template. It will be used only for initializing SDF values. +# The wrapper should accept just a camera as input, and return a dictionary +# with "render" and "depth" keys. +from gaussian_renderer.radegs import render_radegs + + +from argparse import ArgumentParser, Namespace +parser = ArgumentParser(description="Training script parameters") +parser.add_argument("--bug", type=bool, default=False) +pipe = PipelineParams(parser) +background = torch.tensor([0., 0., 0.], device="cuda") +def render_func(view): + render_pkg = render_radegs( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + kernel_size=0.0, + scaling_modifier = 1.0, + require_coord=False, + require_depth=True + ) + return { + "render": render_pkg["render"], + "depth": render_pkg["median_depth"], + } + +# Only the parameters of the Gaussians are needed for extracting the mesh. +means = gaussians.get_xyz +scales = gaussians.get_scaling +rotations = gaussians.get_rotation +opacities = gaussians.get_opacity + +# Sample Gaussians on the surface. +# Should be performed only once, or just once in a while. +# In this example, we sample at most 600_000 Gaussians. +surface_gaussians_idx = sample_gaussians_on_surface( + views=train_cameras, + means=means, + scales=scales, + rotations=rotations, + opacities=opacities, + n_max_samples=600_000, + scene_type='outdoor', +) + +# Compute initial SDF values for pivots. Should be performed only once. +# In the paper, we propose to learn optimal SDF values by maximizing the +# consistency between volumetric renderings and surface mesh renderings. +initial_pivots_sdf = compute_initial_sdf_values( + views=train_cameras, + render_func=render_func, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + +# Compute Delaunay Triangulation. +# Can be performed once in a while. +delaunay_tets = compute_delaunay_triangulation( + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + +# Differentiably extract a mesh from Gaussian parameters, including initial +# or updated SDF values for the Gaussian pivots. +# This function is differentiable with respect to the parameters of the Gaussians, +# as well as the SDF values. Can be performed at every training iteration. +mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=initial_pivots_sdf, + means=means, + scales=scales, + rotations=rotations, + gaussian_idx=surface_gaussians_idx, +) + + + +# You can now apply any differentiable operation on the extracted mesh, +# and backpropagate gradients back to the Gaussians! +# In the paper, we propose to use differentiable mesh rendering. +from scene.mesh import MeshRasterizer, MeshRenderer +renderer = MeshRenderer(MeshRasterizer(cameras=train_cameras)) + +# We cull the mesh based on the view frustum for more efficiency +i_view = np.random.randint(0, len(train_cameras)) +refined_mesh = frustum_cull_mesh(mesh, train_cameras[i_view]) + +mesh_render_pkg = renderer( + refined_mesh, + cam_idx=i_view, + return_depth=True, return_normals=True +) +mesh_depth = mesh_render_pkg["depth"] +mesh_normals = mesh_render_pkg["normals"] + +# 转换为numpy数组后保存 +save_dict = {} +for key, value in mesh_render_pkg.items(): + if isinstance(value, torch.Tensor): + save_dict[key] = value.detach().cpu().numpy() + else: + save_dict[key] = value + +np.savez("mesh_render_output.npz", **save_dict) + +# 保存mesh +# import trimesh + +# 从Meshes对象中提取顶点和面 +refined_vertices = refined_mesh.verts.detach().cpu().numpy() +refined_faces = refined_mesh.faces.detach().cpu().numpy() + +# 创建trimesh对象并保存 +refined_mesh_obj = trimesh.Trimesh(vertices=refined_vertices, faces=refined_faces) + +# # 保存为OBJ格式 +# mesh_obj.export('extracted_mesh.obj') + +# 或者保存为PLY格式 +refined_mesh_obj.export(f'refined_mesh_{len(fcc_data)}.ply') + +vertices = mesh.verts.detach().cpu().numpy() +faces = mesh.faces.detach().cpu().numpy() + +# 创建trimesh对象并保存 +mesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces) + +# # 保存为OBJ格式 +# mesh_obj.export('extracted_mesh.obj') + +# 或者保存为PLY格式 +mesh_obj.export(f'mesh_{len(fcc_data)}.ply') + +# # 或者保存为STL格式 +# mesh_obj.export('extracted_mesh.stl') \ No newline at end of file diff --git a/milo/yufu2mesh_new.py b/milo/yufu2mesh_new.py new file mode 100644 index 0000000..1595dc1 --- /dev/null +++ b/milo/yufu2mesh_new.py @@ -0,0 +1,1664 @@ +from pathlib import Path +import math +import json +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple +from types import SimpleNamespace + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +import trimesh +import yaml + +from argparse import ArgumentParser + +from functional import ( + compute_delaunay_triangulation, + extract_mesh, + frustum_cull_mesh, +) +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +from gaussian_renderer import render_full +from gaussian_renderer.radegs import render_radegs, integrate_radegs +from arguments import PipelineParams +from regularization.regularizer.mesh import ( + initialize_mesh_regularization, + compute_mesh_regularization, +) +from regularization.sdf.learnable import convert_occupancy_to_sdf +from utils.geometry_utils import flatten_voronoi_features + +# DISCOVER-SE 相机轨迹使用 OpenGL 右手坐标系(相机前方为 -Z,向上为 +Y), +# 而 MILo/colmap 渲染管线假设的是前方 +Z、向上 -Y。需要在读入时做一次轴翻转。 +OPENGL_TO_COLMAP = np.diag([1.0, -1.0, -1.0]).astype(np.float32) + +# 统一管理脚本根目录和数据目录,便于构造相对路径 +MILO_DIR = Path(__file__).resolve().parent +DATA_ROOT = MILO_DIR / "data" + + +class DepthProvider: + """负责加载并缓存 Discoverse 深度图,统一形状、裁剪和掩码。""" + + def __init__( + self, + depth_root: Path, + image_height: int, + image_width: int, + device: torch.device, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + self.depth_root = Path(depth_root) + if not self.depth_root.is_dir(): + raise FileNotFoundError(f"深度目录不存在:{self.depth_root}") + self.image_height = image_height + self.image_width = image_width + self.device = device + self.clip_min = clip_min + self.clip_max = clip_max + self._cache: Dict[int, torch.Tensor] = {} + self._mask_cache: Dict[int, torch.Tensor] = {} + + def _file_for_index(self, view_index: int) -> Path: + return self.depth_root / f"depth_img_0_{view_index}.npy" + + def _load_numpy(self, file_path: Path) -> np.ndarray: + depth_np = np.load(file_path) + depth_np = np.squeeze(depth_np) + if depth_np.ndim != 2: + raise ValueError(f"{file_path} 深度数组维度异常:{depth_np.shape}") + if depth_np.shape != (self.image_height, self.image_width): + raise ValueError( + f"{file_path} 深度分辨率应为 {(self.image_height, self.image_width)},当前为 {depth_np.shape}" + ) + if self.clip_min is not None or self.clip_max is not None: + min_val = self.clip_min if self.clip_min is not None else None + max_val = self.clip_max if self.clip_max is not None else None + depth_np = np.clip( + depth_np, + min_val if min_val is not None else depth_np.min(), + max_val if max_val is not None else depth_np.max(), + ) + return depth_np.astype(np.float32) + + def get(self, view_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """返回 (depth_tensor, valid_mask),均在 GPU 上。""" + if view_index not in self._cache: + file_path = self._file_for_index(view_index) + if not file_path.is_file(): + raise FileNotFoundError(f"缺少深度文件:{file_path}") + depth_np = self._load_numpy(file_path) + depth_tensor = torch.from_numpy(depth_np).to(self.device) + valid_mask = torch.isfinite(depth_tensor) & (depth_tensor > 0.0) + if self.clip_min is not None: + valid_mask &= depth_tensor >= self.clip_min + if self.clip_max is not None: + valid_mask &= depth_tensor <= self.clip_max + self._cache[view_index] = depth_tensor + self._mask_cache[view_index] = valid_mask + return self._cache[view_index], self._mask_cache[view_index] + + def as_numpy(self, view_index: int) -> np.ndarray: + depth_tensor, _ = self.get(view_index) + return depth_tensor.detach().cpu().numpy() + + +class NormalGroundTruthCache: + """缓存以初始高斯生成的法线 GT,避免训练阶段重复渲染。""" + + def __init__( + self, + cache_dir: Path, + image_height: int, + image_width: int, + device: torch.device, + ) -> None: + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.image_height = image_height + self.image_width = image_width + self.device = device + self._memory_cache: Dict[int, torch.Tensor] = {} + + def _file_path(self, view_index: int) -> Path: + return self.cache_dir / f"normal_view_{view_index:04d}.npy" + + def has(self, view_index: int) -> bool: + return self._file_path(view_index).is_file() + + def store(self, view_index: int, normal_tensor: torch.Tensor) -> None: + normals_np = prepare_normals(normal_tensor) + np.save(self._file_path(view_index), normals_np.astype(np.float16)) + + def get(self, view_index: int) -> torch.Tensor: + if view_index not in self._memory_cache: + path = self._file_path(view_index) + if not path.is_file(): + raise FileNotFoundError( + f"未找到视角 {view_index} 的法线缓存:{path},请先完成预计算。" + ) + normals_np = np.load(path) + expected_shape = (self.image_height, self.image_width, 3) + if normals_np.shape != expected_shape: + raise ValueError( + f"{path} 法线缓存尺寸应为 {expected_shape},当前为 {normals_np.shape}" + ) + normals_tensor = ( + torch.from_numpy(normals_np.astype(np.float32)) + .permute(2, 0, 1) + .to(self.device) + ) + self._memory_cache[view_index] = normals_tensor + return self._memory_cache[view_index] + + def clear_memory_cache(self) -> None: + self._memory_cache.clear() + + def ensure_all(self, cameras: Sequence[Camera], render_fn) -> None: + total = len(cameras) + for idx, camera in enumerate(cameras): + if self.has(idx): + continue + with torch.no_grad(): + pkg = render_fn(camera) + normal_tensor = pkg["normal"] + self.store(idx, normal_tensor) + if (idx + 1) % 10 == 0 or idx + 1 == total: + print(f"[INFO] 预计算法线缓存 {idx + 1}/{total}") + + +def compute_depth_loss_tensor( + pred_depth: torch.Tensor, + gt_depth: torch.Tensor, + gt_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """返回深度 L1 损失及统计信息。""" + if pred_depth.dim() == 3: + pred_depth = pred_depth.squeeze(0) + if pred_depth.shape != gt_depth.shape: + raise ValueError( + f"预测深度尺寸 {pred_depth.shape} 与 GT {gt_depth.shape} 不一致。" + ) + valid_mask = gt_mask & torch.isfinite(pred_depth) & (pred_depth > 0.0) + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_depth.device) + stats = {"valid_px": 0, "mae": float("nan"), "rmse": float("nan")} + return zero, stats + diff = pred_depth - gt_depth + abs_diff = diff.abs() + loss = abs_diff[valid_mask].mean() + rmse = torch.sqrt((diff[valid_mask] ** 2).mean()) + stats = { + "valid_px": valid_pixels, + "mae": float(abs_diff[valid_mask].mean().detach().item()), + "rmse": float(rmse.detach().item()), + } + return loss, stats + + +def compute_normal_loss_tensor( + pred_normals: torch.Tensor, + gt_normals: torch.Tensor, + base_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """基于余弦相似度的法线损失。""" + if pred_normals.dim() != 3 or pred_normals.shape[0] != 3: + raise ValueError(f"预测法线维度应为 (3,H,W),当前为 {pred_normals.shape}") + if gt_normals.shape != pred_normals.shape: + raise ValueError( + f"法线 GT 尺寸 {gt_normals.shape} 与预测 {pred_normals.shape} 不一致。" + ) + gt_mask = torch.isfinite(gt_normals).all(dim=0) + pred_mask = torch.isfinite(pred_normals).all(dim=0) + valid_mask = base_mask & gt_mask & pred_mask + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_normals.device) + stats = {"valid_px": 0, "mean_cos": float("nan")} + return zero, stats + + pred_unit = F.normalize(pred_normals, dim=0) + gt_unit = F.normalize(gt_normals, dim=0) + cos_sim = (pred_unit * gt_unit).sum(dim=0).clamp(-1.0, 1.0) + loss_map = (1.0 - cos_sim) * valid_mask + loss = loss_map.sum() / valid_mask.sum() + stats = {"valid_px": valid_pixels, "mean_cos": float(cos_sim[valid_mask].mean().item())} + return loss, stats + + +class ManualScene: + """最小 Scene 封装,提供 mesh regularization 所需的接口。""" + + def __init__(self, cameras: Sequence[Camera]): + self._train_cameras = list(cameras) + + def getTrainCameras(self, scale: float = 1.0) -> List[Camera]: + return list(self._train_cameras) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ) -> List[Camera]: + return list(self._train_cameras) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + """构建训练/重建所需的渲染器:训练走 render_full 以获得精确深度梯度,SDF 仍沿用 RaDe-GS。""" + + def render_view(view: Camera) -> Dict[str, torch.Tensor]: + pkg = render_full( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + compute_expected_normals=False, + compute_expected_depth=True, + compute_accurate_median_depth_gradient=True, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def load_mesh_config(config_name: str) -> Dict[str, Any]: + """支持直接传文件路径或 configs/mesh/.yaml。""" + candidate = Path(config_name) + if not candidate.is_file(): + candidate = Path(__file__).resolve().parent / "configs" / "mesh" / f"{config_name}.yaml" + if not candidate.is_file(): + raise FileNotFoundError(f"无法找到 mesh 配置:{config_name}") + with candidate.open("r", encoding="utf-8") as fh: + return yaml.safe_load(fh) + + +def load_optimization_config(config_name: str) -> Dict[str, Any]: + """加载优化配置文件,支持直接传文件路径或 configs/optimization/ .yaml。""" + candidate = Path(config_name) + if not candidate.is_file(): + candidate = Path(__file__).resolve().parent / "configs" / "optimization" / f"{config_name}.yaml" + if not candidate.is_file(): + raise FileNotFoundError(f"无法找到优化配置:{config_name}") + with candidate.open("r", encoding="utf-8") as fh: + config = yaml.safe_load(fh) + + # 验证配置结构 + required_keys = ["gaussian_params", "loss_weights", "depth_processing", "mesh_regularization"] + for key in required_keys: + if key not in config: + raise ValueError(f"优化配置文件缺少必需的键:{key}") + + return config + + +def setup_gaussian_optimization( + gaussians: GaussianModel, + opt_config: Dict[str, Any], +) -> Tuple[torch.optim.Optimizer, Dict[str, float]]: + """ + 根据优化配置设置高斯参数的可训练性和优化器。 + + Args: + gaussians: 高斯模型实例 + opt_config: 优化配置字典 + + Returns: + optimizer: 配置好的优化器 + loss_weights: 损失权重字典 + """ + param_groups = [] + params_config = opt_config["gaussian_params"] + + # 遍历所有配置的参数 + for param_name, param_cfg in params_config.items(): + if not hasattr(gaussians, param_name): + print(f"[WARNING] 高斯模型没有属性 {param_name},跳过") + continue + + param_tensor = getattr(gaussians, param_name) + if not isinstance(param_tensor, torch.Tensor): + print(f"[WARNING] {param_name} 不是张量,跳过") + continue + + trainable = param_cfg.get("trainable", False) + lr = param_cfg.get("lr", 0.0) + + # 设置梯度 + param_tensor.requires_grad_(trainable) + + # 如果可训练且学习率>0,添加到优化器参数组 + if trainable and lr > 0.0: + param_groups.append({ + "params": [param_tensor], + "lr": lr, + "name": param_name + }) + print(f"[INFO] 参数 {param_name}: trainable=True, lr={lr}") + else: + print(f"[INFO] 参数 {param_name}: trainable=False") + + if not param_groups: + raise ValueError("没有可训练的参数!请检查优化配置文件。") + + # 创建优化器 + optimizer = torch.optim.Adam(param_groups) + + # 提取损失权重 + loss_weights = opt_config["loss_weights"] + + return optimizer, loss_weights + + +def ensure_gaussian_occupancy(gaussians: GaussianModel) -> None: + """mesh regularization 依赖 9 维 occupancy 网格,此处在推理环境补齐缓冲。""" + needs_init = ( + not getattr(gaussians, "learn_occupancy", False) + or not hasattr(gaussians, "_occupancy_shift") + or gaussians._occupancy_shift.numel() == 0 + ) + if needs_init: + gaussians.learn_occupancy = True + base = torch.zeros((gaussians._xyz.shape[0], 9), device=gaussians._xyz.device) + shift = torch.zeros_like(base) + gaussians._base_occupancy = nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = nn.Parameter(shift.requires_grad_(True)) + + +def export_mesh_from_state( + gaussians: GaussianModel, + mesh_state: Dict[str, Any], + output_path: Path, + reference_camera: Optional[Camera] = None, +) -> None: + """根据当前 mesh_state 导出网格,并可选做视椎裁剪。""" + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + delaunay_tets = mesh_state.get("delaunay_tets") + + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_save = mesh if reference_camera is None else frustum_cull_mesh(mesh, reference_camera) + verts = mesh_to_save.verts.detach().cpu().numpy() + faces = mesh_to_save.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def quaternion_to_rotation_matrix(quaternion: Sequence[float]) -> np.ndarray: + """将单位四元数转换为 3x3 旋转矩阵。""" + # 这里显式转换 DISCOVERSE 导出的四元数,确保后续符合 MILo 的旋转约定 + q = np.asarray(quaternion, dtype=np.float64) + if q.shape != (4,): + raise ValueError(f"四元数需要包含 4 个分量,当前形状为 {q.shape}") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], + ] + ) + return rotation + + +def freeze_gaussian_model(model: GaussianModel) -> None: + """显式关闭高斯模型中参数的梯度。""" + # 推理阶段冻结高斯参数,后续循环只做前向评估 + tensor_attrs = [ + "_xyz", + "_features_dc", + "_features_rest", + "_scaling", + "_rotation", + "_opacity", + ] + for attr in tensor_attrs: + value = getattr(model, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + + +def prepare_depth_map(depth_tensor: torch.Tensor) -> np.ndarray: + """将深度张量转为二维 numpy 数组。""" + # 统一 squeeze 逻辑,防止 Matplotlib 因 shape 异常报错 + depth_np = depth_tensor.detach().cpu().numpy() + depth_np = np.squeeze(depth_np) + if depth_np.ndim == 1: + depth_np = np.expand_dims(depth_np, axis=0) + return depth_np + + +def prepare_normals(normal_tensor: torch.Tensor) -> np.ndarray: + """将法线张量转换为 HxWx3 的 numpy 数组。""" + # 兼容渲染输出为 (3,H,W) 或 (H,W,3) 的两种格式 + normals_np = normal_tensor.detach().cpu().numpy() + normals_np = np.squeeze(normals_np) + if normals_np.ndim == 3 and normals_np.shape[0] == 3: + normals_np = np.transpose(normals_np, (1, 2, 0)) + if normals_np.ndim == 2: + normals_np = normals_np[..., None] + return normals_np + + +def normals_to_rgb(normals: np.ndarray) -> np.ndarray: + """将 [-1,1] 范围的法线向量映射到 [0,1] 以便可视化。""" + normals = np.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0) + rgb = 0.5 * (normals + 1.0) + return np.clip(rgb, 0.0, 1.0).astype(np.float32) + + +def save_normal_visualization(normal_rgb: np.ndarray, output_path: Path) -> None: + """保存法线可视化图像。""" + plt.imsave(output_path, normal_rgb) + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, +) -> List[Camera]: + """按照 MILo 的视图约定读取相机文件,并进行坐标系转换。""" + # 自定义读取 DISCOVERSE 风格 JSON,统一转换到 COLMAP 世界->相机坐标 + pose_path = Path(json_path) + if not pose_path.is_file(): + raise FileNotFoundError(f"未找到相机 JSON:{json_path}") + + with pose_path.open("r", encoding="utf-8") as fh: + camera_list = json.load(fh) + + if isinstance(camera_list, dict): + for key in ("frames", "poses", "camera_poses"): + if key in camera_list and isinstance(camera_list[key], list): + camera_list = camera_list[key] + break + else: + raise ValueError(f"{json_path} 中的 JSON 结构不包含可识别的相机列表。") + + if not isinstance(camera_list, list) or not camera_list: + raise ValueError(f"{json_path} 中没有有效的相机条目。") + + fov_y = math.radians(fov_y_deg) + aspect_ratio = image_width / image_height + fov_x = 2.0 * math.atan(aspect_ratio * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(camera_list): + if "quaternion" in entry: + rotation_c2w = quaternion_to_rotation_matrix(entry["quaternion"]).astype(np.float32) + elif "rotation" in entry: + rotation_c2w = np.asarray(entry["rotation"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 quaternion 或 rotation。") + + if rotation_c2w.shape != (3, 3): + raise ValueError(f"相机条目 {idx} 的旋转矩阵形状应为 (3,3),实际为 {rotation_c2w.shape}") + + # DISCOVER-SE 的 quaternion/rotation 直接导入后,渲染出来的 PNG 会上下翻转, + # 说明其前进方向仍是 OpenGL 的 -Z。通过右乘 diag(1,-1,-1) 将其显式转换到 + # MILo/colmap 的坐标系,使得后续投影矩阵与深度图一致。 + rotation_c2w = rotation_c2w @ OPENGL_TO_COLMAP + + if "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"相机条目 {idx} 的 position 应为 3 维向量,实际为 {camera_center.shape}") + rotation_w2c = rotation_c2w.T + translation = (-rotation_w2c @ camera_center).astype(np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + # 如果 JSON 已直接存储 colmap 风格的 T(即世界到相机),这里假设它与旋转 + # 一样来自 OpenGL 坐标。严格来说也应执行同样的轴变换,但现有数据集只有 + # position 字段;为避免重复转换,这里只做类型检查并保留原值。 + elif "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 position/translation/tvec 信息。") + + if translation.shape != (3,): + raise ValueError(f"相机条目 {idx} 的平移向量应为长度 3,实际为 {translation.shape}") + + image_name = ( + entry.get("name") + or entry.get("image_name") + or entry.get("img_name") + or f"view_{idx:04d}" + ) + + camera = Camera( + colmap_id=str(idx), + R=rotation_c2w, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device="cuda", + ) + cameras.append(camera) + return cameras + + +def save_heatmap(data: np.ndarray, output_path: Path, title: str) -> None: + """将二维数据保存为热力图,便于直观观察差异。""" + plt.figure(figsize=(6, 4)) + finite_mask = np.isfinite(data) + if finite_mask.any(): + finite_values = data[finite_mask] + vmax = float(np.percentile(finite_values, 99.0)) + if (not np.isfinite(vmax)) or (vmax <= 0.0): + vmax = 1.0 + else: + vmax = 1.0 + masked_data = np.ma.array(data, mask=~finite_mask) + plt.imshow(masked_data, cmap="inferno", vmin=0.0, vmax=vmax) + plt.title(title) + plt.colorbar() + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + +def resolve_data_path(user_path: str) -> Path: + """将相对路径映射到 milo/data 下,支持用户提供绝对路径。""" + if not user_path: + raise ValueError("路径参数不能为空。") + path = Path(user_path).expanduser() + if path.is_absolute(): + return path + return (DATA_ROOT / path).resolve() + + +def save_detail_visualizations( + iteration: int, + detail_dir: Path, + view_index: int, + gaussian_depth_map: np.ndarray, + mesh_depth_map: np.ndarray, + gt_depth_map: np.ndarray, + gaussian_normals_map: np.ndarray, + mesh_normals_map: np.ndarray, + gt_normals_map: np.ndarray, + gaussian_valid: np.ndarray, + mesh_valid: np.ndarray, + gt_valid: np.ndarray, + shared_min: float, + shared_max: float, + depth_stats: Dict[str, float], + normal_stats: Dict[str, float], + loss_summary: Dict[str, float], +) -> None: + """保存更详细的调试可视化,便于逐迭代排查。""" + detail_dir.mkdir(parents=True, exist_ok=True) + + shared_min = shared_min if np.isfinite(shared_min) else 0.0 + shared_max = shared_max if np.isfinite(shared_max) else 1.0 + if not np.isfinite(shared_max) or shared_max <= shared_min: + shared_max = shared_min + 1.0 + + gaussian_mask = gaussian_valid & gt_valid + mesh_mask = mesh_valid & gt_valid + + gaussian_depth_diff = np.full(gaussian_depth_map.shape, np.nan, dtype=np.float32) + mesh_depth_diff = np.full(mesh_depth_map.shape, np.nan, dtype=np.float32) + if gaussian_mask.any(): + gaussian_depth_diff[gaussian_mask] = np.abs( + gaussian_depth_map[gaussian_mask] - gt_depth_map[gaussian_mask] + ) + if mesh_mask.any(): + mesh_depth_diff[mesh_mask] = np.abs(mesh_depth_map[mesh_mask] - gt_depth_map[mesh_mask]) + + diff_values: List[np.ndarray] = [] + if gaussian_mask.any(): + diff_values.append(gaussian_depth_diff[gaussian_mask].reshape(-1)) + if mesh_mask.any(): + diff_values.append(mesh_depth_diff[mesh_mask].reshape(-1)) + if diff_values: + diff_stack = np.concatenate(diff_values) + diff_vmax = float(np.percentile(diff_stack, 99.0)) + if (not np.isfinite(diff_vmax)) or diff_vmax <= 0.0: + diff_vmax = 1.0 + else: + diff_vmax = 1.0 + + depth_fig, depth_axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_depth, ax_gaussian_depth, ax_mesh_depth = depth_axes[0] + ax_gaussian_diff, ax_mesh_diff, ax_depth_hist = depth_axes[1] + + im_gt = ax_gt_depth.imshow(gt_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gt_depth.set_title("GT depth") + ax_gt_depth.axis("off") + depth_fig.colorbar(im_gt, ax=ax_gt_depth, fraction=0.046, pad=0.04) + + im_gaussian = ax_gaussian_depth.imshow( + gaussian_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + ax_gaussian_depth.set_title("Gaussian depth") + ax_gaussian_depth.axis("off") + depth_fig.colorbar(im_gaussian, ax=ax_gaussian_depth, fraction=0.046, pad=0.04) + + im_mesh = ax_mesh_depth.imshow( + mesh_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + ax_mesh_depth.set_title("Mesh depth") + ax_mesh_depth.axis("off") + depth_fig.colorbar(im_mesh, ax=ax_mesh_depth, fraction=0.046, pad=0.04) + + gaussian_diff_masked = np.ma.array(gaussian_depth_diff, mask=~gaussian_mask) + mesh_diff_masked = np.ma.array(mesh_depth_diff, mask=~mesh_mask) + ax_gaussian_diff.imshow(gaussian_diff_masked, cmap="magma", vmin=0.0, vmax=diff_vmax) + ax_gaussian_diff.set_title("|Gaussian - GT|") + ax_gaussian_diff.axis("off") + ax_mesh_diff.imshow(mesh_diff_masked, cmap="magma", vmin=0.0, vmax=diff_vmax) + ax_mesh_diff.set_title("|Mesh - GT|") + ax_mesh_diff.axis("off") + + ax_depth_hist.set_title("Depth diff histogram") + ax_depth_hist.set_xlabel("Absolute difference") + ax_depth_hist.set_ylabel("Count") + if diff_values: + if gaussian_mask.any(): + ax_depth_hist.hist( + gaussian_depth_diff[gaussian_mask].reshape(-1), + bins=60, + alpha=0.6, + label="Gaussian", + ) + if mesh_mask.any(): + ax_depth_hist.hist( + mesh_depth_diff[mesh_mask].reshape(-1), + bins=60, + alpha=0.6, + label="Mesh", + ) + ax_depth_hist.legend() + else: + ax_depth_hist.text(0.5, 0.5, "No valid depth diffs", ha="center", va="center") + depth_fig.suptitle( + f"Iter {iteration:04d} view {view_index} | depth_loss={loss_summary['depth_loss']:.4f}" + f" | normal_loss={loss_summary['normal_loss']:.4f} | mesh_loss={loss_summary['mesh_loss']:.4f}", + fontsize=12, + ) + depth_fig.tight_layout(rect=[0, 0, 1, 0.95]) + depth_fig.savefig(detail_dir / f"detail_depth_iter_{iteration:04d}.png", dpi=300) + plt.close(depth_fig) + + gaussian_normals_rgb = normals_to_rgb(gaussian_normals_map) + mesh_normals_rgb = normals_to_rgb(mesh_normals_map) + gt_normals_rgb = normals_to_rgb(gt_normals_map) + + gaussian_normal_mask = np.all(np.isfinite(gaussian_normals_map), axis=-1) & np.all( + np.isfinite(gt_normals_map), axis=-1 + ) + mesh_normal_mask = np.all(np.isfinite(mesh_normals_map), axis=-1) & np.all( + np.isfinite(gt_normals_map), axis=-1 + ) + gaussian_normal_diff = np.linalg.norm(gaussian_normals_map - gt_normals_map, axis=-1) + mesh_normal_diff = np.linalg.norm(mesh_normals_map - gt_normals_map, axis=-1) + gaussian_normal_diff = np.where(gaussian_normal_mask, gaussian_normal_diff, np.nan) + mesh_normal_diff = np.where(mesh_normal_mask, mesh_normal_diff, np.nan) + + normal_diff_values: List[np.ndarray] = [] + if gaussian_normal_mask.any(): + normal_diff_values.append(gaussian_normal_diff[gaussian_normal_mask].reshape(-1)) + if mesh_normal_mask.any(): + normal_diff_values.append(mesh_normal_diff[mesh_normal_mask].reshape(-1)) + if normal_diff_values: + normal_diff_stack = np.concatenate(normal_diff_values) + normal_vmax = float(np.percentile(normal_diff_stack, 99.0)) + if (not np.isfinite(normal_vmax)) or normal_vmax <= 0.0: + normal_vmax = 1.0 + else: + normal_vmax = 1.0 + + normal_fig, normal_axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_normals, ax_gaussian_normals, ax_mesh_normals = normal_axes[0] + ax_gaussian_normal_diff, ax_mesh_normal_diff, ax_normal_text = normal_axes[1] + + ax_gt_normals.imshow(gt_normals_rgb) + ax_gt_normals.set_title("GT normals") + ax_gt_normals.axis("off") + ax_gaussian_normals.imshow(gaussian_normals_rgb) + ax_gaussian_normals.set_title("Gaussian normals") + ax_gaussian_normals.axis("off") + ax_mesh_normals.imshow(mesh_normals_rgb) + ax_mesh_normals.set_title("Mesh normals") + ax_mesh_normals.axis("off") + + gaussian_normals_masked = np.ma.array(gaussian_normal_diff, mask=~gaussian_normal_mask) + mesh_normals_masked = np.ma.array(mesh_normal_diff, mask=~mesh_normal_mask) + im_gaussian_normal = ax_gaussian_normal_diff.imshow( + gaussian_normals_masked, + cmap="magma", + vmin=0.0, + vmax=normal_vmax, + ) + ax_gaussian_normal_diff.set_title("‖Gaussian-GT‖") + ax_gaussian_normal_diff.axis("off") + normal_fig.colorbar(im_gaussian_normal, ax=ax_gaussian_normal_diff, fraction=0.046, pad=0.04) + + im_mesh_normal = ax_mesh_normal_diff.imshow( + mesh_normals_masked, + cmap="magma", + vmin=0.0, + vmax=normal_vmax, + ) + ax_mesh_normal_diff.set_title("‖Mesh-GT‖") + ax_mesh_normal_diff.axis("off") + normal_fig.colorbar(im_mesh_normal, ax=ax_mesh_normal_diff, fraction=0.046, pad=0.04) + + ax_normal_text.axis("off") + text_lines = [ + f"Iter {iteration:04d} view {view_index}", + f"Loss: total={loss_summary['loss']:.4f} depth={loss_summary['depth_loss']:.4f} normal={loss_summary['normal_loss']:.4f}", + f"Mesh loss={loss_summary['mesh_loss']:.4f} (depth={loss_summary['mesh_depth_loss']:.4f}, normal={loss_summary['mesh_normal_loss']:.4f})", + f"Occupancy: centers={loss_summary['occupied_loss']:.4f} labels={loss_summary['labels_loss']:.4f}", + f"Depth metrics: mae={depth_stats['mae']:.4f} rmse={depth_stats['rmse']:.4f}", + f"Normal metrics: valid_px={normal_stats['valid_px']:.0f} cos={normal_stats['mean_cos']:.4f}", + f"Grad norm={loss_summary['grad_norm']:.4f}", + ] + ax_normal_text.text(0.0, 1.0, "\n".join(text_lines), va="top") + normal_fig.tight_layout() + normal_fig.savefig(detail_dir / f"detail_normals_iter_{iteration:04d}.png", dpi=300) + plt.close(normal_fig) + +def main(): + parser = ArgumentParser(description="桥梁场景高斯到网格迭代分析脚本") + parser.add_argument( + "--num_iterations", + type=int, + default=5, + help="执行循环的次数(未启用 --lock_view_repeat 时生效)", + ) + parser.add_argument("--ma_beta", type=float, default=0.8, help="loss 滑动平均系数") + + # ========== 新增:优化配置文件参数 ========== + parser.add_argument( + "--opt_config", + type=str, + default="default", + help="优化配置名称或完整路径(默认 default,查找 configs/optimization/default.yaml)", + ) + + # ========== 保留旧参数以向后兼容,但会被YAML配置覆盖 ========== + parser.add_argument( + "--depth_loss_weight", type=float, default=None, help="(已弃用,请使用 --opt_config) 深度一致性项权重" + ) + parser.add_argument( + "--normal_loss_weight", type=float, default=None, help="(已弃用,请使用 --opt_config) 法线一致性项权重" + ) + parser.add_argument( + "--lr", + "--learning_rate", + dest="lr", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) 仅优化 XYZ 的学习率", + ) + parser.add_argument( + "--depth_clip_min", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) 深度最小裁剪值", + ) + parser.add_argument( + "--depth_clip_max", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) 深度最大裁剪值", + ) + parser.add_argument( + "--mesh_depth_weight", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) mesh 深度项权重覆盖", + ) + parser.add_argument( + "--mesh_normal_weight", + type=float, + default=None, + help="(已弃用,请使用 --opt_config) mesh 法线项权重覆盖", + ) + + parser.add_argument( + "--delaunay_reset_interval", + type=int, + default=1000, + help="每隔多少次迭代重建一次 Delaunay(<=0 表示每次重建)", + ) + parser.add_argument( + "--mesh_config", + type=str, + default="medium", + help="mesh 配置名称或路径(默认 medium)", + ) + parser.add_argument( + "--save_interval", + type=int, + default=None, + help="保存可视化/npz 的间隔,默认与 Delaunay 重建间隔相同", + ) + parser.add_argument( + "--detail_interval", + type=int, + default=None, + help="详细调试图像保存间隔(单位:迭代,未设置则禁用)", + ) + parser.add_argument( + "--heatmap_dir", + type=str, + default="yufu2mesh_outputs", + help="保存热力图等输出的目录", + ) + parser.add_argument( + "--depth_gt_dir", + type=str, + default="bridge_small/depth", + help="Discoverse 深度 npy 相对路径(根目录为 milo/data,亦可填绝对路径)", + ) + parser.add_argument( + "--ply_path", + type=str, + default="bridge_small/yufu_bridge_small.ply", + help="初始高斯 PLY 路径(相对于 milo/data,可填绝对路径)", + ) + parser.add_argument( + "--camera_poses_json", + type=str, + default="bridge_small/camera_poses_cam1.json", + help="相机位姿 JSON 路径(相对于 milo/data,可填绝对路径)", + ) + parser.add_argument( + "--normal_cache_dir", + type=str, + default=None, + help="法线缓存目录,默认为 runs/ /normal_gt", + ) + parser.add_argument( + "--skip_normal_gt_generation", + action="store_true", + help="已存在缓存时跳过初始法线 GT 预计算", + ) + parser.add_argument("--seed", type=int, default=0, help="控制随机性的种子") + parser.add_argument( + "--lock_view_repeat", + type=int, + default=None, + help="启用视角锁定调试模式时指定同一视角连续迭代次数,启用后总迭代数=视角数量×该值并忽略 --num_iterations;未提供则关闭该模式", + ) + parser.add_argument( + "--log_interval", + type=int, + default=100, + help="控制迭代日志的打印频率(默认每次迭代打印)", + ) + parser.add_argument( + "--warn_until_iter", + type=int, + default=3000, + help="surface sampling warmup 迭代数(用于 mesh downsample)", + ) + parser.add_argument( + "--imp_metric", + type=str, + default="outdoor", + choices=["outdoor", "indoor"], + help="surface sampling 的重要性度量类型", + ) + parser.add_argument( + "--mesh_start_iter", + type=int, + default=2000, + help="mesh 正则起始迭代(默认 2000,避免冷启动阶段干扰)", + ) + parser.add_argument( + "--mesh_update_interval", + type=int, + default=5, + help="mesh 正则重建/回传间隔,>1 可减少 DMTet 抖动(默认 5)", + ) + pipe = PipelineParams(parser) + args = parser.parse_args() + + # ========== 加载优化配置 ========== + print(f"[INFO] 加载优化配置:{args.opt_config}") + opt_config = load_optimization_config(args.opt_config) + + # 兼容性:如果命令行指定了旧参数,发出警告并使用YAML配置 + if args.depth_loss_weight is not None: + print(f"[WARNING] --depth_loss_weight 已弃用,将使用YAML配置中的值") + if args.normal_loss_weight is not None: + print(f"[WARNING] --normal_loss_weight 已弃用,将使用YAML配置中的值") + if args.lr is not None: + print(f"[WARNING] --lr 已弃用,将使用YAML配置中的学习率设置") + if args.depth_clip_min is not None: + print(f"[WARNING] --depth_clip_min 已弃用,将使用YAML配置中的值") + if args.depth_clip_max is not None: + print(f"[WARNING] --depth_clip_max 已弃用,将使用YAML配置中的值") + if args.mesh_depth_weight is not None: + print(f"[WARNING] --mesh_depth_weight 已弃用,将使用YAML配置中的值") + if args.mesh_normal_weight is not None: + print(f"[WARNING] --mesh_normal_weight 已弃用,将使用YAML配置中的值") + + lock_view_mode = args.lock_view_repeat is not None + lock_repeat = max(1, args.lock_view_repeat) if lock_view_mode else 1 + + pipe.debug = getattr(args, "debug", False) + + # 所有输出固定写入 milo/runs/ 下,便于管理实验产物 + base_run_dir = MILO_DIR / "runs" + output_dir = base_run_dir / args.heatmap_dir + output_dir.mkdir(parents=True, exist_ok=True) + iteration_image_dir = output_dir / "iteration_images" + iteration_image_dir.mkdir(parents=True, exist_ok=True) + lock_view_output_dir: Optional[Path] = None + if lock_view_mode: + lock_view_output_dir = output_dir / "lock_view_repeat" + lock_view_output_dir.mkdir(parents=True, exist_ok=True) + + detail_interval: Optional[int] = None + detail_image_dir: Optional[Path] = None + if args.detail_interval is not None: + if args.detail_interval <= 0: + print("[WARNING] --detail_interval <= 0,已禁用详细调试输出。") + else: + detail_interval = max(1, args.detail_interval) + detail_image_dir = output_dir / "detail_images" + detail_image_dir.mkdir(parents=True, exist_ok=True) + print(f"[INFO] 启用详细调试输出:每 {detail_interval} 次迭代写入 detail_images。") + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + depth_gt_dir = resolve_data_path(args.depth_gt_dir) + ply_path = resolve_data_path(args.ply_path) + camera_poses_json = resolve_data_path(args.camera_poses_json) + print( + f"[INFO] 使用数据路径:depth_gt={depth_gt_dir}, ply={ply_path}, camera_json={camera_poses_json}" + ) + + gaussians = GaussianModel(sh_degree=0, learn_occupancy=True) + gaussians.load_ply(str(ply_path)) + + # ========== 使用新的优化配置设置参数 ========== + print("[INFO] 配置高斯参数优化...") + optimizer, loss_weights = setup_gaussian_optimization(gaussians, opt_config) + + height = 720 + width = 1280 + fov_y_deg = 75.0 + + train_cameras = load_cameras_from_json( + json_path=str(camera_poses_json), + image_height=height, + image_width=width, + fov_y_deg=fov_y_deg, + ) + num_views = len(train_cameras) + print(f"[INFO] 成功加载 {num_views} 个相机视角。") + if lock_view_mode: + total_iterations = num_views * lock_repeat + print( + f"[INFO] 启用视角锁定调试:每个视角连续 {lock_repeat} 次,总迭代数 {total_iterations}(已忽略 --num_iterations)。" + ) + else: + total_iterations = args.num_iterations + + device = gaussians._xyz.device + background = torch.tensor([0.0, 0.0, 0.0], device=device) + + # ========== 应用优化配置到mesh和深度处理 ========== + mesh_config = load_mesh_config(args.mesh_config) + mesh_config["start_iter"] = max(0, args.mesh_start_iter) + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", total_iterations), total_iterations) + mesh_config["mesh_update_interval"] = max(1, args.mesh_update_interval) + mesh_config["delaunay_reset_interval"] = args.delaunay_reset_interval + + # 从优化配置中获取mesh权重 + mesh_reg_config = opt_config["mesh_regularization"] + mesh_config["depth_weight"] = mesh_reg_config["depth_weight"] + mesh_config["normal_weight"] = mesh_reg_config["normal_weight"] + print(f"[INFO] Mesh正则化权重: depth={mesh_config['depth_weight']}, normal={mesh_config['normal_weight']}") + + # 这里默认沿用 surface 采样以对齐训练阶段;如仅需快速分析,也可以切换为 random 提升速度。 + mesh_config["delaunay_sampling_method"] = "surface" + + scene_wrapper = ManualScene(train_cameras) + + ensure_gaussian_occupancy(gaussians) + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + + # 从优化配置中获取深度处理参数 + depth_proc_config = opt_config["depth_processing"] + depth_clip_min = depth_proc_config.get("clip_min") + depth_clip_max = depth_proc_config.get("clip_max") + if depth_clip_min is not None and depth_clip_min <= 0.0: + depth_clip_min = None + print(f"[INFO] 深度裁剪范围: min={depth_clip_min}, max={depth_clip_max}") + + depth_provider = DepthProvider( + depth_root=depth_gt_dir, + image_height=height, + image_width=width, + device=device, + clip_min=depth_clip_min, + clip_max=depth_clip_max, + ) + + normal_cache_dir = Path(args.normal_cache_dir) if args.normal_cache_dir else (output_dir / "normal_gt") + normal_cache = NormalGroundTruthCache( + cache_dir=normal_cache_dir, + image_height=height, + image_width=width, + device=device, + ) + if args.skip_normal_gt_generation: + missing = [idx for idx in range(num_views) if not normal_cache.has(idx)] + if missing: + raise RuntimeError( + f"跳过法线 GT 预计算被拒绝,仍有 {len(missing)} 个视角缺少缓存(示例 {missing[:5]})。" + ) + else: + print("[INFO] 开始预计算初始法线 GT(仅进行一次,若存在缓存会自动跳过)。") + normal_cache.ensure_all(train_cameras, render_view) + normal_cache.clear_memory_cache() + + mesh_renderer, mesh_state = initialize_mesh_regularization(scene_wrapper, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + # optimizer已在setup_gaussian_optimization中创建 + mesh_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + ) + + # 输出loss权重配置 + print(f"[INFO] Loss权重配置:") + for loss_name, weight in loss_weights.items(): + print(f" > {loss_name}: {weight}") + + # 记录整个迭代过程中的指标与梯度,结束时统一写入 npz/曲线 + stats_history: Dict[str, List[float]] = { + "iteration": [], + "depth_loss": [], + "normal_loss": [], + "mesh_loss": [], + "mesh_depth_loss": [], + "mesh_normal_loss": [], + "occupied_centers_loss": [], + "occupancy_labels_loss": [], + "depth_mae": [], + "depth_rmse": [], + "normal_mean_cos": [], + "normal_valid_px": [], + "grad_norm": [], + } + + moving_loss = None + previous_depth: Dict[int, np.ndarray] = {} + previous_normals: Dict[int, np.ndarray] = {} + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + lock_sequence = list(range(num_views)) if lock_view_mode else [] + lock_view_ptr = 0 + lock_repeat_ptr = 0 + save_interval = args.save_interval if args.save_interval is not None else args.delaunay_reset_interval + if save_interval is None or save_interval <= 0: + save_interval = 1 + log_interval = max(1, args.log_interval) + mesh_export_warned = False + + for iteration in range(total_iterations): + optimizer.zero_grad(set_to_none=True) + if lock_view_mode: + view_index = lock_sequence[lock_view_ptr] + else: + if not camera_stack: + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + view_index = camera_stack.pop() + viewpoint = train_cameras[view_index] + + training_pkg = render_view(viewpoint) + gt_depth_tensor, gt_depth_mask = depth_provider.get(view_index) + depth_loss_tensor, depth_stats = compute_depth_loss_tensor( + pred_depth=training_pkg["median_depth"], + gt_depth=gt_depth_tensor, + gt_mask=gt_depth_mask, + ) + gt_normals_tensor = normal_cache.get(view_index) + normal_loss_tensor, normal_stats = compute_normal_loss_tensor( + pred_normals=training_pkg["normal"], + gt_normals=gt_normals_tensor, + base_mask=gt_depth_mask, + ) + def _zero_mesh_pkg() -> Dict[str, Any]: + zero = torch.zeros((), device=device) + depth_zero = torch.zeros_like(training_pkg["median_depth"]) + normal_zero = torch.zeros( + training_pkg["median_depth"].shape[-2], + training_pkg["median_depth"].shape[-1], + 3, + device=device, + ) + return { + "mesh_loss": zero, + "mesh_depth_loss": zero, + "mesh_normal_loss": zero, + "occupied_centers_loss": zero, + "occupancy_labels_loss": zero, + "updated_state": mesh_state, + "mesh_render_pkg": { + "depth": depth_zero, + "normals": normal_zero, + }, + } + + mesh_active = iteration >= mesh_config["start_iter"] + if mesh_active: + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=training_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_index, + gaussians=gaussians, + scene=scene_wrapper, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=1.0, + args=mesh_args, + integrate_func=integrate_radegs, + ) + else: + mesh_pkg = _zero_mesh_pkg() + + mesh_state = mesh_pkg["updated_state"] + mesh_loss_tensor = mesh_pkg["mesh_loss"] + + # ========== 使用YAML配置的loss权重 ========== + total_loss = ( + loss_weights["depth"] * depth_loss_tensor + + loss_weights["normal"] * normal_loss_tensor + + mesh_loss_tensor + ) + depth_loss_value = float(depth_loss_tensor.detach().item()) + normal_loss_value = float(normal_loss_tensor.detach().item()) + mesh_loss_value = float(mesh_loss_tensor.detach().item()) + loss_value = float(total_loss.detach().item()) + + if total_loss.requires_grad: + total_loss.backward() + # 计算所有可训练参数的总梯度范数 + total_grad_norm_sq = 0.0 + for param_group in optimizer.param_groups: + for param in param_group["params"]: + if param.grad is not None: + total_grad_norm_sq += param.grad.detach().norm().item() ** 2 + grad_norm = float(total_grad_norm_sq ** 0.5) + optimizer.step() + else: + optimizer.zero_grad(set_to_none=True) + grad_norm = float("nan") + + mesh_render_pkg = mesh_pkg["mesh_render_pkg"] + mesh_depth_map = prepare_depth_map(mesh_render_pkg["depth"]) + mesh_normals_map = prepare_normals(mesh_render_pkg["normals"]) + gaussian_depth_map = prepare_depth_map(training_pkg["median_depth"]) + gaussian_normals_map = prepare_normals(training_pkg["normal"]) + gt_depth_map = depth_provider.as_numpy(view_index) + gt_normals_map = prepare_normals(gt_normals_tensor) + + mesh_valid = np.isfinite(mesh_depth_map) & (mesh_depth_map > 0.0) + gaussian_valid = np.isfinite(gaussian_depth_map) & (gaussian_depth_map > 0.0) + gt_valid = np.isfinite(gt_depth_map) & (gt_depth_map > 0.0) + overlap_mask = gaussian_valid & gt_valid + + depth_delta = gaussian_depth_map - gt_depth_map + if overlap_mask.any(): + delta_abs = np.abs(depth_delta[overlap_mask]) + diff_mean = float(delta_abs.mean()) + diff_max = float(delta_abs.max()) + diff_rmse = float(np.sqrt(np.mean(depth_delta[overlap_mask] ** 2))) + else: + diff_mean = diff_max = diff_rmse = float("nan") + + mesh_depth_loss = float(mesh_pkg["mesh_depth_loss"].item()) + mesh_normal_loss = float(mesh_pkg["mesh_normal_loss"].item()) + occupied_loss = float(mesh_pkg["occupied_centers_loss"].item()) + labels_loss = float(mesh_pkg["occupancy_labels_loss"].item()) + + moving_loss = ( + loss_value + if moving_loss is None + else args.ma_beta * moving_loss + (1 - args.ma_beta) * loss_value + ) + + stats_history["iteration"].append(float(iteration)) + stats_history["depth_loss"].append(depth_loss_value) + stats_history["normal_loss"].append(normal_loss_value) + stats_history["mesh_loss"].append(mesh_loss_value) + stats_history["mesh_depth_loss"].append(mesh_depth_loss) + stats_history["mesh_normal_loss"].append(mesh_normal_loss) + stats_history["occupied_centers_loss"].append(occupied_loss) + stats_history["occupancy_labels_loss"].append(labels_loss) + stats_history["depth_mae"].append(depth_stats["mae"]) + stats_history["depth_rmse"].append(depth_stats["rmse"]) + stats_history["normal_mean_cos"].append(normal_stats["mean_cos"]) + stats_history["normal_valid_px"].append(float(normal_stats["valid_px"])) + stats_history["grad_norm"].append(grad_norm) + + def _fmt(value: float) -> str: + return f"{value:.6f}" + + if (iteration % log_interval == 0) or (iteration == total_iterations - 1): + print( + "[INFO] Iter {iter:02d} | loss={total} (depth={depth}, normal={normal}, mesh={mesh}) | ma_loss={ma}".format( + iter=iteration, + total=_fmt(loss_value), + depth=_fmt(depth_loss_value), + normal=_fmt(normal_loss_value), + mesh=_fmt(mesh_loss_value), + ma=f"{moving_loss:.6f}", + ) + ) + + should_save = (save_interval <= 0) or (iteration % save_interval == 0) + should_save_detail = ( + detail_interval is not None and (iteration % detail_interval == 0) + ) + shared_min: Optional[float] = None + shared_max: Optional[float] = None + if should_save or should_save_detail: + valid_values: List[np.ndarray] = [] + if mesh_valid.any(): + valid_values.append(mesh_depth_map[mesh_valid].reshape(-1)) + if gaussian_valid.any(): + valid_values.append(gaussian_depth_map[gaussian_valid].reshape(-1)) + if gt_valid.any(): + valid_values.append(gt_depth_map[gt_valid].reshape(-1)) + if valid_values: + all_valid = np.concatenate(valid_values) + shared_min = float(all_valid.min()) + shared_max = float(all_valid.max()) + else: + shared_min, shared_max = 0.0, 1.0 + + loss_summary = { + "depth_loss": depth_loss_value, + "normal_loss": normal_loss_value, + "mesh_loss": mesh_loss_value, + "mesh_depth_loss": mesh_depth_loss, + "mesh_normal_loss": mesh_normal_loss, + "occupied_loss": occupied_loss, + "labels_loss": labels_loss, + "loss": loss_value, + "grad_norm": grad_norm, + } + + if should_save: + if shared_min is None or shared_max is None: + shared_min, shared_max = 0.0, 1.0 + + gaussian_depth_vis_path = iteration_image_dir / f"gaussian_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gaussian_depth_vis_path, + gaussian_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + depth_vis_path = iteration_image_dir / f"mesh_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + depth_vis_path, + mesh_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + gt_depth_vis_path = iteration_image_dir / f"gt_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gt_depth_vis_path, + gt_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + normal_vis_path = iteration_image_dir / f"mesh_normal_vis_iter_{iteration:02d}.png" + mesh_normals_rgb = normals_to_rgb(mesh_normals_map) + save_normal_visualization(mesh_normals_rgb, normal_vis_path) + + gaussian_normal_vis_path = iteration_image_dir / f"gaussian_normal_vis_iter_{iteration:02d}.png" + gaussian_normals_rgb = normals_to_rgb(gaussian_normals_map) + save_normal_visualization(gaussian_normals_rgb, gaussian_normal_vis_path) + + gt_normal_vis_path = iteration_image_dir / f"gt_normal_vis_iter_{iteration:02d}.png" + gt_normals_rgb = normals_to_rgb(gt_normals_map) + save_normal_visualization(gt_normals_rgb, gt_normal_vis_path) + + output_npz = output_dir / f"mesh_render_iter_{iteration:02d}.npz" + np.savez( + output_npz, + mesh_depth=mesh_depth_map, + gaussian_depth=gaussian_depth_map, + depth_gt=gt_depth_map, + mesh_normals=mesh_normals_map, + gaussian_normals=gaussian_normals_map, + normal_gt=gt_normals_map, + depth_loss=depth_loss_value, + normal_loss=normal_loss_value, + mesh_loss=mesh_loss_value, + mesh_depth_loss=mesh_depth_loss, + mesh_normal_loss=mesh_normal_loss, + occupied_centers_loss=occupied_loss, + occupancy_labels_loss=labels_loss, + loss=loss_value, + moving_loss=moving_loss, + depth_mae=depth_stats["mae"], + depth_rmse=depth_stats["rmse"], + normal_valid_px=normal_stats["valid_px"], + normal_mean_cos=normal_stats["mean_cos"], + grad_norm=grad_norm, + iteration=iteration, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + if overlap_mask.any(): + depth_diff_vis = np.full_like( + gaussian_depth_map, np.nan, dtype=np.float32 + ) + depth_diff_vis[overlap_mask] = np.abs(depth_delta[overlap_mask]) + save_heatmap( + depth_diff_vis, + iteration_image_dir / f"depth_diff_iter_{iteration:02d}.png", + f"|Pred-GT| iter {iteration}", + ) + + composite_path = iteration_image_dir / f"comparison_iter_{iteration:02d}.png" + fig, axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_depth, ax_gaussian_depth, ax_mesh_depth = axes[0] + ax_gt_normals, ax_gaussian_normals, ax_mesh_normals = axes[1] + + im0 = ax_gt_depth.imshow(gt_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gt_depth.set_title("GT depth") + ax_gt_depth.axis("off") + fig.colorbar(im0, ax=ax_gt_depth, fraction=0.046, pad=0.04) + + im1 = ax_gaussian_depth.imshow(gaussian_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gaussian_depth.set_title("Gaussian depth") + ax_gaussian_depth.axis("off") + fig.colorbar(im1, ax=ax_gaussian_depth, fraction=0.046, pad=0.04) + + im2 = ax_mesh_depth.imshow(mesh_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_mesh_depth.set_title("Mesh depth") + ax_mesh_depth.axis("off") + fig.colorbar(im2, ax=ax_mesh_depth, fraction=0.046, pad=0.04) + + ax_gt_normals.imshow(gt_normals_rgb) + ax_gt_normals.set_title("GT normals") + ax_gt_normals.axis("off") + + ax_gaussian_normals.imshow(gaussian_normals_rgb) + ax_gaussian_normals.set_title("Gaussian normals") + ax_gaussian_normals.axis("off") + + ax_mesh_normals.imshow(mesh_normals_rgb) + ax_mesh_normals.set_title("Mesh normals") + ax_mesh_normals.axis("off") + + info_lines = [ + f"Iteration: {iteration:02d}", + f"View index: {view_index}", + f"GT depth valid px: {int(gt_valid.sum())}", + f"Gaussian depth valid px: {int(gaussian_valid.sum())}", + f"|Pred - GT| mean={diff_mean:.3f}, max={diff_max:.3f}, RMSE={diff_rmse:.3f}", + f"Depth loss={_fmt(depth_loss_value)} (w={loss_weights['depth']:.2f}, mae={depth_stats['mae']:.3f}, rmse={depth_stats['rmse']:.3f})", + f"Normal loss={_fmt(normal_loss_value)} (w={loss_weights['normal']:.2f}, px={normal_stats['valid_px']}, cos={normal_stats['mean_cos']:.3f})", + f"Mesh loss={_fmt(mesh_loss_value)}", + f"Mesh depth loss={_fmt(mesh_depth_loss)} mesh normal loss={_fmt(mesh_normal_loss)}", + f"Occupied centers={_fmt(occupied_loss)} labels={_fmt(labels_loss)}", + ] + fig.suptitle("\n".join(info_lines), fontsize=12, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.94]) + fig.savefig(composite_path, dpi=300) + plt.close(fig) + + mesh_ready_for_export = mesh_active or mesh_state.get("delaunay_tets") is not None + if mesh_ready_for_export: + with torch.no_grad(): + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=output_dir / f"mesh_iter_{iteration:02d}.ply", + reference_camera=None, + ) + else: + if not mesh_export_warned: + print( + "[INFO] Mesh 正则尚未启动或尚未生成 Delaunay,已跳过网格导出以避免全量三角化。" + ) + mesh_export_warned = True + + if should_save_detail and detail_image_dir is not None: + save_detail_visualizations( + iteration=iteration, + detail_dir=detail_image_dir, + view_index=view_index, + gaussian_depth_map=gaussian_depth_map, + mesh_depth_map=mesh_depth_map, + gt_depth_map=gt_depth_map, + gaussian_normals_map=gaussian_normals_map, + mesh_normals_map=mesh_normals_map, + gt_normals_map=gt_normals_map, + gaussian_valid=gaussian_valid, + mesh_valid=mesh_valid, + gt_valid=gt_valid, + shared_min=shared_min if shared_min is not None else 0.0, + shared_max=shared_max if shared_max is not None else 1.0, + depth_stats=depth_stats, + normal_stats=normal_stats, + loss_summary=loss_summary, + ) + + if lock_view_mode and lock_view_output_dir is not None: + if view_index in previous_depth: + prev_depth_map = previous_depth[view_index] + depth_valid_mask = ( + np.isfinite(prev_depth_map) + & np.isfinite(gaussian_depth_map) + & (prev_depth_map > 0.0) + & (gaussian_depth_map > 0.0) + ) + if depth_valid_mask.any(): + depth_delta = np.abs(gaussian_depth_map - prev_depth_map) + depth_diff = np.full_like( + gaussian_depth_map, np.nan, dtype=np.float32 + ) + depth_diff[depth_valid_mask] = depth_delta[depth_valid_mask] + save_heatmap( + depth_diff, + lock_view_output_dir / f"depth_diff_iter_{iteration:02d}_temporal.png", + f"Depth Δ iter {iteration}", + ) + if view_index in previous_normals: + prev_normals_map = previous_normals[view_index] + normal_valid_mask = np.all( + np.isfinite(prev_normals_map), axis=-1 + ) & np.all(np.isfinite(gaussian_normals_map), axis=-1) + if normal_valid_mask.any(): + normal_delta = gaussian_normals_map - prev_normals_map + if normal_delta.ndim == 3: + normal_diff = np.linalg.norm(normal_delta, axis=-1) + else: + normal_diff = np.abs(normal_delta) + normal_diff_vis = np.full_like( + normal_diff, np.nan, dtype=np.float32 + ) + normal_diff_vis[normal_valid_mask] = normal_diff[normal_valid_mask] + save_heatmap( + normal_diff_vis, + lock_view_output_dir / f"normal_diff_iter_{iteration:02d}_temporal.png", + f"Normal Δ iter {iteration}", + ) + + if lock_view_mode: + previous_depth[view_index] = gaussian_depth_map + previous_normals[view_index] = gaussian_normals_map + lock_repeat_ptr += 1 + if lock_repeat_ptr >= lock_repeat: + lock_repeat_ptr = 0 + lock_view_ptr = (lock_view_ptr + 1) % num_views + with torch.no_grad(): + # 输出完整指标轨迹及汇总曲线,方便任务结束后快速复盘 + history_npz = output_dir / "metrics_history.npz" + np.savez( + history_npz, + **{k: np.asarray(v, dtype=np.float32) for k, v in stats_history.items()}, + ) + summary_fig = output_dir / "metrics_summary.png" + if stats_history["iteration"]: + fig, axes = plt.subplots(2, 2, figsize=(16, 10), dpi=200) + iters = np.asarray(stats_history["iteration"]) + axes[0, 0].plot(iters, stats_history["depth_loss"], label="depth") + axes[0, 0].plot(iters, stats_history["normal_loss"], label="normal") + axes[0, 0].plot(iters, stats_history["mesh_loss"], label="mesh") + axes[0, 0].set_title("Total losses") + axes[0, 0].set_xlabel("Iteration") + axes[0, 0].legend() + + axes[0, 1].plot(iters, stats_history["mesh_depth_loss"], label="mesh depth") + axes[0, 1].plot(iters, stats_history["mesh_normal_loss"], label="mesh normal") + axes[0, 1].plot(iters, stats_history["occupied_centers_loss"], label="occupied centers") + axes[0, 1].plot(iters, stats_history["occupancy_labels_loss"], label="occupancy labels") + axes[0, 1].set_title("Mesh regularization components") + axes[0, 1].set_xlabel("Iteration") + axes[0, 1].legend() + + axes[1, 0].plot(iters, stats_history["depth_mae"], label="depth MAE") + axes[1, 0].plot(iters, stats_history["depth_rmse"], label="depth RMSE") + axes[1, 0].set_title("Depth metrics") + axes[1, 0].set_xlabel("Iteration") + axes[1, 0].legend() + + axes[1, 1].plot(iters, stats_history["normal_mean_cos"], label="mean cos") + axes[1, 1].plot(iters, stats_history["grad_norm"], label="grad norm") + axes[1, 1].set_title("Normals / Gradients") + axes[1, 1].set_xlabel("Iteration") + axes[1, 1].legend() + + fig.tight_layout() + fig.savefig(summary_fig) + plt.close(fig) + print(f"[INFO] 已保存曲线汇总:{summary_fig}") + print(f"[INFO] 记录所有迭代指标到 {history_npz}") + final_mesh_path = output_dir / "mesh_final.ply" + final_gaussian_path = output_dir / "gaussians_final.ply" + print(f"[INFO] 导出最终 mesh 到 {final_mesh_path}") + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=final_mesh_path, + reference_camera=None, + ) + print(f"[INFO] 导出最终高斯到 {final_gaussian_path}") + gaussians.save_ply(str(final_gaussian_path)) + print("[INFO] 循环结束,所有结果已写入输出目录。") + + +if __name__ == "__main__": + main() diff --git a/milo/yufu2mesh_new_backup_20250114.py b/milo/yufu2mesh_new_backup_20250114.py new file mode 100644 index 0000000..007300f --- /dev/null +++ b/milo/yufu2mesh_new_backup_20250114.py @@ -0,0 +1,1204 @@ +from pathlib import Path +import math +import json +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple +from types import SimpleNamespace + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +import trimesh +import yaml + +from argparse import ArgumentParser + +from functional import ( + compute_delaunay_triangulation, + extract_mesh, + frustum_cull_mesh, +) +from scene.cameras import Camera +from scene.gaussian_model import GaussianModel +from gaussian_renderer import render_full +from gaussian_renderer.radegs import render_radegs, integrate_radegs +from arguments import PipelineParams +from regularization.regularizer.mesh import ( + initialize_mesh_regularization, + compute_mesh_regularization, +) +from regularization.sdf.learnable import convert_occupancy_to_sdf +from utils.geometry_utils import flatten_voronoi_features + +# DISCOVER-SE 相机轨迹使用 OpenGL 右手坐标系(相机前方为 -Z,向上为 +Y), +# 而 MILo/colmap 渲染管线假设的是前方 +Z、向上 -Y。需要在读入时做一次轴翻转。 +OPENGL_TO_COLMAP = np.diag([1.0, -1.0, -1.0]).astype(np.float32) + + +class DepthProvider: + """负责加载并缓存 Discoverse 深度图,统一形状、裁剪和掩码。""" + + def __init__( + self, + depth_root: Path, + image_height: int, + image_width: int, + device: torch.device, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + self.depth_root = Path(depth_root) + if not self.depth_root.is_dir(): + raise FileNotFoundError(f"深度目录不存在:{self.depth_root}") + self.image_height = image_height + self.image_width = image_width + self.device = device + self.clip_min = clip_min + self.clip_max = clip_max + self._cache: Dict[int, torch.Tensor] = {} + self._mask_cache: Dict[int, torch.Tensor] = {} + + def _file_for_index(self, view_index: int) -> Path: + return self.depth_root / f"depth_img_0_{view_index}.npy" + + def _load_numpy(self, file_path: Path) -> np.ndarray: + depth_np = np.load(file_path) + depth_np = np.squeeze(depth_np) + if depth_np.ndim != 2: + raise ValueError(f"{file_path} 深度数组维度异常:{depth_np.shape}") + if depth_np.shape != (self.image_height, self.image_width): + raise ValueError( + f"{file_path} 深度分辨率应为 {(self.image_height, self.image_width)},当前为 {depth_np.shape}" + ) + if self.clip_min is not None or self.clip_max is not None: + min_val = self.clip_min if self.clip_min is not None else None + max_val = self.clip_max if self.clip_max is not None else None + depth_np = np.clip( + depth_np, + min_val if min_val is not None else depth_np.min(), + max_val if max_val is not None else depth_np.max(), + ) + return depth_np.astype(np.float32) + + def get(self, view_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """返回 (depth_tensor, valid_mask),均在 GPU 上。""" + if view_index not in self._cache: + file_path = self._file_for_index(view_index) + if not file_path.is_file(): + raise FileNotFoundError(f"缺少深度文件:{file_path}") + depth_np = self._load_numpy(file_path) + depth_tensor = torch.from_numpy(depth_np).to(self.device) + valid_mask = torch.isfinite(depth_tensor) & (depth_tensor > 0.0) + if self.clip_min is not None: + valid_mask &= depth_tensor >= self.clip_min + if self.clip_max is not None: + valid_mask &= depth_tensor <= self.clip_max + self._cache[view_index] = depth_tensor + self._mask_cache[view_index] = valid_mask + return self._cache[view_index], self._mask_cache[view_index] + + def as_numpy(self, view_index: int) -> np.ndarray: + depth_tensor, _ = self.get(view_index) + return depth_tensor.detach().cpu().numpy() + + +class NormalGroundTruthCache: + """缓存以初始高斯生成的法线 GT,避免训练阶段重复渲染。""" + + def __init__( + self, + cache_dir: Path, + image_height: int, + image_width: int, + device: torch.device, + ) -> None: + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.image_height = image_height + self.image_width = image_width + self.device = device + self._memory_cache: Dict[int, torch.Tensor] = {} + + def _file_path(self, view_index: int) -> Path: + return self.cache_dir / f"normal_view_{view_index:04d}.npy" + + def has(self, view_index: int) -> bool: + return self._file_path(view_index).is_file() + + def store(self, view_index: int, normal_tensor: torch.Tensor) -> None: + normals_np = prepare_normals(normal_tensor) + np.save(self._file_path(view_index), normals_np.astype(np.float16)) + + def get(self, view_index: int) -> torch.Tensor: + if view_index not in self._memory_cache: + path = self._file_path(view_index) + if not path.is_file(): + raise FileNotFoundError( + f"未找到视角 {view_index} 的法线缓存:{path},请先完成预计算。" + ) + normals_np = np.load(path) + expected_shape = (self.image_height, self.image_width, 3) + if normals_np.shape != expected_shape: + raise ValueError( + f"{path} 法线缓存尺寸应为 {expected_shape},当前为 {normals_np.shape}" + ) + normals_tensor = ( + torch.from_numpy(normals_np.astype(np.float32)) + .permute(2, 0, 1) + .to(self.device) + ) + self._memory_cache[view_index] = normals_tensor + return self._memory_cache[view_index] + + def clear_memory_cache(self) -> None: + self._memory_cache.clear() + + def ensure_all(self, cameras: Sequence[Camera], render_fn) -> None: + total = len(cameras) + for idx, camera in enumerate(cameras): + if self.has(idx): + continue + with torch.no_grad(): + pkg = render_fn(camera) + normal_tensor = pkg["normal"] + self.store(idx, normal_tensor) + if (idx + 1) % 10 == 0 or idx + 1 == total: + print(f"[INFO] 预计算法线缓存 {idx + 1}/{total}") + + +def compute_depth_loss_tensor( + pred_depth: torch.Tensor, + gt_depth: torch.Tensor, + gt_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """返回深度 L1 损失及统计信息。""" + if pred_depth.dim() == 3: + pred_depth = pred_depth.squeeze(0) + if pred_depth.shape != gt_depth.shape: + raise ValueError( + f"预测深度尺寸 {pred_depth.shape} 与 GT {gt_depth.shape} 不一致。" + ) + valid_mask = gt_mask & torch.isfinite(pred_depth) & (pred_depth > 0.0) + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_depth.device) + stats = {"valid_px": 0, "mae": float("nan"), "rmse": float("nan")} + return zero, stats + diff = pred_depth - gt_depth + abs_diff = diff.abs() + loss = abs_diff[valid_mask].mean() + rmse = torch.sqrt((diff[valid_mask] ** 2).mean()) + stats = { + "valid_px": valid_pixels, + "mae": float(abs_diff[valid_mask].mean().detach().item()), + "rmse": float(rmse.detach().item()), + } + return loss, stats + + +def compute_normal_loss_tensor( + pred_normals: torch.Tensor, + gt_normals: torch.Tensor, + base_mask: torch.Tensor, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """基于余弦相似度的法线损失。""" + if pred_normals.dim() != 3 or pred_normals.shape[0] != 3: + raise ValueError(f"预测法线维度应为 (3,H,W),当前为 {pred_normals.shape}") + if gt_normals.shape != pred_normals.shape: + raise ValueError( + f"法线 GT 尺寸 {gt_normals.shape} 与预测 {pred_normals.shape} 不一致。" + ) + gt_mask = torch.isfinite(gt_normals).all(dim=0) + pred_mask = torch.isfinite(pred_normals).all(dim=0) + valid_mask = base_mask & gt_mask & pred_mask + valid_pixels = int(valid_mask.sum().item()) + if valid_pixels == 0: + zero = torch.zeros((), device=pred_normals.device) + stats = {"valid_px": 0, "mean_cos": float("nan")} + return zero, stats + + pred_unit = F.normalize(pred_normals, dim=0) + gt_unit = F.normalize(gt_normals, dim=0) + cos_sim = (pred_unit * gt_unit).sum(dim=0).clamp(-1.0, 1.0) + loss_map = (1.0 - cos_sim) * valid_mask + loss = loss_map.sum() / valid_mask.sum() + stats = {"valid_px": valid_pixels, "mean_cos": float(cos_sim[valid_mask].mean().item())} + return loss, stats + + +class ManualScene: + """最小 Scene 封装,提供 mesh regularization 所需的接口。""" + + def __init__(self, cameras: Sequence[Camera]): + self._train_cameras = list(cameras) + + def getTrainCameras(self, scale: float = 1.0) -> List[Camera]: + return list(self._train_cameras) + + def getTrainCameras_warn_up( + self, + iteration: int, + warn_until_iter: int, + scale: float = 1.0, + scale2: float = 2.0, + ) -> List[Camera]: + return list(self._train_cameras) + + +def build_render_functions( + gaussians: GaussianModel, + pipe: PipelineParams, + background: torch.Tensor, +): + """构建训练/重建所需的渲染器:训练走 render_full 以获得精确深度梯度,SDF 仍沿用 RaDe-GS。""" + + def render_view(view: Camera) -> Dict[str, torch.Tensor]: + pkg = render_full( + viewpoint_camera=view, + pc=gaussians, + pipe=pipe, + bg_color=background, + compute_expected_normals=False, + compute_expected_depth=True, + compute_accurate_median_depth_gradient=True, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return pkg + + def render_for_sdf( + view: Camera, + gaussians_override: Optional[GaussianModel] = None, + pipeline_override: Optional[PipelineParams] = None, + background_override: Optional[torch.Tensor] = None, + kernel_size: float = 0.0, + require_depth: bool = True, + require_coord: bool = False, + ): + pc_obj = gaussians if gaussians_override is None else gaussians_override + pipe_obj = pipe if pipeline_override is None else pipeline_override + bg_color = background if background_override is None else background_override + pkg = render_radegs( + viewpoint_camera=view, + pc=pc_obj, + pipe=pipe_obj, + bg_color=bg_color, + kernel_size=kernel_size, + scaling_modifier=1.0, + require_coord=require_coord, + require_depth=require_depth, + ) + if "area_max" not in pkg: + pkg["area_max"] = torch.zeros_like(pkg["radii"]) + return { + "render": pkg["render"].detach(), + "median_depth": pkg["median_depth"].detach(), + } + + return render_view, render_for_sdf + + +def load_mesh_config(config_name: str) -> Dict[str, Any]: + """支持直接传文件路径或 configs/mesh/ .yaml。""" + candidate = Path(config_name) + if not candidate.is_file(): + candidate = Path(__file__).resolve().parent / "configs" / "mesh" / f"{config_name}.yaml" + if not candidate.is_file(): + raise FileNotFoundError(f"无法找到 mesh 配置:{config_name}") + with candidate.open("r", encoding="utf-8") as fh: + return yaml.safe_load(fh) + + +def ensure_gaussian_occupancy(gaussians: GaussianModel) -> None: + """mesh regularization 依赖 9 维 occupancy 网格,此处在推理环境补齐缓冲。""" + needs_init = ( + not getattr(gaussians, "learn_occupancy", False) + or not hasattr(gaussians, "_occupancy_shift") + or gaussians._occupancy_shift.numel() == 0 + ) + if needs_init: + gaussians.learn_occupancy = True + base = torch.zeros((gaussians._xyz.shape[0], 9), device=gaussians._xyz.device) + shift = torch.zeros_like(base) + gaussians._base_occupancy = nn.Parameter(base.requires_grad_(False), requires_grad=False) + gaussians._occupancy_shift = nn.Parameter(shift.requires_grad_(True)) + + +def export_mesh_from_state( + gaussians: GaussianModel, + mesh_state: Dict[str, Any], + output_path: Path, + reference_camera: Optional[Camera] = None, +) -> None: + """根据当前 mesh_state 导出网格,并可选做视椎裁剪。""" + gaussian_idx = mesh_state.get("delaunay_xyz_idx") + delaunay_tets = mesh_state.get("delaunay_tets") + + if delaunay_tets is None: + delaunay_tets = compute_delaunay_triangulation( + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + occupancy = ( + gaussians.get_occupancy + if gaussian_idx is None + else gaussians.get_occupancy[gaussian_idx] + ) + pivots_sdf = convert_occupancy_to_sdf(flatten_voronoi_features(occupancy)) + + mesh = extract_mesh( + delaunay_tets=delaunay_tets, + pivots_sdf=pivots_sdf, + means=gaussians.get_xyz, + scales=gaussians.get_scaling, + rotations=gaussians.get_rotation, + gaussian_idx=gaussian_idx, + ) + + mesh_to_save = mesh if reference_camera is None else frustum_cull_mesh(mesh, reference_camera) + verts = mesh_to_save.verts.detach().cpu().numpy() + faces = mesh_to_save.faces.detach().cpu().numpy() + trimesh.Trimesh(vertices=verts, faces=faces, process=False).export(output_path) + + +def quaternion_to_rotation_matrix(quaternion: Sequence[float]) -> np.ndarray: + """将单位四元数转换为 3x3 旋转矩阵。""" + # 这里显式转换 DISCOVERSE 导出的四元数,确保后续符合 MILo 的旋转约定 + q = np.asarray(quaternion, dtype=np.float64) + if q.shape != (4,): + raise ValueError(f"四元数需要包含 4 个分量,当前形状为 {q.shape}") + w, x, y, z = q + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + rotation = np.array( + [ + [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], + [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], + [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], + ] + ) + return rotation + + +def freeze_gaussian_model(model: GaussianModel) -> None: + """显式关闭高斯模型中参数的梯度。""" + # 推理阶段冻结高斯参数,后续循环只做前向评估 + tensor_attrs = [ + "_xyz", + "_features_dc", + "_features_rest", + "_scaling", + "_rotation", + "_opacity", + ] + for attr in tensor_attrs: + value = getattr(model, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + + +def prepare_depth_map(depth_tensor: torch.Tensor) -> np.ndarray: + """将深度张量转为二维 numpy 数组。""" + # 统一 squeeze 逻辑,防止 Matplotlib 因 shape 异常报错 + depth_np = depth_tensor.detach().cpu().numpy() + depth_np = np.squeeze(depth_np) + if depth_np.ndim == 1: + depth_np = np.expand_dims(depth_np, axis=0) + return depth_np + + +def prepare_normals(normal_tensor: torch.Tensor) -> np.ndarray: + """将法线张量转换为 HxWx3 的 numpy 数组。""" + # 兼容渲染输出为 (3,H,W) 或 (H,W,3) 的两种格式 + normals_np = normal_tensor.detach().cpu().numpy() + normals_np = np.squeeze(normals_np) + if normals_np.ndim == 3 and normals_np.shape[0] == 3: + normals_np = np.transpose(normals_np, (1, 2, 0)) + if normals_np.ndim == 2: + normals_np = normals_np[..., None] + return normals_np + + +def normals_to_rgb(normals: np.ndarray) -> np.ndarray: + """将 [-1,1] 范围的法线向量映射到 [0,1] 以便可视化。""" + normals = np.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0) + rgb = 0.5 * (normals + 1.0) + return np.clip(rgb, 0.0, 1.0).astype(np.float32) + + +def save_normal_visualization(normal_rgb: np.ndarray, output_path: Path) -> None: + """保存法线可视化图像。""" + plt.imsave(output_path, normal_rgb) + + +def load_cameras_from_json( + json_path: str, + image_height: int, + image_width: int, + fov_y_deg: float, +) -> List[Camera]: + """按照 MILo 的视图约定读取相机文件,并进行坐标系转换。""" + # 自定义读取 DISCOVERSE 风格 JSON,统一转换到 COLMAP 世界->相机坐标 + pose_path = Path(json_path) + if not pose_path.is_file(): + raise FileNotFoundError(f"未找到相机 JSON:{json_path}") + + with pose_path.open("r", encoding="utf-8") as fh: + camera_list = json.load(fh) + + if isinstance(camera_list, dict): + for key in ("frames", "poses", "camera_poses"): + if key in camera_list and isinstance(camera_list[key], list): + camera_list = camera_list[key] + break + else: + raise ValueError(f"{json_path} 中的 JSON 结构不包含可识别的相机列表。") + + if not isinstance(camera_list, list) or not camera_list: + raise ValueError(f"{json_path} 中没有有效的相机条目。") + + fov_y = math.radians(fov_y_deg) + aspect_ratio = image_width / image_height + fov_x = 2.0 * math.atan(aspect_ratio * math.tan(fov_y * 0.5)) + + cameras: List[Camera] = [] + for idx, entry in enumerate(camera_list): + if "quaternion" in entry: + rotation_c2w = quaternion_to_rotation_matrix(entry["quaternion"]).astype(np.float32) + elif "rotation" in entry: + rotation_c2w = np.asarray(entry["rotation"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 quaternion 或 rotation。") + + if rotation_c2w.shape != (3, 3): + raise ValueError(f"相机条目 {idx} 的旋转矩阵形状应为 (3,3),实际为 {rotation_c2w.shape}") + + # DISCOVER-SE 的 quaternion/rotation 直接导入后,渲染出来的 PNG 会上下翻转, + # 说明其前进方向仍是 OpenGL 的 -Z。通过右乘 diag(1,-1,-1) 将其显式转换到 + # MILo/colmap 的坐标系,使得后续投影矩阵与深度图一致。 + rotation_c2w = rotation_c2w @ OPENGL_TO_COLMAP + + if "position" in entry: + camera_center = np.asarray(entry["position"], dtype=np.float32) + if camera_center.shape != (3,): + raise ValueError(f"相机条目 {idx} 的 position 应为 3 维向量,实际为 {camera_center.shape}") + rotation_w2c = rotation_c2w.T + translation = (-rotation_w2c @ camera_center).astype(np.float32) + elif "translation" in entry: + translation = np.asarray(entry["translation"], dtype=np.float32) + # 如果 JSON 已直接存储 colmap 风格的 T(即世界到相机),这里假设它与旋转 + # 一样来自 OpenGL 坐标。严格来说也应执行同样的轴变换,但现有数据集只有 + # position 字段;为避免重复转换,这里只做类型检查并保留原值。 + elif "tvec" in entry: + translation = np.asarray(entry["tvec"], dtype=np.float32) + else: + raise KeyError(f"相机条目 {idx} 未提供 position/translation/tvec 信息。") + + if translation.shape != (3,): + raise ValueError(f"相机条目 {idx} 的平移向量应为长度 3,实际为 {translation.shape}") + + image_name = ( + entry.get("name") + or entry.get("image_name") + or entry.get("img_name") + or f"view_{idx:04d}" + ) + + camera = Camera( + colmap_id=str(idx), + R=rotation_c2w, + T=translation, + FoVx=fov_x, + FoVy=fov_y, + image=torch.zeros(3, image_height, image_width), + gt_alpha_mask=None, + image_name=image_name, + uid=idx, + data_device="cuda", + ) + cameras.append(camera) + return cameras + + +def save_heatmap(data: np.ndarray, output_path: Path, title: str) -> None: + """将二维数据保存为热力图,便于直观观察差异。""" + # 迭代间深度 / 法线差分可视化,快速定位局部变化 + plt.figure(figsize=(6, 4)) + plt.imshow(data, cmap="inferno") + plt.title(title) + plt.colorbar() + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + +def main(): + parser = ArgumentParser(description="桥梁场景高斯到网格迭代分析脚本") + parser.add_argument("--num_iterations", type=int, default=5, help="执行循环的次数") + parser.add_argument("--ma_beta", type=float, default=0.8, help="loss 滑动平均系数") + parser.add_argument( + "--depth_loss_weight", type=float, default=0.3, help="深度一致性项权重(默认 0.3)" + ) + parser.add_argument( + "--normal_loss_weight", type=float, default=0.05, help="法线一致性项权重(默认 0.05)" + ) + parser.add_argument( + "--lr", + "--learning_rate", + dest="lr", + type=float, + default=1e-3, + help="XYZ 学习率(默认 1e-3)", + ) + parser.add_argument( + "--shape_lr", + type=float, + default=5e-4, + help="缩放/旋转/不透明度的学习率(默认 5e-4)", + ) + parser.add_argument( + "--delaunay_reset_interval", + type=int, + default=1000, + help="每隔多少次迭代重建一次 Delaunay(<=0 表示每次重建)", + ) + parser.add_argument( + "--mesh_config", + type=str, + default="medium", + help="mesh 配置名称或路径(默认 medium)", + ) + parser.add_argument( + "--save_interval", + type=int, + default=None, + help="保存可视化/npz 的间隔,默认与 Delaunay 重建间隔相同", + ) + parser.add_argument( + "--heatmap_dir", + type=str, + default="yufu2mesh_outputs", + help="保存热力图等输出的目录", + ) + parser.add_argument( + "--depth_gt_dir", + type=str, + default="/home/zoyo/Desktop/MILo_rtx50/milo/data/bridge_clean/depth", + help="Discoverse 深度 npy 所在目录", + ) + parser.add_argument( + "--depth_clip_min", + type=float, + default=0.0, + help="深度最小裁剪值,<=0 表示不裁剪", + ) + parser.add_argument( + "--depth_clip_max", + type=float, + default=None, + help="深度最大裁剪值,None 表示不裁剪", + ) + parser.add_argument( + "--normal_cache_dir", + type=str, + default=None, + help="法线缓存目录,默认为 runs/ /normal_gt", + ) + parser.add_argument( + "--skip_normal_gt_generation", + action="store_true", + help="已存在缓存时跳过初始法线 GT 预计算", + ) + parser.add_argument("--seed", type=int, default=0, help="控制随机性的种子") + parser.add_argument( + "--lock_view_index", + type=int, + default=None, + help="固定视角索引,仅在指定时输出热力图", + ) + parser.add_argument( + "--log_interval", + type=int, + default=100, + help="控制迭代日志的打印频率(默认每次迭代打印)", + ) + parser.add_argument( + "--warn_until_iter", + type=int, + default=3000, + help="surface sampling warmup 迭代数(用于 mesh downsample)", + ) + parser.add_argument( + "--imp_metric", + type=str, + default="outdoor", + choices=["outdoor", "indoor"], + help="surface sampling 的重要性度量类型", + ) + parser.add_argument( + "--mesh_start_iter", + type=int, + default=2000, + help="mesh 正则起始迭代(默认 2000,避免冷启动阶段干扰)", + ) + parser.add_argument( + "--mesh_update_interval", + type=int, + default=5, + help="mesh 正则重建/回传间隔,>1 可减少 DMTet 抖动(默认 5)", + ) + parser.add_argument( + "--mesh_depth_weight", + type=float, + default=0.1, + help="mesh 深度项权重覆盖(默认 0.1,原配置通常为 0.05)", + ) + parser.add_argument( + "--mesh_normal_weight", + type=float, + default=0.1, + help="mesh 法线项权重覆盖(默认 0.1,原配置通常为 0.05)", + ) + parser.add_argument( + "--disable_shape_training", + action="store_true", + help="禁用缩放/旋转/不透明度的优化,仅用于调试", + ) + pipe = PipelineParams(parser) + args = parser.parse_args() + + pipe.debug = getattr(args, "debug", False) + + # 所有输出固定写入 milo/runs/ 下,便于管理实验产物 + base_run_dir = Path(__file__).resolve().parent / "runs" + output_dir = base_run_dir / args.heatmap_dir + output_dir.mkdir(parents=True, exist_ok=True) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + ply_path = "/home/zoyo/Desktop/MILo_rtx50/milo/data/bridge_clean/yufu_bridge_cleaned.ply" + camera_poses_json = "/home/zoyo/Desktop/MILo_rtx50/milo/data/bridge_clean/camera_poses_cam1.json" + + gaussians = GaussianModel(sh_degree=0, learn_occupancy=True) + gaussians.load_ply(ply_path) + freeze_attrs = [ + "_features_dc", + "_features_rest", + "_base_occupancy", + "_occupancy_shift", + ] + for attr in freeze_attrs: + value = getattr(gaussians, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + gaussians._xyz.requires_grad_(True) + + shape_trainable = [] + shape_attr_list = ["_scaling", "_rotation", "_opacity"] + if not args.disable_shape_training: + for attr in shape_attr_list: + value = getattr(gaussians, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(True) + shape_trainable.append(value) + else: + for attr in shape_attr_list: + value = getattr(gaussians, attr, None) + if isinstance(value, torch.Tensor): + value.requires_grad_(False) + + height = 720 + width = 1280 + fov_y_deg = 75.0 + + train_cameras = load_cameras_from_json( + json_path=camera_poses_json, + image_height=height, + image_width=width, + fov_y_deg=fov_y_deg, + ) + num_views = len(train_cameras) + print(f"[INFO] 成功加载 {num_views} 个相机视角。") + + device = gaussians._xyz.device + background = torch.tensor([0.0, 0.0, 0.0], device=device) + + mesh_config = load_mesh_config(args.mesh_config) + mesh_config["start_iter"] = max(0, args.mesh_start_iter) + mesh_config["stop_iter"] = max(mesh_config.get("stop_iter", args.num_iterations), args.num_iterations) + mesh_config["mesh_update_interval"] = max(1, args.mesh_update_interval) + mesh_config["delaunay_reset_interval"] = args.delaunay_reset_interval + mesh_config["depth_weight"] = args.mesh_depth_weight + mesh_config["normal_weight"] = args.mesh_normal_weight + # 这里默认沿用 surface 采样以对齐训练阶段;如仅需快速分析,也可以切换为 random 提升速度。 + mesh_config["delaunay_sampling_method"] = "surface" + + scene_wrapper = ManualScene(train_cameras) + + ensure_gaussian_occupancy(gaussians) + if gaussians.spatial_lr_scale <= 0: + gaussians.spatial_lr_scale = 1.0 + gaussians.set_occupancy_mode(mesh_config.get("occupancy_mode", "occupancy_shift")) + + render_view, render_for_sdf = build_render_functions(gaussians, pipe, background) + + depth_clip_min = args.depth_clip_min if args.depth_clip_min > 0.0 else None + depth_clip_max = args.depth_clip_max + depth_provider = DepthProvider( + depth_root=Path(args.depth_gt_dir), + image_height=height, + image_width=width, + device=device, + clip_min=depth_clip_min, + clip_max=depth_clip_max, + ) + + normal_cache_dir = Path(args.normal_cache_dir) if args.normal_cache_dir else (output_dir / "normal_gt") + normal_cache = NormalGroundTruthCache( + cache_dir=normal_cache_dir, + image_height=height, + image_width=width, + device=device, + ) + if args.skip_normal_gt_generation: + missing = [idx for idx in range(num_views) if not normal_cache.has(idx)] + if missing: + raise RuntimeError( + f"跳过法线 GT 预计算被拒绝,仍有 {len(missing)} 个视角缺少缓存(示例 {missing[:5]})。" + ) + else: + print("[INFO] 开始预计算初始法线 GT(仅进行一次,若存在缓存会自动跳过)。") + normal_cache.ensure_all(train_cameras, render_view) + normal_cache.clear_memory_cache() + + mesh_renderer, mesh_state = initialize_mesh_regularization(scene_wrapper, mesh_config) + mesh_state["reset_delaunay_samples"] = True + mesh_state["reset_sdf_values"] = True + + param_groups = [{"params": [gaussians._xyz], "lr": args.lr}] + if shape_trainable: + param_groups.append({"params": shape_trainable, "lr": args.shape_lr}) + optimizer = torch.optim.Adam(param_groups) + mesh_args = SimpleNamespace( + warn_until_iter=args.warn_until_iter, + imp_metric=args.imp_metric, + ) + + # 记录整个迭代过程中的指标与梯度,结束时统一写入 npz/曲线 + stats_history: Dict[str, List[float]] = { + "iteration": [], + "depth_loss": [], + "normal_loss": [], + "mesh_loss": [], + "mesh_depth_loss": [], + "mesh_normal_loss": [], + "occupied_centers_loss": [], + "occupancy_labels_loss": [], + "depth_mae": [], + "depth_rmse": [], + "normal_mean_cos": [], + "normal_valid_px": [], + "grad_norm": [], + } + + moving_loss = None + previous_depth: Dict[int, np.ndarray] = {} + previous_normals: Dict[int, np.ndarray] = {} + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + save_interval = args.save_interval if args.save_interval is not None else args.delaunay_reset_interval + if save_interval is None or save_interval <= 0: + save_interval = 1 + + for iteration in range(args.num_iterations): + optimizer.zero_grad(set_to_none=True) + if args.lock_view_index is not None: + view_index = args.lock_view_index % num_views + else: + if not camera_stack: + camera_stack = list(range(num_views)) + random.shuffle(camera_stack) + view_index = camera_stack.pop() + viewpoint = train_cameras[view_index] + + training_pkg = render_view(viewpoint) + gt_depth_tensor, gt_depth_mask = depth_provider.get(view_index) + depth_loss_tensor, depth_stats = compute_depth_loss_tensor( + pred_depth=training_pkg["median_depth"], + gt_depth=gt_depth_tensor, + gt_mask=gt_depth_mask, + ) + gt_normals_tensor = normal_cache.get(view_index) + normal_loss_tensor, normal_stats = compute_normal_loss_tensor( + pred_normals=training_pkg["normal"], + gt_normals=gt_normals_tensor, + base_mask=gt_depth_mask, + ) + def _zero_mesh_pkg() -> Dict[str, Any]: + zero = torch.zeros((), device=device) + depth_zero = torch.zeros_like(training_pkg["median_depth"]) + normal_zero = torch.zeros_like(training_pkg["normal"].permute(1, 2, 0)) + return { + "mesh_loss": zero, + "mesh_depth_loss": zero, + "mesh_normal_loss": zero, + "occupied_centers_loss": zero, + "occupancy_labels_loss": zero, + "updated_state": mesh_state, + "mesh_render_pkg": { + "depth": depth_zero, + "normals": normal_zero, + }, + } + + mesh_active = iteration >= mesh_config["start_iter"] + if mesh_active: + mesh_pkg = compute_mesh_regularization( + iteration=iteration, + render_pkg=training_pkg, + viewpoint_cam=viewpoint, + viewpoint_idx=view_index, + gaussians=gaussians, + scene=scene_wrapper, + pipe=pipe, + background=background, + kernel_size=0.0, + config=mesh_config, + mesh_renderer=mesh_renderer, + mesh_state=mesh_state, + render_func=render_for_sdf, + weight_adjustment=1.0, + args=mesh_args, + integrate_func=integrate_radegs, + ) + else: + mesh_pkg = _zero_mesh_pkg() + + mesh_state = mesh_pkg["updated_state"] + mesh_loss_tensor = mesh_pkg["mesh_loss"] + total_loss = ( + args.depth_loss_weight * depth_loss_tensor + + args.normal_loss_weight * normal_loss_tensor + + mesh_loss_tensor + ) + depth_loss_value = float(depth_loss_tensor.detach().item()) + normal_loss_value = float(normal_loss_tensor.detach().item()) + mesh_loss_value = float(mesh_loss_tensor.detach().item()) + loss_value = float(total_loss.detach().item()) + + if total_loss.requires_grad: + total_loss.backward() + grad_norm = float(gaussians._xyz.grad.detach().norm().item()) + optimizer.step() + else: + optimizer.zero_grad(set_to_none=True) + grad_norm = float("nan") + + mesh_render_pkg = mesh_pkg["mesh_render_pkg"] + mesh_depth_map = prepare_depth_map(mesh_render_pkg["depth"]) + mesh_normals_map = prepare_normals(mesh_render_pkg["normals"]) + gaussian_depth_map = prepare_depth_map(training_pkg["median_depth"]) + gaussian_normals_map = prepare_normals(training_pkg["normal"]) + gt_depth_map = depth_provider.as_numpy(view_index) + gt_normals_map = prepare_normals(gt_normals_tensor) + + mesh_valid = np.isfinite(mesh_depth_map) & (mesh_depth_map > 0.0) + gaussian_valid = np.isfinite(gaussian_depth_map) & (gaussian_depth_map > 0.0) + gt_valid = np.isfinite(gt_depth_map) & (gt_depth_map > 0.0) + overlap_mask = gaussian_valid & gt_valid + + depth_delta = gaussian_depth_map - gt_depth_map + if overlap_mask.any(): + delta_abs = np.abs(depth_delta[overlap_mask]) + diff_mean = float(delta_abs.mean()) + diff_max = float(delta_abs.max()) + diff_rmse = float(np.sqrt(np.mean(depth_delta[overlap_mask] ** 2))) + else: + diff_mean = diff_max = diff_rmse = float("nan") + + mesh_depth_loss = float(mesh_pkg["mesh_depth_loss"].item()) + mesh_normal_loss = float(mesh_pkg["mesh_normal_loss"].item()) + occupied_loss = float(mesh_pkg["occupied_centers_loss"].item()) + labels_loss = float(mesh_pkg["occupancy_labels_loss"].item()) + + moving_loss = ( + loss_value + if moving_loss is None + else args.ma_beta * moving_loss + (1 - args.ma_beta) * loss_value + ) + + stats_history["iteration"].append(float(iteration)) + stats_history["depth_loss"].append(depth_loss_value) + stats_history["normal_loss"].append(normal_loss_value) + stats_history["mesh_loss"].append(mesh_loss_value) + stats_history["mesh_depth_loss"].append(mesh_depth_loss) + stats_history["mesh_normal_loss"].append(mesh_normal_loss) + stats_history["occupied_centers_loss"].append(occupied_loss) + stats_history["occupancy_labels_loss"].append(labels_loss) + stats_history["depth_mae"].append(depth_stats["mae"]) + stats_history["depth_rmse"].append(depth_stats["rmse"]) + stats_history["normal_mean_cos"].append(normal_stats["mean_cos"]) + stats_history["normal_valid_px"].append(float(normal_stats["valid_px"])) + stats_history["grad_norm"].append(grad_norm) + + def _fmt(value: float) -> str: + return f"{value:.6f}" + + if (iteration % max(1, args.log_interval) == 0) or (iteration == args.num_iterations - 1): + print( + "[INFO] Iter {iter:02d} | loss={total} (depth={depth}, normal={normal}, mesh={mesh}) | ma_loss={ma}".format( + iter=iteration, + total=_fmt(loss_value), + depth=_fmt(depth_loss_value), + normal=_fmt(normal_loss_value), + mesh=_fmt(mesh_loss_value), + ma=f"{moving_loss:.6f}", + ) + ) + + should_save = (save_interval <= 0) or (iteration % save_interval == 0) + if should_save: + valid_values: List[np.ndarray] = [] + if mesh_valid.any(): + valid_values.append(mesh_depth_map[mesh_valid].reshape(-1)) + if gaussian_valid.any(): + valid_values.append(gaussian_depth_map[gaussian_valid].reshape(-1)) + if gt_valid.any(): + valid_values.append(gt_depth_map[gt_valid].reshape(-1)) + if valid_values: + all_valid = np.concatenate(valid_values) + shared_min = float(all_valid.min()) + shared_max = float(all_valid.max()) + else: + shared_min, shared_max = 0.0, 1.0 + + gaussian_depth_vis_path = output_dir / f"gaussian_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gaussian_depth_vis_path, + gaussian_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + depth_vis_path = output_dir / f"mesh_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + depth_vis_path, + mesh_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + gt_depth_vis_path = output_dir / f"gt_depth_vis_iter_{iteration:02d}.png" + plt.imsave( + gt_depth_vis_path, + gt_depth_map, + cmap="viridis", + vmin=shared_min, + vmax=shared_max, + ) + + normal_vis_path = output_dir / f"mesh_normal_vis_iter_{iteration:02d}.png" + mesh_normals_rgb = normals_to_rgb(mesh_normals_map) + save_normal_visualization(mesh_normals_rgb, normal_vis_path) + + gaussian_normal_vis_path = output_dir / f"gaussian_normal_vis_iter_{iteration:02d}.png" + gaussian_normals_rgb = normals_to_rgb(gaussian_normals_map) + save_normal_visualization(gaussian_normals_rgb, gaussian_normal_vis_path) + + gt_normal_vis_path = output_dir / f"gt_normal_vis_iter_{iteration:02d}.png" + gt_normals_rgb = normals_to_rgb(gt_normals_map) + save_normal_visualization(gt_normals_rgb, gt_normal_vis_path) + + output_npz = output_dir / f"mesh_render_iter_{iteration:02d}.npz" + np.savez( + output_npz, + mesh_depth=mesh_depth_map, + gaussian_depth=gaussian_depth_map, + depth_gt=gt_depth_map, + mesh_normals=mesh_normals_map, + gaussian_normals=gaussian_normals_map, + normal_gt=gt_normals_map, + depth_loss=depth_loss_value, + normal_loss=normal_loss_value, + mesh_loss=mesh_loss_value, + mesh_depth_loss=mesh_depth_loss, + mesh_normal_loss=mesh_normal_loss, + occupied_centers_loss=occupied_loss, + occupancy_labels_loss=labels_loss, + loss=loss_value, + moving_loss=moving_loss, + depth_mae=depth_stats["mae"], + depth_rmse=depth_stats["rmse"], + normal_valid_px=normal_stats["valid_px"], + normal_mean_cos=normal_stats["mean_cos"], + grad_norm=grad_norm, + iteration=iteration, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + if overlap_mask.any(): + depth_diff_vis = np.zeros_like(gaussian_depth_map) + depth_diff_vis[overlap_mask] = depth_delta[overlap_mask] + save_heatmap( + np.abs(depth_diff_vis), + output_dir / f"depth_diff_iter_{iteration:02d}.png", + f"|Pred-GT| iter {iteration}", + ) + + composite_path = output_dir / f"comparison_iter_{iteration:02d}.png" + fig, axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300) + ax_gt_depth, ax_gaussian_depth, ax_mesh_depth = axes[0] + ax_gt_normals, ax_gaussian_normals, ax_mesh_normals = axes[1] + + im0 = ax_gt_depth.imshow(gt_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gt_depth.set_title("GT depth") + ax_gt_depth.axis("off") + fig.colorbar(im0, ax=ax_gt_depth, fraction=0.046, pad=0.04) + + im1 = ax_gaussian_depth.imshow(gaussian_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_gaussian_depth.set_title("Gaussian depth") + ax_gaussian_depth.axis("off") + fig.colorbar(im1, ax=ax_gaussian_depth, fraction=0.046, pad=0.04) + + im2 = ax_mesh_depth.imshow(mesh_depth_map, cmap="viridis", vmin=shared_min, vmax=shared_max) + ax_mesh_depth.set_title("Mesh depth") + ax_mesh_depth.axis("off") + fig.colorbar(im2, ax=ax_mesh_depth, fraction=0.046, pad=0.04) + + ax_gt_normals.imshow(gt_normals_rgb) + ax_gt_normals.set_title("GT normals") + ax_gt_normals.axis("off") + + ax_gaussian_normals.imshow(gaussian_normals_rgb) + ax_gaussian_normals.set_title("Gaussian normals") + ax_gaussian_normals.axis("off") + + ax_mesh_normals.imshow(mesh_normals_rgb) + ax_mesh_normals.set_title("Mesh normals") + ax_mesh_normals.axis("off") + + info_lines = [ + f"Iteration: {iteration:02d}", + f"View index: {view_index}", + f"GT depth valid px: {int(gt_valid.sum())}", + f"Gaussian depth valid px: {int(gaussian_valid.sum())}", + f"|Pred - GT| mean={diff_mean:.3f}, max={diff_max:.3f}, RMSE={diff_rmse:.3f}", + f"Depth loss={_fmt(depth_loss_value)} (w={args.depth_loss_weight:.2f}, mae={depth_stats['mae']:.3f}, rmse={depth_stats['rmse']:.3f})", + f"Normal loss={_fmt(normal_loss_value)} (w={args.normal_loss_weight:.2f}, px={normal_stats['valid_px']}, cos={normal_stats['mean_cos']:.3f})", + f"Mesh loss={_fmt(mesh_loss_value)}", + f"Mesh depth loss={_fmt(mesh_depth_loss)} mesh normal loss={_fmt(mesh_normal_loss)}", + f"Occupied centers={_fmt(occupied_loss)} labels={_fmt(labels_loss)}", + ] + fig.suptitle("\n".join(info_lines), fontsize=12, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.94]) + fig.savefig(composite_path, dpi=300) + plt.close(fig) + + if args.lock_view_index is not None: + if view_index in previous_depth: + depth_diff = np.abs(gaussian_depth_map - previous_depth[view_index]) + save_heatmap( + depth_diff, + output_dir / f"depth_diff_iter_{iteration:02d}_temporal.png", + f"Depth Δ iter {iteration}", + ) + if view_index in previous_normals: + normal_delta = gaussian_normals_map - previous_normals[view_index] + if normal_delta.ndim == 3: + normal_diff = np.linalg.norm(normal_delta, axis=-1) + else: + normal_diff = np.abs(normal_delta) + save_heatmap( + normal_diff, + output_dir / f"normal_diff_iter_{iteration:02d}_temporal.png", + f"Normal Δ iter {iteration}", + ) + + with torch.no_grad(): + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=output_dir / f"mesh_iter_{iteration:02d}.ply", + reference_camera=None, + ) + + if args.lock_view_index is not None: + previous_depth[view_index] = gaussian_depth_map + previous_normals[view_index] = gaussian_normals_map + with torch.no_grad(): + # 输出完整指标轨迹及汇总曲线,方便任务结束后快速复盘 + history_npz = output_dir / "metrics_history.npz" + np.savez( + history_npz, + **{k: np.asarray(v, dtype=np.float32) for k, v in stats_history.items()}, + ) + summary_fig = output_dir / "metrics_summary.png" + if stats_history["iteration"]: + fig, axes = plt.subplots(2, 2, figsize=(16, 10), dpi=200) + iters = np.asarray(stats_history["iteration"]) + axes[0, 0].plot(iters, stats_history["depth_loss"], label="depth") + axes[0, 0].plot(iters, stats_history["normal_loss"], label="normal") + axes[0, 0].plot(iters, stats_history["mesh_loss"], label="mesh") + axes[0, 0].set_title("Total losses") + axes[0, 0].set_xlabel("Iteration") + axes[0, 0].legend() + + axes[0, 1].plot(iters, stats_history["mesh_depth_loss"], label="mesh depth") + axes[0, 1].plot(iters, stats_history["mesh_normal_loss"], label="mesh normal") + axes[0, 1].plot(iters, stats_history["occupied_centers_loss"], label="occupied centers") + axes[0, 1].plot(iters, stats_history["occupancy_labels_loss"], label="occupancy labels") + axes[0, 1].set_title("Mesh regularization components") + axes[0, 1].set_xlabel("Iteration") + axes[0, 1].legend() + + axes[1, 0].plot(iters, stats_history["depth_mae"], label="depth MAE") + axes[1, 0].plot(iters, stats_history["depth_rmse"], label="depth RMSE") + axes[1, 0].set_title("Depth metrics") + axes[1, 0].set_xlabel("Iteration") + axes[1, 0].legend() + + axes[1, 1].plot(iters, stats_history["normal_mean_cos"], label="mean cos") + axes[1, 1].plot(iters, stats_history["grad_norm"], label="grad norm") + axes[1, 1].set_title("Normals / Gradients") + axes[1, 1].set_xlabel("Iteration") + axes[1, 1].legend() + + fig.tight_layout() + fig.savefig(summary_fig) + plt.close(fig) + print(f"[INFO] 已保存曲线汇总:{summary_fig}") + print(f"[INFO] 记录所有迭代指标到 {history_npz}") + final_mesh_path = output_dir / "mesh_final.ply" + final_gaussian_path = output_dir / "gaussians_final.ply" + print(f"[INFO] 导出最终 mesh 到 {final_mesh_path}") + export_mesh_from_state( + gaussians=gaussians, + mesh_state=mesh_state, + output_path=final_mesh_path, + reference_camera=None, + ) + print(f"[INFO] 导出最终高斯到 {final_gaussian_path}") + gaussians.save_ply(str(final_gaussian_path)) + print("[INFO] 循环结束,所有结果已写入输出目录。") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 380ca65..ff9da00 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,19 @@ +# Core dependencies +torch==2.7.1 +torchvision==0.22.1 +torchaudio==2.7.1 +numpy==2.3.3 +pillow==11.3.0 +scipy==1.16.2 + +# Computer Vision and 3D processing open3d==0.19.0 trimesh==4.6.8 +pymeshlab==2025.7 scikit-image==0.24.0 opencv-python==4.11.0.86 plyfile==1.1 -tqdm==4.67.1 \ No newline at end of file + +# Utilities +tqdm==4.67.1 +matplotlib==3.10.7 \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h index 3ff1b54..344b618 100755 --- a/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h @@ -10,6 +10,9 @@ */ #pragma once +#include +#include +#include #include #include diff --git a/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h b/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h index 647f1a9..5e924fa 100644 --- a/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h +++ b/submodules/diff-gaussian-rasterization_gof/cuda_rasterizer/rasterizer_impl.h @@ -10,6 +10,9 @@ */ #pragma once +#include +#include +#include #include #include diff --git a/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h b/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h index 5c76881..2a1f616 100644 --- a/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h +++ b/submodules/diff-gaussian-rasterization_ms/cuda_rasterizer/rasterizer_impl.h @@ -10,6 +10,9 @@ */ #pragma once +#include +#include +#include #include #include diff --git a/submodules/simple-knn/simple_knn.cu b/submodules/simple-knn/simple_knn.cu index e72e4c9..b998aaf 100644 --- a/submodules/simple-knn/simple_knn.cu +++ b/submodules/simple-knn/simple_knn.cu @@ -1,3 +1,5 @@ +#include +#include /* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco @@ -20,7 +22,6 @@ #include #include #include -#define __CUDACC__ #include #include diff --git a/submodules/simple-knn/spatial.cu b/submodules/simple-knn/spatial.cu index 1a6a654..207fc3a 100644 --- a/submodules/simple-knn/spatial.cu +++ b/submodules/simple-knn/spatial.cu @@ -20,7 +20,7 @@ distCUDA2(const torch::Tensor& points) auto float_opts = points.options().dtype(torch::kFloat32); torch::Tensor means = torch::full({P}, 0.0, float_opts); - SimpleKNN::knn(P, (float3*)points.contiguous().data (), means.contiguous().data ()); + SimpleKNN::knn(P, (float3*)points.contiguous().data_ptr (), means.contiguous().data_ptr ()); return means; } \ No newline at end of file diff --git a/submodules/tetra_triangulation/src/force_abi.h b/submodules/tetra_triangulation/src/force_abi.h new file mode 100644 index 0000000..db5ecf3 --- /dev/null +++ b/submodules/tetra_triangulation/src/force_abi.h @@ -0,0 +1,38 @@ +// force_abi.h +// ----------------------------------------------------------------------------- +// Force GNU libstdc++ to use the *new* C++11 ABI (_GLIBCXX_USE_CXX11_ABI = 1) +// for this translation unit. **Must be included before ANY standard headers.** +// +// Why? +// - PyTorch ≥ 2.6 official binaries are built with the new C++11 ABI (=1). +// - If your extension is compiled with the old ABI (=0), you’ll hit runtime +// linker errors such as undefined symbol: c10::detail::torchCheckFail(...RKSs). +// +// Usage: +// Place this as the VERY FIRST include in each .cpp that builds your extension: +// #include "force_abi.h" +// #include +// ... +// ----------------------------------------------------------------------------- + +#pragma once + +// Only meaningful for GCC's libstdc++; harmless elsewhere. +#if defined(__GNUC__) && !defined(_LIBCPP_VERSION) + +// If the macro was already defined (e.g., by compiler flags), reset it first. +# ifdef _GLIBCXX_USE_CXX11_ABI +# undef _GLIBCXX_USE_CXX11_ABI +# endif +// Enforce the new (C++11) ABI. +# define _GLIBCXX_USE_CXX11_ABI 1 + +// Optional sanity check: if some libstdc++ internals are already visible, +// it likely means a standard header slipped in before this file. In that case +// overriding the ABI here won't affect those already-included headers. +# if defined(_GLIBCXX_RELEASE) || defined(__GLIBCXX__) || defined(_GLIBCXX_BEGIN_NAMESPACE_VERSION) +# warning "force_abi.h should be included BEFORE any standard library headers." +# endif + +#endif // defined(__GNUC__) && !defined(_LIBCPP_VERSION) + diff --git a/submodules/tetra_triangulation/src/py_binding.cpp b/submodules/tetra_triangulation/src/py_binding.cpp index 0b2d43f..8fafa7e 100755 --- a/submodules/tetra_triangulation/src/py_binding.cpp +++ b/submodules/tetra_triangulation/src/py_binding.cpp @@ -1,3 +1,4 @@ +#include "force_abi.h" #include #include diff --git a/submodules/tetra_triangulation/src/triangulation.cpp b/submodules/tetra_triangulation/src/triangulation.cpp index 8044295..f083085 100755 --- a/submodules/tetra_triangulation/src/triangulation.cpp +++ b/submodules/tetra_triangulation/src/triangulation.cpp @@ -1,3 +1,4 @@ +#include "force_abi.h" #include "triangulation.h" #include @@ -66,4 +67,4 @@ std::vector triangulate(size_t num_points, float3* points) { // 0, max_depth); return cells; -} \ No newline at end of file +}