diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..57c8efb --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +data/* +output/* +__pycache__/ +*.pyc +*.pyo +*.pyd \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..4b5a294 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda" +} \ No newline at end of file diff --git a/PROJECT_SUMMARY.md b/PROJECT_SUMMARY.md new file mode 100644 index 0000000..1653f13 --- /dev/null +++ b/PROJECT_SUMMARY.md @@ -0,0 +1,599 @@ +# FastGS / TD-FastGS — 项目全面总结 + +> **FastGS**: CVPR 2026 论文 — 将 3D Gaussian Splatting 训练加速至 **100 秒** +> **TD-FastGS**: 在本仓库中实现的时域扩展,支持 **动态场景 4D Gaussian Splatting** + +--- + +## 1. 项目概述 + +FastGS 是一个通用的 3DGS 加速框架,在原始 3D Gaussian Splatting 基础上引入三项关键改进: +- **多视图一致性致密化 (VCD)** +- **多视图一致性剪枝 (VCP)** +- **Compact Box (CB)** + +本仓库 (`fast4dgs`) 在 FastGS 基础上扩展了 **TD-FastGS**(移植自 TD-4DGS 的时域机制),支持动态场景的 4D Gaussian Splatting 训练与渲染。 + +--- + +## 2. 目录结构 + +``` +fast4dgs/ +├── train.py # 主训练脚本(3D + 4D 双入口) +├── render.py # 渲染脚本(生成测试/训练视图图像) +├── convert.py # COLMAP 数据转换脚本 +├── metrics.py # 评估指标计算(PSNR, SSIM, LPIPS) +├── export_frames.py # 4DGS: 逐帧导出高斯点云 PLY +├── slim_ply.py # 4DGS: 对 point_cloud.ply 瘦身(仅保留播放器所需属性) +├── full_eval.py # 批量完整评估脚本 +├── train_base.sh # FastGS 标准模式训练脚本 +├── train_big.sh # FastGS-Big(高质量)模式训练脚本 +├── environment.yml # Conda 环境配置 +├── README.md # 项目说明文档 +├── prompt.md # TD-FastGS 实现指导文档 +│ +├── arguments/ +│ └── __init__.py # 命令行参数定义(ModelParams, PipelineParams, OptimizationParams) +│ +├── scene/ +│ ├── __init__.py # Scene 类:场景管理与数据加载分发 +│ ├── gaussian_model.py # GaussianModel 核心类(含 4D 时域扩展) +│ ├── cameras.py # Camera 类(含惰性图像加载、LRU缓存) +│ ├── dataset_readers.py # 数据读取器(COLMAP, Blender, COLMAP4D) +│ ├── colmap_loader.py # COLMAP 二进制/文本格式读取 +│ +├── gaussian_renderer/ +│ ├── __init__.py # render_fastgs(3D)和 render_4d(4D)渲染函数 +│ ├── network_gui.py # GUI 可视化服务器(旧版) +│ ├── network_gui_ws.py # GUI 可视化服务器(WebSocket) +│ +├── utils/ +│ ├── fast_utils.py # FastGS 核心工具函数(VCD/VCP评分、4D采样等) +│ ├── loss_utils.py # 损失函数(L1, L2, SSIM) +│ ├── image_utils.py # 图像质量指标(PSNR, MSE) +│ ├── general_utils.py # 通用工具(逆sigmoid, 学习率调度, 四元数等) +│ ├── graphics_utils.py # 图形学工具(投影矩阵, FOV计算等) +│ ├── camera_utils.py # 相机加载工具(惰性路径、分辨率计算) +│ ├── sh_utils.py # 球谐函数工具 +│ ├── system_utils.py # 系统工具(文件IO等) +│ +├── submodules/ +│ ├── diff-gaussian-rasterization_fastgs/ # FastGS 自定义 CUDA 光栅化核 +│ ├── fused-ssim/ # 快速 SSIM CUDA 扩展 +│ └── simple-knn/ # KNN CUDA 扩展(用于初始化) +│ +├── tests/ +│ └── test_td_fastgs.py # TD-FastGS 单元测试 +│ +├── lpipsPyTorch/ # LPIPS 感知损失(第三方) +│ +├── memory/ +│ ├── MEMORY.md # 内存/数据格式索引 +│ └── flower300-data-format.md # flower300 数据集格式说明 +│ +├── data/ +│ └── flower300/ # 示例 4D 数据集(36 相机 × 300 帧) +│ +└── output/ # 训练输出目录 +``` + +--- + +## 3. 核心工作流程 + +### 3.1 整体流程图 + +```mermaid +graph TD + A[输入数据] --> B{3D 还是 4D?} + B -->|3D| C1[Scene.init 调用 Colmap/Blender 读取器] + B -->|4D| C2[Scene.init 调用 Colmap4D 读取器] + C1 --> D1[create_from_pcd 初始化 3D 高斯] + C2 --> D2[create_from_pcd_4d 初始化 4D 高斯] + D1 --> E1[training 3D训练循环] + D2 --> E2[training_4d 4D训练循环] + E1 --> F1[render_fastgs 3D渲染] + E2 --> F2[render_4d 时空渲染] + F1 --> G1[VCD/VCP 致密化+剪枝] + F2 --> G2[时域感知 VCD/VCP] + G1 --> H1[保存 PLY] + G2 --> H2[保存 4D PLY] + H1 --> I[render.py 渲染评估] + H2 --> I + I --> J[metrics.py 指标计算] + H2 --> K[export_frames.py 逐帧导出] + H2 --> L[slim_ply.py 瘦身] +``` + +### 3.2 训练流程 + +训练入口为 `train.py`,参数解析后调用 `training()` 函数。函数自动检测是否为 4DGS 数据集(检测 `static_points/` 和 `dynamic_points/` 目录),然后分派到对应的训练循环: + +#### 3D 训练循环 (`training`) +1. 初始化 `GaussianModel` 与 `Scene` +2. 每轮迭代: + - 随机采样一个训练视角 + - 调用 `render_fastgs` 渲染 + - 计算 L1 + (1-SSIM) 损失 + - 反向传播 + - **致密化阶段** (iter < 15000): + - 多视图一致性的 clone/split (VCD) + - 每 3000 步清洗不透明度 + - **后致密化阶段** (iter 15000~30000, 每 3000 步): + - 多视图一致性剪枝 (VCP) + - 优化器步进 +3. 保存 PLY 点云 + +#### 4D 训练循环 (`training_4d`) +1. 初始化 `GaussianModel`(含时域属性) +2. 每轮迭代: + - **3 阶段相机采样**: + - Stage 1 (≤3000): 仅采样 frame-0(收敛静态背景) + - Stage 2 (≤10000): 滑动窗口采样(4 帧连续窗口) + - Stage 3 (>10000): 全局均匀采样 + - 调用 `render_4d` 时空渲染(含因果剪枝) + - 损失: L1 + (1-SSIM) + λ_v * L_smooth + λ_scale * scale_penalty + - 反向传播后调用**三级梯度闸门** + - 时域感知 VCD/VCP(静态/动态点分别使用不同阈值) + - **解耦不透明度重置**(仅静态点) + - 优化器步进后调用**静态硬拉回** + +### 3.3 渲染流程 + +#### 3D 渲染 (`render_fastgs`) +1. 设置光栅化配置(投影矩阵、SH 度数等) +2. 提取高斯属性(位置、颜色、不透明度、缩放、旋转) +3. 调用 CUDA 光栅化器生成图像 +4. 返回渲染图像、屏幕空间点、可见性过滤、半径、度量计数 + +#### 4D 渲染 (`render_4d`) +1. **时空变换**: `x'(t) = x₀ + v·(t - t_μ)`, 计算时域权重 `w_t` +2. **因果剪枝**: 仅保留 `t_μ ≤ t` 且 `α·w_t > τ_alive` 的高斯子集 +3. 在存活子集上执行 CUDA 光栅化(Compact Box 自动在 kernel 内运行) +4. **回填**: 将子集结果(半径、度量计数)回填到全尺寸张量 + +--- + +## 4. 核心算法原理 + +### 4.1 FastGS 三项改进 + +| 改进项 | 符号 | 说明 | +|-------|------|------| +| **VCD** | $s^i_d = \frac{1}{K}\sum_j\sum_{p\in\Omega_i}\mathbb{I}(M^j_{mask}(p)=1)$ | 多视图一致性致密化,仅当跨 K 视图的高误差像素计数均值 > τ_d(默认 5)时才 clone/split | +| **VCP** | $s^i_p = \mathcal{N}(\sum_j(\sum_{p\in\Omega_i}\mathbb{I}(M^j_{mask}(p)=1))\cdot E^j_{photo})$ | 多视图一致性剪枝,当 $s^i_p > \tau_p$(默认 0.9)时删除 | +| **CB** | $(\mathbf{p}-\mu)\Sigma^{-1}(\mathbf{p}-\mu)^T \leq \beta(2\ln\frac{\sigma}{\tau_\alpha})$ | Compact Box:用马氏距离代替 3-sigma 规则减少 Gaussian-tile 对 | + +### 4.2 TD-FastGS 时域扩展 + +每个高斯基元附加 **5 个时域属性**: + +| 属性 | 符号 | 可学习 | 语义 | +|------|------|--------|------| +| 出生时间 | $t_\mu$ | ❌(冻结) | 锚定于 SfM 帧索引,归一化到 [0,1] | +| 生命半径(log空间) | $\sigma_{t,raw}$ | ✅ | $\sigma_t = e^{\sigma_{t,raw}}$ | +| 运动速度 | $\mathbf{v} \in \mathbb{R}^3$ | ✅(动态点)/ ❌(静态点锁死) | | + +**时空变换**: +$$\mathbf{x}'(t) = \mathbf{x}_0 + \mathbf{v} \cdot (t - t_\mu)$$ +$$\alpha'_i(t) = \alpha_i \cdot \underbrace{\exp\left(-\frac{(t - t_\mu^{(i)})^2}{2\sigma_t^{(i)2} + \epsilon}\right)}_{w_t^{(i)}(t)}$$ + +**因果存活条件**: +$$\text{alive}(i, t) = \mathbb{1}[t_\mu^{(i)} \leq t] \wedge \mathbb{1}[\alpha'_i(t) > \tau_{alive}], \quad \tau_{alive} = 0.005$$ + +### 4.3 三级梯度闸门(4D) + +反向传播后、优化器步进前执行: +1. **静态点**: velocity 和 sigma_t_raw 梯度清零 +2. **动态 & 当前帧** ($w_t > \text{thresh}$): 所有梯度通过 +3. **动态 & 其他帧**: 几何梯度(xyz, features, scaling, rotation)清零,opacity/velocity/sigma_t 保留 + +### 4.4 速度初始化(从光流估计 3D 速度) + +TD-FastGS 支持从**光流**(2D optical flow)自动估计每个动态点的初始 3D 速度,作为高斯属性 $\mathbf{v}$ 的初始值。 + +#### 数据准备 + +光流文件放在 `flows//.npy` 目录下,每个文件是一个 `(Hf, Wf, 2)` 的 NumPy 数组,存储每个像素在流图坐标系中的 `(du, dv)` 位移。 + +#### 算法流程 (`load_flow_velocities`) + +对于每一个有动态点的帧 $f$: + +```mermaid +flowchart TD + A[动态点云 3D 位置
points_world] --> B[投影到每个相机
p_cam = Rc@p_world + Tc] + B --> C{深度 > 0.01?} + C -->|是| D[投影到像素坐标
u,v = fx*x/z+cx, fy*y/z+cy] + D --> E[缩放到流图分辨率
uf, vf = u*Wf/W_cam, v*Hf/H_cam] + E --> F{在流图边界内?} + F -->|是| G[双线性采样光流
→ flow_u, flow_v] + G --> H[f+1 帧像素位置
u1 = u + flow_u, v1 = v + flow_v] + H --> I[反投影两个像素到相机空间射线
→ 3D 位移方向] + I --> J[深度缩放 → 度量位移
dir0 *= depth, dir1 *= depth] + J --> K[相机空间位移→世界空间
disp_world = Rw @ (dir1 - dir0)] + K --> L[除以帧间隔 Δt
→ 速度估计] + L --> M[多相机平均
→ 最终 3D 速度] +``` + +**关键公式**: + +1. **投影到像素**: $u = f_x \cdot \frac{x_c}{z} + c_x$, $v = f_y \cdot \frac{y_c}{z} + c_y$ +2. **双线性采光流**: $flow = \sum_{i=0}^1\sum_{j=0}^1 w_{ij} \cdot flow_{v_i,u_j}$ +3. **反投影 + 深度缩放**: $\mathbf{dir} = [\frac{u - c_x}{f_x}, \frac{v - c_y}{f_y}, 1] \cdot z$ +4. **速度**: $\mathbf{v} = \frac{\mathbf{dir}_1 - \mathbf{dir}_0}{\Delta t}$ + +#### 数据流 + +``` +flows//.npy # 2D 光流输入 + ↓ +load_flow_velocities() # 多视图三角化 → 3D 速度 + ↓ +load_temporal_point_cloud_pcd() # 组装 TemporalPointCloud.velocities + ↓ +TemporalPointCloud # 传入 create_from_pcd_4d() + ↓ +GaussianModel._velocity # 作为可学习参数初始化 +``` + +#### 回退机制 + +- 若无 `flows/` 目录 → velocity 初始化为零 +- 最后一帧无后续帧光流 → velocity 为零 +- 投影到所有相机均不在视野内 → velocity 为零 +- 所有静态点 velocity 恒为零(之后由 `enforce_static_constraints` 强制保持) + +### 4.5 速度平滑正则化 + +$$L_{smooth} = \frac{1}{K}\sum_k w_k \cdot ||\mathbf{v}_{a_k} - \mathbf{v}_{b_k}||^2$$ +$$w_k = \exp\left(-\frac{||\mathbf{x}_{a_k} - \mathbf{x}_{b_k}||^2}{2\bar{s}^2}\right)$$ + +随机采样 K=4096 对动态点,以空间距离加权约束速度一致性。 + +--- + +## 5. 完整参数列表 + +### 5.1 模型参数 (ModelParams) + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--sh_degree` | 3 | 球谐函数阶数(≤3) | +| `--source_path` / `-s` | "" | 数据源路径 | +| `--model_path` / `-m` | "" | 模型输出路径(默认 output/\) | +| `--images` / `-i` | "images" | 图像子目录名 | +| `--resolution` / `-r` | -1 | 分辨率控制(1/2/4/8 为比例,-1 自动缩放到 1.6K,其他为指定宽度) | +| `--white_background` / `-w` | False | 使用白色背景(NeRF 合成数据集) | +| `--data_device` | "cuda" | 数据存放设备(大数据集建议 "cpu") | +| `--eval` | False | 使用 MipNeRF360 风格训练/测试拆分 | +| `--force_4dgs` | False | 强制使用 4DGS 数据读取器 | +| `--n_frames` | -1 | 时序帧数(-1 自动推断) | + +### 5.2 管线参数 (PipelineParams) + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--convert_SHs_python` | False | 用 PyTorch 计算 SH 前向/反向 | +| `--compute_cov3D_python` | False | 用 PyTorch 计算 3D 协方差 | +| `--debug` | False | 调试模式 | +| `--antialiasing` | False | 抗锯齿 | + +### 5.3 优化参数 (OptimizationParams) + +#### 学习率 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--position_lr_init` | 0.00016 | 位置初始学习率 | +| `--position_lr_final` | 0.0000016 | 位置最终学习率 | +| `--position_lr_delay_mult` | 0.01 | 位置学习率延迟乘数 | +| `--position_lr_max_steps` | 30000 | 位置学习率最大步数 | +| `--feature_lr` | 0.0025 | 特征学习率(旧版兼容) | +| `--lowfeature_lr` | 0.0025 | 低阶 SH 系数 (features_dc) 学习率 | +| `--highfeature_lr` | 0.005 | 高阶 SH 系数 (features_rest) 学习率 | +| `--opacity_lr` | 0.025 | 不透明度学习率 | +| `--scaling_lr` | 0.005 | 缩放学习率 | +| `--rotation_lr` | 0.001 | 旋转学习率 | + +#### FastGS 专有参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--loss_thresh` | 0.1 | 损失图阈值(越低保留越多高斯) | +| `--grad_abs_thresh` | 0.0012 | 绝对梯度阈值(split 判断) | +| `--grad_thresh` | 0.0002 | 梯度阈值(clone 判断) | +| `--dense` | 0.001 | 场景范围的百分比,超过则强制致密化 | +| `--mult` | 0.5 | Compact Box 乘数,控制每个 splat 的 tile 数 | +| `--densification_interval` | 100 | 致密化间隔(步数) | +| `--densify_from_iter` | 500 | 开始致密化的迭代 | +| `--densify_until_iter` | 15000 | 结束致密化的迭代 | +| `--opacity_reset_interval` | 3000 | 不透明度重置间隔 | +| `--lambda_dssim` | 0.2 | SSIM 损失权重 | +| `--percent_dense` | 0.001 | 密集点百分比 | +| `--random_background` | False | 随机背景颜色 | +| `--optimizer_type` | "default" | 优化器类型(default / sparse_adam) | + +#### TD-FastGS 4D 时域参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--velocity_lr` | 0.0016 | 速度学习率 | +| `--sigma_t_lr` | 0.002 | 时域宽度(log空间)学习率 | +| `--lambda_velocity` | 0.01 | 速度平滑正则化权重 λ_v | +| `--velocity_smooth_pairs` | 4096 | 速度平滑采样的点对数 | +| `--tau_alive` | 0.005 | 因果剪枝不透明度阈值 | +| `--tau_d_static` | 5.0 | 静态点致密化(VCD)阈值 | +| `--tau_d_dynamic` | 2.5 | 动态点致密化(VCD)阈值 | +| `--tau_p` | 0.9 | 剪枝(VCP)阈值 | +| `--wt_densify_thresh` | 0.2 | 致密化/剪枝的活跃窗口阈值 w_t | +| `--wt_current_thresh` | 0.5 | 梯度闸门的"当前帧"阈值 | +| `--static_only_until` | 3000 | Stage 1 边界:之前仅采样 frame-0 | +| `--temporal_window_until` | 10000 | Stage 2 边界:之前滑动窗口采样 | +| `--temporal_window_size` | 4 | Stage 2 滑动窗口宽度(帧数) | +| `--lambda_scale_penalty` | 0.0 | 动态点缩放惩罚权重(0=关闭) | + +### 5.4 其他参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--iterations` | 30000 | 总训练迭代数 | +| `--test_iterations` | 7000 30000 | 测试迭代列表 | +| `--save_iterations` | 7000 30000 \ | 保存迭代列表 | +| `--checkpoint_iterations` | — | 检查点保存迭代 | +| `--start_checkpoint` | — | 恢复训练的检查点路径 | +| `--debug_from` | — | 开始调试的迭代 | +| `--quiet` | — | 静默模式 | +| `--ip` | 127.0.0.1 | GUI 服务器 IP | +| `--port` | 6009 | GUI 服务器端口 | + +### 5.5 后处理工具参数 + +#### `slim_ply.py` +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `-i` / `--input` | point_cloud.ply | 输入 PLY 路径 | +| `-o` / `--output` | point_cloud_slim.ply | 输出 PLY 路径 | +| `-n` / `--num-frames` | 交互输入 | 动画总帧数 | +| `--vel-threshold` | 0.001 | 速度模长阈值,低于此值视为静态高斯(0 禁用分离) | + +#### `export_frames.py` +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--ply` | 必填 | 训练好的 point_cloud.ply 路径 | +| `--out` | 必填 | 逐帧 PLY 输出目录 | +| `--num_frames` | 80 | 导出帧数 | +| `--threshold` | 0.005 | 剪枝不透明度阈值 | + +#### `render.py` +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--iteration` | -1(最大) | 渲染迭代 | +| `--skip_train` | False | 跳过训练集渲染 | +| `--skip_test` | False | 跳过测试集渲染 | +| `--mult` | 0.5 | Compact Box 乘数 | + +#### `convert.py` +| 参数 | 说明 | +|------|------| +| `-s` / `--source_path` | 源路径(必填) | +| `--camera` | 相机模型(默认 OPENCV) | +| `--colmap_executable` | COLMAP 可执行文件路径 | +| `--resize` | 是否生成缩放图像(2×/4×/8×) | +| `--skip_matching` | 跳过特征提取与匹配 | + +--- + +## 6. 数据格式 + +### 6.1 3D 数据集格式 + +#### COLMAP 格式 +``` +dataset/ +├── sparse/0/ +│ ├── cameras.bin (.txt) # 相机内参 +│ ├── images.bin (.txt) # 图像外参 +│ └── points3D.bin (.txt) # SfM 点云 +├── images/ # 输入图像 +└── input/ # (可选) convert.py 原始输入 +``` + +#### Blender/NeRF 合成格式 +``` +dataset/ +├── transforms_train.json # 训练集变换 +├── transforms_test.json # 测试集变换 +└── points3d.ply # 初始点云(无 COLMAP 时自动生成随机点) +``` + +### 6.2 4D 数据集格式 (flower300 布局) + +``` +dataset/ +├── sparse/0/ +│ ├── cameras.txt # 相机标定(36 个固定视角) +│ ├── images.txt # 图像名(1.png~36.png = 相机 ID) +│ └── points3D.txt # 可能为空(点云来自下面 PLY) +├── images/ +│ ├── / +│ │ └── images/ +│ │ ├── 1.png # 相机 1 在该帧的图像 +│ │ ├── 2.png +│ │ └── ... # ~36 张/帧 +│ ├── ... # frames 1~N(如 300) +├── static_points/ +│ └── pcd1.ply # 静态背景点云(~17k 点, t_mu=0) +├── dynamic_points/ +│ ├── pcd.ply # 逐帧动态点云(~2300 点/帧) +├── flows/ # (可选) 光流用于初始化 3D 速度 + ├── / # 帧目录,如 1/, 2/, ... + │ ├── 1.npy # 相机 1 的 2D 光流 (Hf, Wf, 2) + │ ├── 2.npy + │ └── ... # 每个相机一个 .npy 文件 +``` + +### 6.3 4D PLY 属性 + +| 属性 | 类型 | 说明 | +|------|------|------| +| `x, y, z` | float | 位置 | +| `f_dc_0, f_dc_1, f_dc_2` | float | 低阶 SH 系数(DC) | +| `f_rest_0..n` | float | 高阶 SH 系数 | +| `opacity` | float | 不透明度(原始值,需 sigmoid) | +| `scale_0, scale_1, scale_2` | float | 缩放(log 空间) | +| `rot_0, rot_1, rot_2, rot_3` | float | 旋转四元数 | +| `t_mu` | float | 出生时间(归一化到 [0,1]) | +| `sigma_t_raw` | float | 生命半径(log 空间) | +| `vel_x, vel_y, vel_z` | float | 运动速度 | +| `is_static` | float | 是否为静态点(0 或 1) | + +`slim_ply.py` 输出使用 DT-4DGS 标准命名(`t_sigma`, `velocity_0/1/2`)。 + +--- + +## 7. 训练模式 + +### 7.1 标准模式 (FastGS) + +`train_base.sh` — `--densification_interval 500` +- 约 100 秒完成训练(RTX 4090) +- 适合快速原型和大多数场景 + +### 7.2 高质量模式 (FastGS-Big) + +`train_big.sh` — `--densification_interval 100` +- 使用 `--mode final_count`,为每个场景预设高斯数量上限 +- 场景特定的 `big_budgets`(如 bicycle: 598万, room: 154万) +- 训练时间略长,质量更高 + +### 7.3 Budget 模式 + +`full_eval.py --mode budget` +- 使用 `--mode multiplier` 和场景特定的预算乘数 +- 用于批量评估 + +--- + +## 8. 完整评估流程 + +``` +1. 数据准备: python convert.py -s # COLMAP SfM +2. 训练: python train.py -s -m [options] +3. 渲染: python render.py -m [--skip_train] [--mult x] +4. 指标: python metrics.py -m +``` + +### 4D 后处理流程 +``` +5. 瘦身: python slim_ply.py -i point_cloud.ply -o slim.ply -n +6. 逐帧导出: python export_frames.py --ply point_cloud.ply --out frames/ --num_frames +``` + +### 批量完整评估 +``` +python full_eval.py -m360 -tat -db [--mode big|budget] +``` + +--- + +## 9. 关键类图 + +```mermaid +classDiagram + class GaussianModel { + - _xyz, _features_dc, _features_rest + - _scaling, _rotation, _opacity + - is_4d: bool + - _t_mu, _sigma_t_raw, _velocity + - is_static: Tensor(bool) + + create_from_pcd() + + create_from_pcd_4d() + + load_ply() / save_ply() + + training_setup() + + compute_temporal_weight(t) + + densify_and_prune_fastgs() + + densify_and_prune_4d() + + apply_gradient_gating() + + enforce_static_constraints() + + reset_opacity_decoupled() + } + + class Scene { + - train_cameras, test_cameras + - is_4dgs: bool + - n_frames: int + + getTrainCameras() + + getTestCameras() + + save(iteration) + } + + class Camera { + + timestamp: float + + frame_idx: int + + original_image (lazy loading) + + world_view_transform + + full_proj_transform + } + + class SceneInfo { + - point_cloud: BasicPointCloud + - temporal_point_cloud: TemporalPointCloud + - train_cameras, test_cameras + - n_frames: int + } + + class TemporalPointCloud { + - points, colors, normals + - timestamps, is_static + - velocities (optional) + } + + Scene --> GaussianModel + Scene --> Camera + Scene --> SceneInfo + SceneInfo --> TemporalPointCloud + GaussianModel ..> Camera : render +``` + +--- + +## 10. 依赖与环境 + +| 组件 | 版本 | +|------|------| +| Python | 3.7.13 | +| PyTorch | 1.12.1 | +| CUDA Toolkit | 11.6 | +| torchvision | 0.13.1 | +| torchaudio | 0.12.1 | +| CUDA 编译器 (nvcc) | ≥ 11.8 | +| 系统 CUDA (nvidia-smi) | ≥ 12.2 | + +**子模块**(需编译的 CUDA 扩展): +- `diff-gaussian-rasterization_fastgs` — FastGS 自定义光栅化器 +- `simple-knn` — KNN 距离计算 +- `fused-ssim` — 快速 SSIM + +--- + +## 11. 性能指标 + +| 数据集 | FastGS 训练时间 | FastGS-Big 训练时间 | 对比 3DGS | +|--------|:-:|:-:|:-:| +| Mip-NeRF 360 (室外) | ~100s | ~2min | 3.32× DashGaussian | +| Mip-NeRF 360 (室内) | ~80s | ~1.5min | — | +| Deep Blending | ~60s | ~1min | 15.45× 加速 | +| Tanks&Temples | ~70s | ~1.5min | — | + +--- + +## 12. 注意事项 + +1. **静态点初始化**: `sigma_t_raw = log(1000)` → 时域权重 ≈ 1,全场可见 +2. **动态点初始化**: `sigma_t_raw = log(2.5/n_frames)` → 覆盖约 2.5 帧 +3. **速度初始化**: 优先从 `flows//` 目录下的 2D 光流 `.npy` 文件估算。若无光流则速度初始化为零,训练过程中由速度平滑正则化驱动收敛 +4. **静态硬拉回**: 每步优化后,静态点的 `velocity=0`、`sigma_t_raw=log(1000)`、`t_mu=0`,并清空 Adam 动量 +5. **LRU 图像缓存**: Camera 类使用有界 LRU 缓存(默认 64 张),避免 4D 数据集大量图像同时驻留 GPU +6. **窗口系统兼容性**: 终端使用 Windows PowerShell,不支持 `&&` 链式命令 diff --git a/arguments/__init__.py b/arguments/__init__.py index 52888ae..800600f 100755 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -44,7 +44,7 @@ def extract(self, args): setattr(group, arg[0], arg[1]) return group -class ModelParams(ParamGroup): +class ModelParams(ParamGroup): def __init__(self, parser, sentinel=False): self.sh_degree = 3 self._source_path = "" @@ -54,6 +54,10 @@ def __init__(self, parser, sentinel=False): self._white_background = False self.data_device = "cuda" self.eval = False + # TD-FastGS 4D extension. When force_4dgs is True the scene loader always + # uses the 4DGS reader; otherwise it is auto-detected from the dataset layout. + self.force_4dgs = False + self.n_frames = -1 # number of temporal frames; -1 => infer from data super().__init__(parser, "Loading Parameters", sentinel) def extract(self, args): @@ -101,6 +105,22 @@ def __init__(self, parser): self.random_background = False self.optimizer_type = "default" + + # ----- TD-FastGS 4D (temporal) parameters ----- + self.velocity_lr = 0.0016 # learning rate for per-Gaussian velocity v + self.sigma_t_lr = 0.002 # learning rate for sigma_t_raw (life radius, log space) + self.lambda_velocity = 0.01 # weight of the velocity-smoothness regularizer (lambda_v) + self.velocity_smooth_pairs = 4096 # number of point-pairs sampled for L_smooth + self.tau_alive = 0.005 # causal pruning threshold on alpha'(t) + self.tau_d_static = 5.0 # densification (VCD) threshold for static points + self.tau_d_dynamic = 2.5 # densification (VCD) threshold for dynamic points + self.tau_p = 0.9 # pruning (VCP) threshold + self.wt_densify_thresh = 0.2 # w_t active-window threshold used for densify/prune gating + self.wt_current_thresh = 0.5 # w_t "current frame" threshold for the gradient gate + self.static_only_until = 3000 # stage-1 boundary: sample only frame 0 before this + self.temporal_window_until = 10000 # stage-2 boundary: sliding-window sampling before this + self.temporal_window_size = 4 # sliding-window width (frames) in stage 2 + self.lambda_scale_penalty = 0.0 # soft scale penalty weight for dynamic points (0 => off) super().__init__(parser, "Optimization Parameters") def get_combined_args(parser : ArgumentParser): diff --git a/arguments/__pycache__/__init__.cpython-37.pyc b/arguments/__pycache__/__init__.cpython-37.pyc deleted file mode 100755 index 99f1f03..0000000 Binary files a/arguments/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/arguments/__pycache__/__init__.cpython-38.pyc b/arguments/__pycache__/__init__.cpython-38.pyc deleted file mode 100755 index a3969f7..0000000 Binary files a/arguments/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/export_frames.py b/export_frames.py new file mode 100644 index 0000000..1cd9187 --- /dev/null +++ b/export_frames.py @@ -0,0 +1,146 @@ +""" +从训练好的 4DGS point_cloud.ply 中,按帧导出每一帧的高斯点云。 + +兼容两种属性命名: + TD-FastGS (本项目): sigma_t_raw, vel_x/vel_y/vel_z, is_static + DT-4DGS (原始): t_sigma, velocity_0/1/2 + +对于每个帧时间 t,执行: + 1. 位置偏移: xyz_t = xyz + velocity * (t - t_mu) + 2. 时域高斯权重: w = exp(- (t - t_mu)^2 / (2 * sigma^2 + 1e-5)) + 3. 不透明度调制: opacity_t = opacity * w + 4. 因果律: 仅保留 t_mu <= t 且 opacity_t > threshold 的点 + +用法: +python export_frames.py --ply --out --num_frames 80 +python export_frames.py --ply output/flower_1/point_cloud/iteration_30000/point_cloud.ply --out output/flower_1/frames --num_frames 80 +""" + +import argparse +import os +import numpy as np +from plyfile import PlyData, PlyElement + + +def load_4dgs_ply(path): + plydata = PlyData.read(path) + v = plydata["vertex"] + all_props = [p.name for p in v.properties] + + xyz = np.stack([v["x"], v["y"], v["z"]], axis=1) # (N, 3) + + # opacity (raw, pre-sigmoid) + opacity_raw = np.asarray(v["opacity"]) # (N,) + + # t_mu always present + t_mu = np.asarray(v["t_mu"]) # (N,) + + # Temporal-width field: TD-FastGS saves "sigma_t_raw"; DT-4DGS saves "t_sigma". + if "sigma_t_raw" in all_props: + t_sigma_raw = np.asarray(v["sigma_t_raw"]) + elif "t_sigma" in all_props: + t_sigma_raw = np.asarray(v["t_sigma"]) + else: + raise ValueError("PLY has neither 'sigma_t_raw' nor 't_sigma' — not a 4DGS file?") + + # Velocity: TD-FastGS uses vel_x/y/z; DT-4DGS uses velocity_0/1/2. + if "vel_x" in all_props: + velocity = np.stack([v["vel_x"], v["vel_y"], v["vel_z"]], axis=1) + elif "velocity_0" in all_props: + velocity = np.stack([v["velocity_0"], v["velocity_1"], v["velocity_2"]], axis=1) + else: + raise ValueError("PLY has no velocity attributes — not a 4DGS file?") + + return plydata, xyz, opacity_raw, t_mu, t_sigma_raw, velocity, all_props + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-np.clip(x, -80, 80))) + + +def export_frame(plydata, xyz, opacity_raw, t_mu, t_sigma_raw, velocity, frame_t, threshold=0.005): + """返回该帧存活点的索引、偏移后的 xyz、调制后的 opacity_raw""" + dt = frame_t - t_mu # (N,) + xyz_t = xyz + velocity * dt[:, None] # (N, 3) + + sigma = np.exp(t_sigma_raw) # (N,) + temporal_weight = np.exp(-(dt ** 2) / (2.0 * sigma ** 2 + 1e-5)) # (N,) + + opacity_activated = sigmoid(opacity_raw) # (N,) + opacity_t = opacity_activated * temporal_weight # (N,) + + # 因果律剪枝 + alive = (t_mu <= frame_t) & (opacity_t > threshold) + + # 调制后的 opacity 转回 raw (inverse sigmoid) + opacity_t_clamped = np.clip(opacity_t, 1e-7, 1.0 - 1e-7) + opacity_t_raw = np.log(opacity_t_clamped / (1.0 - opacity_t_clamped)) + + return alive, xyz_t, opacity_t_raw + + +def save_frame_ply(plydata, alive_mask, xyz_t, opacity_t_raw, output_path): + """将存活点写成标准 3DGS PLY(去掉 4DGS 专属字段)""" + src = plydata["vertex"] + + # Skip all 4DGS-specific fields regardless of naming convention. + skip = { + "t_mu", + "sigma_t_raw", "t_sigma", # TD-FastGS / DT-4DGS temporal width + "vel_x", "vel_y", "vel_z", # TD-FastGS velocity + "velocity_0", "velocity_1", "velocity_2", # DT-4DGS velocity + "is_static", # TD-FastGS static flag + } + keep_names = [p.name for p in src.properties if p.name not in skip] + src_dtype = src.data.dtype + dtype_out = [(name, src_dtype[name].str) for name in keep_names] + + n_alive = int(alive_mask.sum()) + elements = np.empty(n_alive, dtype=dtype_out) + + for name in keep_names: + col = np.asarray(src[name])[alive_mask] + if name == "x": + col = xyz_t[alive_mask, 0] + elif name == "y": + col = xyz_t[alive_mask, 1] + elif name == "z": + col = xyz_t[alive_mask, 2] + elif name == "opacity": + col = opacity_t_raw[alive_mask] + elements[name] = col + + el = PlyElement.describe(elements, "vertex") + PlyData([el]).write(output_path) + + +def main(): + parser = argparse.ArgumentParser(description="Export per-frame PLY from a 4DGS checkpoint") + parser.add_argument("--ply", type=str, required=True, help="Path to the trained point_cloud.ply") + parser.add_argument("--out", type=str, required=True, help="Output directory for per-frame PLY files") + parser.add_argument("--num_frames", type=int, default=80, help="Number of frames to export") + parser.add_argument("--threshold", type=float, default=0.005, help="Opacity threshold for pruning") + args = parser.parse_args() + + os.makedirs(args.out, exist_ok=True) + + print(f"Loading PLY: {args.ply}") + plydata, xyz, opacity_raw, t_mu, t_sigma_raw, velocity, all_props = load_4dgs_ply(args.ply) + print(f"Total gaussians: {xyz.shape[0]}") + + for i in range(args.num_frames): + frame_t = i / max(args.num_frames - 1, 1) + alive, xyz_t, opacity_t_raw = export_frame( + plydata, xyz, opacity_raw, t_mu, t_sigma_raw, velocity, frame_t, args.threshold + ) + n_alive = int(alive.sum()) + + out_path = os.path.join(args.out, f"{i + 1}.ply") + save_frame_ply(plydata, alive, xyz_t, opacity_t_raw, out_path) + print(f"Frame {i + 1}/{args.num_frames} t={frame_t:.4f} alive={n_alive} -> {out_path}") + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index 0d584cc..c300c96 100755 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -107,4 +107,142 @@ def render_fastgs(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.T "viewspace_points": screenspace_points, "visibility_filter" : (radii > 0).nonzero(), "radii": radii, - "accum_metric_counts" : accum_metric_counts} \ No newline at end of file + "accum_metric_counts" : accum_metric_counts} + + +def render_4d(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, mult, + scaling_modifier = 1.0, override_color = None, get_flag=None, metric_map = None): + """TD-FastGS 4D render. + + Mandatory ordering: + 1. Spatio-temporal transform: translate centers to the current frame. + 2. Causal pruning: build the alive_mask sparse subset. + 3. Compact Box + rasterization on the alive subset only (FastGS CB runs + inside the CUDA kernel, so subsetting the inputs is what guarantees CB + is computed only over alive Gaussians). + 4. Back-fill per-Gaussian outputs (radii, metric counts) to full size so the + FastGS VCD/VCP statistics keep operating on full-size tensors. + + The full-size `screenspace_points` is indexed to form the subset means2D; the + rasterizer backward therefore scatters the screen-space gradient back into the + full-size tensor, so add_densification_stats works exactly as in the 3D path. + """ + t = float(viewpoint_camera.timestamp) + + # --- Step 1: spatio-temporal transform (kept in graph; w_t feeds sigma_t_raw) --- + w_t = pc.compute_temporal_weight(t) # (N,) + dt = t - pc._t_mu # (N,) + xyz_transformed = pc.get_xyz + pc._velocity * dt.unsqueeze(-1) # (N, 3) + opacity_eff = pc.get_opacity.squeeze(-1) * w_t # (N,) + + N = pc.get_xyz.shape[0] + + # --- Step 2: causal pruning (boolean mask; no grad needed for the mask) --- + with torch.no_grad(): + causal_mask = pc._t_mu <= (t + 1e-6) # (N,) bool + alive_mask = causal_mask & (opacity_eff > pc.tau_alive) + alive_idx = alive_mask.nonzero(as_tuple=False).squeeze(-1) + + # Full-size screen-space tensor; subset rows receive grad via index backward. + screenspace_points = torch.zeros((N, 4), dtype=pc.get_xyz.dtype, + requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except Exception: + pass + + radii_full = torch.zeros(N, dtype=torch.int, device="cuda") + accum_full = torch.zeros(N, dtype=torch.int, device="cuda") + + if alive_idx.numel() == 0: + # Nothing alive at this time: return a background image and empty stats. + H, W = int(viewpoint_camera.image_height), int(viewpoint_camera.image_width) + rendered_image = bg_color.view(3, 1, 1).expand(3, H, W).contiguous() + return {"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter": (radii_full > 0).nonzero(), + "radii": radii_full, + "accum_metric_counts": accum_full, + "w_t": w_t, + "alive_mask": alive_mask} + + # Rasterization configuration. + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + if metric_map is None: + metric_map = torch.zeros(int(viewpoint_camera.image_height) * int(viewpoint_camera.image_width), + dtype=torch.int, device='cuda') + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + mult=mult, + prefiltered=False, + debug=pipe.debug, + get_flag=get_flag, + metric_map=metric_map + ) + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + # --- Step 3: extract the alive subset of every per-Gaussian input --- + means3D = xyz_transformed[alive_idx] + means2D = screenspace_points[alive_idx] # grad scatters back to full size + opacity = opacity_eff[alive_idx].unsqueeze(-1) + + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier)[alive_idx] + else: + scales = pc.get_scaling[alive_idx] + rotations = pc.get_rotation[alive_idx] + + dc = None + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2)[alive_idx] + dir_pp = (xyz_transformed[alive_idx] - viewpoint_camera.camera_center.repeat(alive_idx.shape[0], 1)) + dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + dc = pc.get_features_dc[alive_idx] + shs = pc.get_features_rest[alive_idx] + else: + colors_precomp = override_color[alive_idx] + + rendered_image, radii_sparse, accum_sparse = rasterizer( + means3D=means3D, + means2D=means2D, + dc=dc, + shs=shs, + colors_precomp=colors_precomp, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=cov3D_precomp) + + # --- Step 4: back-fill per-Gaussian outputs to full size --- + radii_full[alive_idx] = radii_sparse + if accum_sparse is not None and accum_sparse.numel() == alive_idx.numel(): + accum_full[alive_idx] = accum_sparse.to(accum_full.dtype) + + return {"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter": (radii_full > 0).nonzero(), + "radii": radii_full, + "accum_metric_counts": accum_full, + "w_t": w_t, + "alive_mask": alive_mask} \ No newline at end of file diff --git a/gaussian_renderer/__pycache__/__init__.cpython-37.pyc b/gaussian_renderer/__pycache__/__init__.cpython-37.pyc deleted file mode 100755 index b1827fe..0000000 Binary files a/gaussian_renderer/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/gaussian_renderer/__pycache__/__init__.cpython-38.pyc b/gaussian_renderer/__pycache__/__init__.cpython-38.pyc deleted file mode 100755 index 71f0a7d..0000000 Binary files a/gaussian_renderer/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/gaussian_renderer/__pycache__/network_gui_ws.cpython-37.pyc b/gaussian_renderer/__pycache__/network_gui_ws.cpython-37.pyc deleted file mode 100755 index 2753eb6..0000000 Binary files a/gaussian_renderer/__pycache__/network_gui_ws.cpython-37.pyc and /dev/null differ diff --git a/gaussian_renderer/__pycache__/network_gui_ws.cpython-38.pyc b/gaussian_renderer/__pycache__/network_gui_ws.cpython-38.pyc deleted file mode 100755 index 964edaf..0000000 Binary files a/gaussian_renderer/__pycache__/network_gui_ws.cpython-38.pyc and /dev/null differ diff --git a/lpipsPyTorch/__pycache__/__init__.cpython-37.pyc b/lpipsPyTorch/__pycache__/__init__.cpython-37.pyc deleted file mode 100755 index 26ba86b..0000000 Binary files a/lpipsPyTorch/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/lpipsPyTorch/__pycache__/__init__.cpython-38.pyc b/lpipsPyTorch/__pycache__/__init__.cpython-38.pyc deleted file mode 100755 index d806ad7..0000000 Binary files a/lpipsPyTorch/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/lpipsPyTorch/modules/__pycache__/lpips.cpython-37.pyc b/lpipsPyTorch/modules/__pycache__/lpips.cpython-37.pyc deleted file mode 100755 index 4661736..0000000 Binary files a/lpipsPyTorch/modules/__pycache__/lpips.cpython-37.pyc and /dev/null differ diff --git a/lpipsPyTorch/modules/__pycache__/lpips.cpython-38.pyc b/lpipsPyTorch/modules/__pycache__/lpips.cpython-38.pyc deleted file mode 100755 index 7459afb..0000000 Binary files a/lpipsPyTorch/modules/__pycache__/lpips.cpython-38.pyc and /dev/null differ diff --git a/lpipsPyTorch/modules/__pycache__/networks.cpython-37.pyc b/lpipsPyTorch/modules/__pycache__/networks.cpython-37.pyc deleted file mode 100755 index 9098a37..0000000 Binary files a/lpipsPyTorch/modules/__pycache__/networks.cpython-37.pyc and /dev/null differ diff --git a/lpipsPyTorch/modules/__pycache__/networks.cpython-38.pyc b/lpipsPyTorch/modules/__pycache__/networks.cpython-38.pyc deleted file mode 100755 index 3daf375..0000000 Binary files a/lpipsPyTorch/modules/__pycache__/networks.cpython-38.pyc and /dev/null differ diff --git a/lpipsPyTorch/modules/__pycache__/utils.cpython-37.pyc b/lpipsPyTorch/modules/__pycache__/utils.cpython-37.pyc deleted file mode 100755 index 7e0e32b..0000000 Binary files a/lpipsPyTorch/modules/__pycache__/utils.cpython-37.pyc and /dev/null differ diff --git a/lpipsPyTorch/modules/__pycache__/utils.cpython-38.pyc b/lpipsPyTorch/modules/__pycache__/utils.cpython-38.pyc deleted file mode 100755 index 9c55060..0000000 Binary files a/lpipsPyTorch/modules/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/memory/MEMORY.md b/memory/MEMORY.md new file mode 100644 index 0000000..2dac2b4 --- /dev/null +++ b/memory/MEMORY.md @@ -0,0 +1,3 @@ +# Memory Index + +- [flower300 data format](flower300-data-format.md) — multi-view-video 4D dataset layout (36 cams x 300 frames, pcd.ply), differs from prompt.md spec diff --git a/memory/flower300-data-format.md b/memory/flower300-data-format.md new file mode 100644 index 0000000..6a3f53d --- /dev/null +++ b/memory/flower300-data-format.md @@ -0,0 +1,16 @@ +--- +name: flower300-data-format +description: Layout of the user's flower300 multi-view-video 4D dataset +metadata: + type: project +--- +The user's TD-FastGS data (`flower300/`, also the target real-data format) is a multi-view video, NOT the `points3D.ply`/`frame_*.ply` layout assumed by the original prompt.md spec: + +- `sparse/0/{cameras,images,points3D}.txt` — COLMAP calibration for 36 fixed cameras (PINHOLE, 3839x2159). `images.txt` names them `1.png`..`36.png` (these are CAMERA ids, not frames). `points3D.txt` is EMPTY — no COLMAP point cloud; init comes only from the PLYs below. +- `images//images/.png` — frames 1..300, each folder has 36 cam PNGs (some frames also have redundant `.jpg` duplicates that must be ignored). Training cameras = cross product 36 cams x 300 frames = 10,800 images. +- `static_points/pcd1.ply` — ~17k static background pts (t_mu=0). +- `dynamic_points/pcd.ply` — frames 1..300, ~2300 pts each, born at that frame's time. + +Timestamp normalization MUST be identical for cameras and dynamic PLYs: `t = (frame_id - fmin)/(fmax - fmin)`, fmin=1 fmax=300. Static => t=0. + +10,800 full-res images cannot be eagerly loaded to GPU — Camera needs lazy disk loading with a bounded LRU CPU cache. See [[fast4dgs-td-implementation]]. diff --git a/prompt.md b/prompt.md new file mode 100644 index 0000000..827307c --- /dev/null +++ b/prompt.md @@ -0,0 +1,756 @@ +# TD-FastGS 实现提示词 +> 提供给 Claude Opus 4.8 (claude-opus-4-8) 的完整实现指导文档 +> 任务:以 FastGS 为基础代码,移植 TD-4DGS 的时域机制,实现高效的 4D Gaussian Splatting + +--- + +## 你的任务 + +你是一名计算机视觉领域的资深工程师,需要将一套名为 **TD-4DGS** 的时域扩展机制移植到 **FastGS** 代码库中,实现 **TD-FastGS**:一个在保持 FastGS 训练加速优势的同时,支持动态场景重建的 4D Gaussian Splatting 系统。 + +请严格按照本文档逐步实现,不要跳过任何小节,不要做出文档未要求的"顺手优化"。每完成一个模块后,写出该模块的单元测试代码(不需要运行,只需写出测试逻辑)。 + +--- + +## 背景知识 + +### FastGS 的核心机制 + +FastGS 在原版 3DGS 的基础上做了三项改进: + +**1. 多视图一致性致密化(VCD)**:不再用图像空间梯度幅值判断是否致密化,而是随机采样 K 个视角,对每个视角生成逐像素 L1 误差图(min-max 归一化),提取高误差掩码,统计每个高斯在其 2D footprint 内跨多视角的高误差像素均值作为致密化重要性分数: + +$$s^i_d = \frac{1}{K} \sum_{j=1}^{K} \sum_{p \in \Omega_i} \mathbb{I}\left(M^j_{mask}(p) = 1\right)$$ + +仅当 $s^i_d > \tau_d$(默认=5)时才执行 clone/split。 + +**2. 多视图一致性剪枝(VCP)**:结合逐视角光度损失 $E^j_{photo}$,计算剪枝分数: + +$$s^i_p = \mathcal{N}\left(\sum_{j=1}^{K} \left(\sum_{p \in \Omega_i} \mathbb{I}\left(M^j_{mask}(p) = 1\right)\right) \cdot E^j_{photo}\right)$$ + +当 $s^i_p > \tau_p$(默认=0.9)时删除该高斯。 + +**3. Compact Box(CB)**:在光栅化预处理阶段,用 Mahalanobis 距离阈值代替 3-sigma 规则,进一步减少 Gaussian-tile 对: + +$$(\mathbf{p} - \mu_{i_{2D}}) \Sigma^{-1}_{i_{2D}} (\mathbf{p} - \mu_{i_{2D}})^T \leq \beta \left(2\ln\frac{\sigma_i}{\tau_\alpha}\right)$$ + +FastGS 基于 `3DGS-accel`(集成了 Taming-3DGS 的 per-splat 并行反传和 SH 加速),默认 30K 迭代,K=10,λ=0.2,NVIDIA RTX 4090 上约 100 秒完成训练。 + +--- + +### TD-4DGS 的时域机制 + +TD-4DGS 在每个高斯基元上附加 **5 个显式时域标量**: + +| 属性 | 符号 | 可学习 | 语义 | +|------|------|--------|------| +| 出生时间 | $t_\mu$ | **否** | 锚定于 SfM 帧索引,归一化到 [0,1] | +| 生命半径(log空间) | $\sigma_{t,raw}$ | 是 | $\sigma_t = e^{\sigma_{t,raw}}$ | +| 运动速度 | $\mathbf{v} \in \mathbb{R}^3$ | 是(动态点),锁死(静态点) | + +时空变换: +$$\mathbf{x}'(t) = \mathbf{x}_0 + \mathbf{v} \cdot (t - t_\mu)$$ +$$\alpha'_i(t) = \alpha_i \cdot \underbrace{\exp\left(-\frac{(t - t_\mu^{(i)})^2}{2\sigma_t^{(i)2} + \epsilon}\right)}_{w_t^{(i)}(t)}$$ + +因果存活条件: +$$\text{alive}(i, t) = \mathbb{1}\left[t_\mu^{(i)} \leq t\right] \wedge \mathbb{1}\left[\alpha'_i(t) > \tau_{alive}\right], \quad \tau_{alive} = 0.005$$ + +每帧仅光栅化满足 alive 条件的稀疏子集。 + +--- + +## 数据格式约定 + +输入数据结构: +``` +dataset/ +├── static_points/ # 背景 SfM 点云(单帧或多帧均值) +│ └── points3D.ply +├── dynamic_points/ # 逐帧前景 SfM 点云 +│ ├── frame_0000.ply +│ ├── frame_0001.ply +│ └── ... +├── images/ +│ ├── cam_00/ +│ │ ├── frame_0000.jpg +│ │ └── ... +│ └── cam_01/ +│ └── ... +└── sparse/ # COLMAP 标定结果(cameras.bin, images.bin) +``` + +- $N_{cam}$:固定视角相机数(如 36) +- $N_{frame}$:时序帧数(如 80-150) +- 每帧每视角一张图像,时间戳归一化为 $t \in [0, 1]$ + +--- + +## 文件修改清单 + +需要修改的文件(以 FastGS 代码库为基础): + +| 文件 | 修改性质 | +|------|---------| +| `scene/gaussian_model.py` | **核心扩展**:时域属性注册、初始化、PLY 序列化、时域感知 ADC | +| `gaussian_renderer/__init__.py` | **渲染扩展**:时空变换、因果剪枝、alive_mask 回填 | +| `train.py` | **训练逻辑**:梯度闸门、静态硬拉回、解耦 opacity reset、3 阶段采样、时域感知 VCD/VCP | +| `scene/__init__.py` | 场景管理:4DGS 自动检测、解耦点云加载 | +| `scene/dataset_readers.py` | 数据读取:多帧场景读取器、时间戳赋值 | +| `scene/cameras.py` | 相机对象:timestamp 属性 | + +--- + +## 模块一:高斯模型时域扩展(`gaussian_model.py`) + +### 1.1 时域属性注册 + +在 `GaussianModel.__init__` 中新增以下张量,与现有属性并列注册到 optimizer 参数组: + +```python +# 时域属性(仅动态点有效,静态点锁死) +self._t_mu = torch.empty(0) # shape (N,),出生时间,不可学习 +self._sigma_t_raw = torch.empty(0) # shape (N,),生命半径 log 空间,可学习 +self._velocity = torch.empty(0) # shape (N, 3),运动速度,可学习(动态点) +self.is_static = torch.empty(0, dtype=torch.bool) # shape (N,),身份掩码,不进优化器 +``` + +重要:`_t_mu` 和 `is_static` **不加入 optimizer**,仅作为常驻属性。`_sigma_t_raw` 和 `_velocity` 加入 optimizer,但静态点的对应行梯度在每步后被硬清零。 + +### 1.2 初始化策略 + +静态点初始化: +```python +t_mu_static = torch.zeros(N_static) +sigma_t_raw_static = torch.full((N_static,), math.log(1000.0)) # sigma=1000,全时域可见 +velocity_static = torch.zeros(N_static, 3) +is_static_flags = torch.ones(N_static, dtype=torch.bool) +``` + +动态点初始化(逐帧 SfM 点云加载后): +```python +# timestamp 为该帧归一化时间戳,N_frames 为总帧数 +t_mu_dynamic = torch.full((N_dynamic,), timestamp) # 锚定到帧索引 +sigma_t_raw_dynamic = torch.full( + (N_dynamic,), math.log(2.5 / N_frames) +) # 初始覆盖约 2.5 帧 +velocity_dynamic = torch.zeros(N_dynamic, 3) +is_static_flags = torch.zeros(N_dynamic, dtype=torch.bool) +``` + +拼接顺序:静态点在前,动态点在后(便于后续 masking)。 + +### 1.3 时域权重计算(核心辅助函数) + +```python +def compute_temporal_weight(self, t: float) -> torch.Tensor: + """ + 计算所有高斯在时刻 t 的时域活跃权重 w_t^(i)(t) + 返回 shape (N,) 的权重张量,静态点权重恒为 1.0 + """ + sigma_t = torch.exp(self._sigma_t_raw) # (N,) + dt = t - self._t_mu # (N,) + w_t = torch.exp(-dt ** 2 / (2 * sigma_t ** 2 + 1e-8)) # (N,) + # 静态点恒为 1.0 + w_t[self.is_static] = 1.0 + return w_t +``` + +### 1.4 时域感知 VCD(替换原版 `densify_and_clone` / `densify_and_split`) + +**修改要点**:在 FastGS 原始 VCD 的多视图分数计算中,引入时域权重: + +```python +def compute_vcd_score_4d(self, viewspace_points_list, visibility_filter_list, + rendered_images, gt_images, timestamps, tau=0.5): + """ + 时域感知 VCD 分数计算。 + + 参数: + viewspace_points_list: 每个采样视角的 2D 投影点列表 + visibility_filter_list: 每个视角的可见性掩码列表 + rendered_images: list of (3, H, W),K 个采样视角的渲染结果 + gt_images: list of (3, H, W),对应的 GT 图像 + timestamps: list of float,每个采样视角的时间戳 + tau: 高误差像素阈值 + + 返回: + scores: shape (N,),时域加权 VCD 分数 + """ + N = self._xyz.shape[0] + scores = torch.zeros(N, device="cuda") + + for j, (render_j, gt_j, t_j, pts_j, vis_j) in enumerate( + zip(rendered_images, gt_images, timestamps, + viewspace_points_list, visibility_filter_list) + ): + # 1. 计算逐像素 L1 误差图并归一化 + err_map = (render_j - gt_j).abs().mean(dim=0) # (H, W) + err_map_norm = (err_map - err_map.min()) / (err_map.max() - err_map.min() + 1e-8) + mask_j = (err_map_norm > tau).float() # (H, W) + + # 2. 计算当前视角时间戳下各高斯的时域权重 + w_t = self.compute_temporal_weight(t_j) # (N,) + + # 3. 从渲染器前向传播中获取高误差像素计数 + # (FastGS 原版在 render 前向传播中直接统计 2D footprint 内的高误差像素数) + # 这里用 w_t 对每个高斯的像素计数进行加权 + pixel_count_j = self._get_footprint_error_count(pts_j, vis_j, mask_j) # (N,) + + scores = scores + w_t * pixel_count_j + + return scores / len(timestamps) + +def _get_footprint_error_count(self, viewspace_pts, visibility, error_mask): + """ + 统计每个可见高斯在其 2D footprint 内的高误差像素数。 + 复用 FastGS 渲染器前向传播中已实现的统计逻辑。 + 不可见的高斯返回 0。 + """ + N = self._xyz.shape[0] + counts = torch.zeros(N, device="cuda") + # 实现参考 FastGS 原版 render 中的 footprint 统计 + # ...(保持与 FastGS 原版 VCD 一致的实现方式,只是在外部加 w_t 权重) + return counts +``` + +**致密化条件修改**: + +```python +def densify_and_prune_4d(self, vcd_scores, vcp_scores, tau_d=5.0, tau_p=0.9, + min_opacity=0.005, max_screen_size=None): + """ + 时域感知 ADC 主函数,替换 FastGS 原版 densify_and_prune。 + + VCD(致密化):仅对 w_t 活跃期内且 VCD 分数超过阈值的高斯执行 clone/split + VCP(剪枝):分静态/动态两套逻辑 + """ + # --- VCD: 致密化 --- + # 基础条件:VCD 分数超阈值 + densify_mask = vcd_scores > tau_d + + # 动态点额外条件:仅在活跃期(w_t > 0.2)内允许致密化 + # (通过采样当前 batch 时间戳的 w_t 均值估计) + dynamic_active = (~self.is_static) & (self._current_wt_mean > 0.2) + static_mask = self.is_static + densify_mask = densify_mask & (static_mask | dynamic_active) + + # 执行 clone/split(与 FastGS 原版逻辑相同,子代继承 is_static 和时域属性) + self._clone_with_temporal(densify_mask & (grad < threshold)) + self._split_with_temporal(densify_mask & (grad >= threshold)) + + # --- VCP: 剪枝 --- + # 静态点:原版 VCP 逻辑 + static_prune = self.is_static & (vcp_scores > tau_p) + + # 动态点:Credit-Assigned 联合剪枝 + # 仅在活跃期(w_t > 0.2)内判断,避免非活跃帧的正常动态点被误杀 + dynamic_active_prune = (~self.is_static) & (self._current_wt_mean > 0.2) + dynamic_prune = dynamic_active_prune & (vcp_scores > tau_p) + + prune_mask = static_prune | dynamic_prune + + # 附加:过小/过透明点 + opacity_prune = (self.get_opacity.squeeze() < min_opacity) + if max_screen_size is not None: + size_prune = self.get_scaling.max(dim=1).values > max_screen_size + prune_mask = prune_mask | opacity_prune | size_prune + else: + prune_mask = prune_mask | opacity_prune + + self.prune_points(prune_mask) +``` + +### 1.5 clone/split 中的时域属性继承 + +```python +def _clone_with_temporal(self, mask): + """Clone 时,子代继承父代所有属性,包括时域属性""" + # 克隆所有属性(参考 FastGS 原版 clone 实现) + new_xyz = self._xyz[mask] + new_t_mu = self._t_mu[mask] + new_sigma_t_raw = self._sigma_t_raw[mask] + new_velocity = self._velocity[mask] + new_is_static = self.is_static[mask] + # ... 其他属性同原版 + + self._append_gaussians(new_xyz, new_t_mu, new_sigma_t_raw, + new_velocity, new_is_static, ...) + +def _split_with_temporal(self, mask, N_splits=2): + """Split 时,子代继承父代时域属性(位置扰动,时域参数不变)""" + # 子代时域属性 = 父代时域属性(直接复制) + # 子代位置 = 父代位置 + 沿主轴方向的随机扰动 + # 子代 scale 缩小(原版行为) + # 注意:静态点子代的 velocity 和 sigma_t_raw 需要在 post-optim 中硬拉回 + pass +``` + +### 1.6 Credit-Assigned 剪枝(VCP 的动态点条件) + +在 VCP 分数计算之外,以下情况的动态点需额外保护(不被 VCP 删除): +- `w_t < 0.2`:当前非活跃期,即使 VCP 分数高也不剪枝(可能下一帧才是活跃期) +- `alpha * w_t > 0.005`:在活跃期内仍有足够不透明度 + +--- + +## 模块二:渲染管线时域扩展(`gaussian_renderer/__init__.py`) + +### 2.1 渲染主函数修改 + +在调用 CUDA 光栅化器之前,插入时空变换和因果剪枝: + +```python +def render_4d(viewpoint_camera, pc: GaussianModel, pipe, bg_color, + scaling_modifier=1.0, override_color=None): + """ + 4D 渲染主函数。 + + 关键顺序(不可更改): + 1. 时空变换:将高斯中心平移到当前帧位置 + 2. 因果剪枝:生成 alive_mask 稀疏子集 + 3. Compact Box 计算:在 alive 子集上计算 CB(重要:必须在此顺序) + 4. 光栅化:对 alive 子集进行 tile-based rasterization + 5. alive_mask 回填:将稀疏子集统计量映射回全量尺寸 + """ + t = viewpoint_camera.timestamp # float in [0, 1] + + # Step 1: 时空变换 + w_t = pc.compute_temporal_weight(t) # (N,) + + # 位置变换:x' = x0 + v * (t - t_mu) + dt = t - pc._t_mu # (N,) + xyz_transformed = pc.get_xyz + pc._velocity * dt.unsqueeze(-1) # (N, 3) + + # 时域有效不透明度:alpha' = alpha * w_t + opacity = pc.get_opacity.squeeze(-1) * w_t # (N,) + + # Step 2: 因果剪枝(alive_mask) + # 条件1:t_mu <= t(因果律,高斯未"出生"则不渲染) + causal_mask = pc._t_mu <= t + 1e-6 # (N,) bool + # 条件2:alpha' > tau_alive(时域活跃度不透明度阈值) + alive_mask = causal_mask & (opacity > 0.005) # (N,) bool + + # Step 3 & 4: 在 alive 子集上执行 FastGS 原版渲染流程(含 CB) + # 提取稀疏子集 + alive_idx = alive_mask.nonzero(as_tuple=False).squeeze(-1) + xyz_alive = xyz_transformed[alive_idx] + opacity_alive = opacity[alive_idx] + # ... 提取其他属性子集 + + # 调用 FastGS 的 CB 光栅化器(传入变换后的位置) + rendered_image, radii_sparse, viewspace_pts_sparse = rasterize_cb( + xyz_alive, opacity_alive, ... + ) + + # Step 5: 将稀疏子集的 radii 和 viewspace_pts 回填到全量尺寸 + radii_full = torch.zeros(pc._xyz.shape[0], device="cuda") + viewspace_pts_full = torch.zeros(pc._xyz.shape[0], 2, device="cuda") + radii_full[alive_idx] = radii_sparse + viewspace_pts_full[alive_idx] = viewspace_pts_sparse + + return { + "render": rendered_image, + "viewspace_points": viewspace_pts_full, # 全量尺寸,与 FastGS VCD 统计兼容 + "visibility_filter": alive_mask, # 全量尺寸 + "radii": radii_full, + "w_t": w_t, # 返回供 train.py 使用 + } +``` + +### 2.2 alive_mask 的双层回填 + +回填时注意:FastGS 的 VCD 统计是在 **全量高斯的 viewspace_points 上运行**的(用 radii 判断 footprint)。回填必须保持 shape 和 dtype 与原版一致,否则 VCD 统计的 footprint 计算会出错。 + +具体地: +- `radii_full[~alive_mask] = 0`:未激活高斯的 radii 置 0,使其 footprint 为空集 +- `viewspace_pts_full`:需要携带 `requires_grad=True`,因为 FastGS 通过 `viewspace_points.grad` 统计致密化梯度 + +--- + +## 模块三:训练循环修改(`train.py`) + +### 3.1 三级梯度闸门 + +在每次 `loss.backward()` 后、`optimizer.step()` 前插入: + +```python +def apply_gradient_gating(gaussians: GaussianModel, t_current: float): + """ + 三级梯度闸门: + - 静态点:v 和 sigma_t_raw 的梯度强制清零 + - 动态点(当前帧,即 t 接近 t_mu):所有参数梯度放行 + - 动态点(其他帧):xyz/f_dc/f_rest/scaling/rotation 梯度清零, + 只允许 opacity/velocity/sigma_t_raw 梯度通过 + + 注意:此函数中"当前帧"的判断使用 w_t > 0.5 作为阈值(在 ~1σ 生命周期核心内) + """ + w_t = gaussians.compute_temporal_weight(t_current) # (N,) + + is_static = gaussians.is_static # (N,) bool + is_dynamic_current = (~is_static) & (w_t > 0.5) # 动态且在当前帧活跃窗口内 + is_dynamic_other = (~is_static) & (w_t <= 0.5) # 动态但不在当前帧 + + # 静态点:锁死速度和时域参数 + for param_name in ['_velocity', '_sigma_t_raw']: + param = getattr(gaussians, param_name) + if param.grad is not None: + param.grad[is_static] = 0.0 + + # 动态点(其他帧):锁死几何参数,只允许 opacity/velocity/sigma_t_raw 更新 + geo_params = ['_xyz', '_features_dc', '_features_rest', '_scaling', '_rotation'] + for param_name in geo_params: + param = getattr(gaussians, param_name) + if param.grad is not None: + param.grad[is_dynamic_other] = 0.0 + + # 注意:opacity 梯度对所有动态点(包括 other 帧)都放行 + # 这是允许 opacity 跨时间轴统筹优化的关键设计 +``` + +### 3.2 静态点硬拉回(每步 optimizer.step() 后执行) + +```python +def enforce_static_constraints(gaussians: GaussianModel): + """ + 物理硬拉回,对抗 Adam 动量残余。 + 必须在 optimizer.step() 之后立即调用。 + """ + with torch.no_grad(): + static_mask = gaussians.is_static + + # 速度强制归零 + gaussians._velocity.data[static_mask] = 0.0 + + # 生命半径强制设为全时域可见(log(1000)) + gaussians._sigma_t_raw.data[static_mask] = math.log(1000.0) + + # 出生时间强制为 0 + gaussians._t_mu.data[static_mask] = 0.0 + + # 同步清除对应的 Adam 动量(防止动量把参数拉回去) + for group in gaussians.optimizer.param_groups: + for p in group['params']: + if p is gaussians._velocity or p is gaussians._sigma_t_raw: + state = gaussians.optimizer.state[p] + if 'exp_avg' in state: + state['exp_avg'][static_mask] = 0.0 + state['exp_avg_sq'][static_mask] = 0.0 +``` + +### 3.3 解耦 opacity 重置策略 + +```python +def reset_opacity_decoupled(gaussians: GaussianModel, reset_value: float = 0.01): + """ + 分治重置: + - 静态点:重置到 reset_value(原版行为) + - 动态点:不重置(保护动态前景的 opacity 状态) + + 同时同步 Adam 状态,防止重置失效。 + """ + with torch.no_grad(): + static_mask = gaussians.is_static + + # 仅对静态点执行 opacity 重置 + opacities_new = gaussians.get_opacity.clone() + opacities_new[static_mask] = inverse_sigmoid( + torch.ones(static_mask.sum(), 1, device="cuda") * reset_value + ) + + # 通过 replace_tensor_to_optimizer 同步更新 Adam 状态 + gaussians.replace_tensor_to_optimizer(opacities_new, "opacity") +``` + +### 3.4 时序感知相机采样策略(3 阶段) + +```python +def sample_camera_4d(train_cameras: List, iteration: int, + N_frames: int, N_cam_per_frame: int) -> Camera: + """ + 3 阶段采样策略: + + 阶段1(iter <= 3000):静态强化期 + 仅从第 0 帧的 N_cam 个视角中采样,优先收敛静态背景基座 + + 阶段2(3000 < iter <= 10000):时序滑窗期 + 在时间轴上随机选一个起始帧,取相邻 W=4 帧内的视角 + 为速度 v 提供连续帧对比梯度(速度的唯一有效监督信号) + + 阶段3(iter > 10000):全局随机期 + 全部训练相机无放回随机采样,保证全时域全视角覆盖 + """ + if iteration <= 3000: + # 阶段1:仅第 0 帧 + frame_0_cameras = [c for c in train_cameras if c.frame_idx == 0] + return random.choice(frame_0_cameras) + + elif iteration <= 10000: + # 阶段2:时序滑窗 + window_size = 4 + start_frame = random.randint(0, N_frames - window_size) + window_cameras = [ + c for c in train_cameras + if start_frame <= c.frame_idx < start_frame + window_size + ] + return random.choice(window_cameras) + + else: + # 阶段3:全局随机 + return random.choice(train_cameras) +``` + +### 3.5 时域感知 VCD/VCP 集成到训练主循环 + +FastGS 原版的 VCD/VCP 调用逻辑是:在每次 densification 时重新渲染 K 个采样视角,获取误差图和光度损失,计算分数。 + +在 4D 版本中,修改采样视角的方式: + +```python +def sample_views_for_vcd_vcp( + train_cameras: List, + K: int = 10, + iteration: int = 0 +) -> List[Camera]: + """ + 为 VCD/VCP 采样 K 个视角。 + + 与相机采样策略对齐: + - iter <= 3000:只从第 0 帧采样(保证静态场景一致性) + - iter > 3000:全局随机采样(覆盖时域,但注意时域权重会自动过滤非活跃视角的贡献) + + 返回:K 个 Camera 对象,包含 timestamp 属性 + """ + if iteration <= 3000: + pool = [c for c in train_cameras if c.frame_idx == 0] + else: + pool = train_cameras + + return random.sample(pool, min(K, len(pool))) +``` + +在计算 VCD/VCP 分数时,将每个采样视角的时间戳 `t_j` 传入,使 `compute_temporal_weight(t_j)` 自动对非活跃高斯的贡献置零。 + +### 3.6 训练主循环时间线 + +``` +iter 0 ─────────────────────────────────────────────── 30000 + │ + 500 ├─ 致密化开始(VCD + VCP,每 500 轮) + 2000 ├─ SH → 1 阶 + 3000 ├─ 首次 opacity reset(仅静态点) + │ 结束静态专属采样 → 进入时序滑窗期 + 6000 ├─ opacity reset(仅静态点)+ SH → 2 阶 + 10000 ├─ opacity reset(仅静态点)+ SH → 3 阶(满阶) + │ 进入全局随机期 + 12000 ├─ opacity reset(仅静态点) + 15000 ├─ opacity reset(仅静态点)+ 致密化结束 + │ 之后 VCP 每 3000 轮执行一次(仅剪枝,不致密化) + 30000 └─ 训练结束,保存 PLY 序列 +``` + +--- + +## 模块四:场景与数据读取(`scene/` 目录) + +### 4.1 相机对象扩展(`cameras.py`) + +在 `Camera` 类中添加: + +```python +class Camera: + def __init__(self, ..., timestamp: float = 0.0, frame_idx: int = 0): + # 现有属性... + self.timestamp = timestamp # float in [0, 1],归一化时间戳 + self.frame_idx = frame_idx # int,帧序号 + + # 延迟图像加载:仅记录路径,不解码像素 + # self._image_path = image_path + # self._image = None # 懒加载 + + @property + def original_image(self): + # 懒加载实现(可选,用于大规模数据集) + if self._image is None: + self._image = load_image_to_tensor(self._image_path) + return self._image +``` + +### 4.2 场景读取器(`dataset_readers.py`) + +```python +def read_4dgs_scene(scene_path: str, N_frames: int) -> Tuple[List, PointCloud]: + """ + 读取 4DGS 场景数据: + + 1. 读取 COLMAP 标定结果(cameras.bin + images.bin) + 2. 构建 Camera 对象列表,分配 timestamp = frame_idx / (N_frames - 1) + 3. 读取 static_points/points3D.ply 作为背景点云 + 4. 逐帧读取 dynamic_points/frame_XXXX.ply 作为前景点云 + 5. 返回 (camera_list, StaticPointCloud, DynamicPointCloudList) + + 前景点云打包方式:DynamicPointCloudList[i] 是第 i 帧的点云, + 每个点带 timestamp = i / (N_frames - 1),用于初始化 t_mu + """ + pass +``` + +--- + +## 模块五:损失函数 + +```python +# 总损失(与 TD-4DGS 保持一致) +loss = (1 - lambda_s) * L1 + lambda_s * (1 - SSIM) + lambda_v * L_smooth + +# 速度平滑正则(仅动态点) +def compute_velocity_smoothness_loss(gaussians: GaussianModel, K_pairs: int = 4096): + """ + 随机采样 K_pairs 对动态点,计算空间高斯核加权的速度一致性损失。 + + L_smooth = (1/K) * sum_k w_k * ||v_{a_k} - v_{b_k}||^2 + w_k = exp(-||x_{a_k} - x_{b_k}||^2 / (2 * s_bar^2)) + + 其中 s_bar^2 是局部空间尺度(用 KNN 距离估计)。 + """ + dynamic_idx = (~gaussians.is_static).nonzero(as_tuple=False).squeeze(-1) + if dynamic_idx.shape[0] < 2: + return torch.tensor(0.0, device="cuda") + + # 随机采样点对 + idx_a = torch.randint(0, dynamic_idx.shape[0], (K_pairs,)) + idx_b = torch.randint(0, dynamic_idx.shape[0], (K_pairs,)) + + pos_a = gaussians.get_xyz[dynamic_idx[idx_a]] # (K, 3) + pos_b = gaussians.get_xyz[dynamic_idx[idx_b]] # (K, 3) + vel_a = gaussians._velocity[dynamic_idx[idx_a]] # (K, 3) + vel_b = gaussians._velocity[dynamic_idx[idx_b]] # (K, 3) + + dist_sq = ((pos_a - pos_b) ** 2).sum(-1) # (K,) + s_bar_sq = dist_sq.mean().detach() + 1e-8 # 全局尺度估计(简化版) + + w = torch.exp(-dist_sq / (2 * s_bar_sq)) # (K,) + loss = (w * ((vel_a - vel_b) ** 2).sum(-1)).mean() + + return loss + +# 可选:深度正则(指数衰减权重) +depth_weight = 1.0 * math.exp(-iteration / 5000.0) # 从 1.0 衰减到 ~0.007 at 30K +loss = loss + depth_weight * L_depth # 仅在有深度图时启用 +``` + +--- + +## 模块六:动态 Scale 约束(可选,谨慎使用) + +如果发现动态高斯出现"膨胀偷懒"(少量大球代替大量小球),可以在 optimizer.step() 后施加软惩罚,**不建议硬钳位**(硬钳位会破坏 Adam 动量): + +```python +# 软惩罚替代硬钳位 +scale_limit = math.log(0.05 * scene_extent) +dynamic_mask = ~gaussians.is_static +scale_excess = (gaussians._scaling[dynamic_mask] - scale_limit).clamp(min=0) +scale_penalty = scale_excess.pow(2).mean() +loss = loss + 0.001 * scale_penalty # 权重可调 +``` + +--- + +## 关键注意事项(必读) + +### ⚠️ 注意1:alive_mask 与 CB 的执行顺序 +**CB 必须在 alive_mask 过滤之后执行**。如果先计算 CB 再过滤 alive 点,会对不该渲染的高斯浪费 tile pair 计算。正确顺序:提取 alive 子集 → 对子集执行 CB → 光栅化。 + +### ⚠️ 注意2:VCD 时域权重不能在 `torch.no_grad()` 块中计算 +VCD/VCP 分数计算时,`compute_temporal_weight` 必须在计算图中(需要梯度回传到 `sigma_t_raw`)。不要在 `with torch.no_grad():` 块中调用。 + +### ⚠️ 注意3:FastGS 的 τ_d 阈值需为动态点降低 +FastGS 原版 τ_d=5 是针对静态场景调优的,动态场景中每个高斯覆盖的视角更少(因为时域滤波),导致致密化分数天然偏低。建议对动态点使用 τ_d_dynamic = τ_d_static × 0.5 = 2.5。 + +```python +# 在致密化条件中: +tau_d_effective = torch.where( + gaussians.is_static, + torch.tensor(5.0), + torch.tensor(2.5) # 动态点阈值减半 +) +densify_mask = vcd_scores > tau_d_effective +``` + +### ⚠️ 注意4:时域权重的梯度截断位置 +在渲染时,`opacity = base_opacity * w_t` 中,`w_t` 对 `sigma_t_raw` 有梯度。**不要**在 `w_t` 处调用 `.detach()`,否则 `sigma_t_raw` 将无法通过光度损失学习。但在梯度闸门中对静态点清零 `sigma_t_raw.grad` 时,**在 `loss.backward()` 之后**执行。 + +### ⚠️ 注意5:PLY 序列化需保存时域属性 +保存模型时,除了原版 3DGS 的属性外,还需保存 `t_mu`、`sigma_t_raw`、`velocity`、`is_static`。建议在 PLY header 中添加自定义属性字段,并在 `load_ply` 中兼容旧版(无时域属性时降级为静态模式)。 + +### ⚠️ 注意6:大规模数据集的内存管理 +对于 150 帧 × 36 视角 = 5400 相机,建议启用延迟图像加载(Lazy Loading)。如果实现 LRU 缓存,缓存大小建议不超过 200 帧(约 2-4 GB GPU 显存)。 + +### ⚠️ 注意7:FastGS 的 per-splat 并行反传兼容性 +FastGS 从 Taming-3DGS 引入了 per-splat 并行反传(替代原版 per-pixel)。梯度闸门中对 `_xyz.grad` 等参数的清零操作,需要确认 per-splat 反传输出的梯度格式与原版一致,否则索引可能错位。建议在测试阶段先用原版反传验证梯度闸门的正确性,再切换到 per-splat。 + +--- + +## 验证清单 + +在实现完成后,按顺序验证以下测试点: + +1. **静态硬拉回**:训练 1000 步后,检查 `gaussians._velocity[is_static].abs().max()` 是否为 0。 +2. **因果律约束**:以 t=0.1 渲染时,`t_mu > 0.1` 的动态高斯的 `alive_mask` 全为 False。 +3. **时域权重梯度**:检查 `sigma_t_raw.grad[~is_static]` 在 backward 后非零。 +4. **VCD 时域权重**:对一个 `t_mu=0.8` 的动态高斯,在 t=0.0 的视角下的 VCD 分数贡献应近似为 0。 +5. **解耦 opacity reset**:reset 后,`get_opacity()[~is_static]` 的值不应变化。 +6. **子代继承**:clone 之后,新创建点的 `is_static` 与父代一致。 +7. **CB 顺序**:检查传入光栅化器的高斯数量等于 `alive_mask.sum()`,而非全量 N。 +8. **损失下降**:在单帧静态场景(退化为标准 FastGS)上,训练 5K 步的 PSNR 曲线应与原版 FastGS 近似(±0.5 dB 以内)。 + +--- + +## 超参数参考 + +| 参数 | 推荐值 | 说明 | +|------|--------|------| +| K(VCD/VCP 采样视角数) | 10 | 与 FastGS 原版一致 | +| τ_d(静态点致密化阈值) | 5 | 与 FastGS 原版一致 | +| τ_d(动态点致密化阈值) | 2.5 | 动态点视角覆盖稀疏,需降低 | +| τ_p(剪枝阈值) | 0.9 | 与 FastGS 原版一致 | +| τ_alive(因果剪枝阈值) | 0.005 | alpha'(t) 最小有效不透明度 | +| w_t 活跃期阈值(致密化) | 0.2 | 约 ±2σ 生命周期核心内 | +| w_t 当前帧阈值(梯度闸门) | 0.5 | 约 ±1σ 生命周期核心内 | +| σ_t 初始值(动态点) | log(2.5/N_frames) | 初始覆盖约 2.5 帧 | +| σ_t 初始值(静态点) | log(1000) | 全时域可见 | +| λ_v(速度平滑权重) | 0.01 | 可在 0.005-0.05 之间调整 | +| 动态点 opacity reset 值 | 不重置 | 保护动态前景 | +| 静态点 opacity reset 值 | 0.01 | 原版行为 | +| β(CB Mahalanobis 缩放) | 与 FastGS 原版一致 | 动态场景可适当放宽 | + +--- + +## 输出格式 + +训练完成后,模型保存为: +``` +output/ +├── point_cloud/ +│ ├── iteration_30000/ +│ │ └── point_cloud.ply # 包含时域属性的完整模型 +├── renders/ +│ ├── frame_0000/ +│ │ ├── cam_00.png +│ │ └── ... +│ └── ... +└── cfg_args # 训练超参数记录 +``` + +PLY 文件中额外属性字段(在原版 3DGS 属性之后追加): +``` +property float t_mu +property float sigma_t_raw +property float vel_x +property float vel_y +property float vel_z +property uchar is_static # 0 或 1 +``` + +--- + +*文档版本 v1.0 | 基于 FastGS (arXiv:2511.04283v3) 和 TD-4DGS 内部技术报告* diff --git a/scene/__init__.py b/scene/__init__.py index 2b31398..e52c5d2 100755 --- a/scene/__init__.py +++ b/scene/__init__.py @@ -40,17 +40,33 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration self.train_cameras = {} self.test_cameras = {} - if os.path.exists(os.path.join(args.source_path, "sparse")): + # Detect a 4DGS dataset: forced via --force_4dgs, or auto-detected from the + # decoupled static/dynamic point-cloud layout described in prompt.md. + force_4dgs = getattr(args, "force_4dgs", False) + has_temporal_layout = os.path.isdir(os.path.join(args.source_path, "static_points")) or \ + os.path.isdir(os.path.join(args.source_path, "dynamic_points")) + self.is_4dgs = force_4dgs or has_temporal_layout + + if self.is_4dgs: + print("[TD-FastGS] 4DGS dataset detected; using the temporal scene reader.") + n_frames = getattr(args, "n_frames", -1) + scene_info = sceneLoadTypeCallbacks["Colmap4D"](args.source_path, args.images, + args.eval, n_frames=n_frames) + self.n_frames = scene_info.n_frames + elif os.path.exists(os.path.join(args.source_path, "sparse")): scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) + self.n_frames = 1 elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): print("Found transforms_train.json file, assuming Blender data set!") scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) + self.n_frames = 1 else: assert False, "Could not recognize scene type!" if not self.loaded_iter: - with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: - dest_file.write(src_file.read()) + if scene_info.ply_path is not None and os.path.exists(scene_info.ply_path): + with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: + dest_file.write(src_file.read()) json_cams = [] camlist = [] if scene_info.test_cameras: @@ -79,6 +95,9 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply")) + elif self.is_4dgs and scene_info.temporal_point_cloud is not None: + self.gaussians.create_from_pcd_4d(scene_info.temporal_point_cloud, + self.cameras_extent, self.n_frames) else: self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) diff --git a/scene/__pycache__/__init__.cpython-37.pyc b/scene/__pycache__/__init__.cpython-37.pyc deleted file mode 100755 index 3eac0cc..0000000 Binary files a/scene/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/scene/__pycache__/__init__.cpython-38.pyc b/scene/__pycache__/__init__.cpython-38.pyc deleted file mode 100755 index 4ef7e26..0000000 Binary files a/scene/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/scene/__pycache__/cameras.cpython-37.pyc b/scene/__pycache__/cameras.cpython-37.pyc deleted file mode 100755 index 77f4e83..0000000 Binary files a/scene/__pycache__/cameras.cpython-37.pyc and /dev/null differ diff --git a/scene/__pycache__/cameras.cpython-38.pyc b/scene/__pycache__/cameras.cpython-38.pyc deleted file mode 100755 index 27a10fa..0000000 Binary files a/scene/__pycache__/cameras.cpython-38.pyc and /dev/null differ diff --git a/scene/__pycache__/colmap_loader.cpython-37.pyc b/scene/__pycache__/colmap_loader.cpython-37.pyc deleted file mode 100755 index 2a71ae7..0000000 Binary files a/scene/__pycache__/colmap_loader.cpython-37.pyc and /dev/null differ diff --git a/scene/__pycache__/colmap_loader.cpython-38.pyc b/scene/__pycache__/colmap_loader.cpython-38.pyc deleted file mode 100755 index 9c4f315..0000000 Binary files a/scene/__pycache__/colmap_loader.cpython-38.pyc and /dev/null differ diff --git a/scene/__pycache__/dataset_readers.cpython-37.pyc b/scene/__pycache__/dataset_readers.cpython-37.pyc deleted file mode 100755 index eddb03f..0000000 Binary files a/scene/__pycache__/dataset_readers.cpython-37.pyc and /dev/null differ diff --git a/scene/__pycache__/dataset_readers.cpython-38.pyc b/scene/__pycache__/dataset_readers.cpython-38.pyc deleted file mode 100755 index 9df8005..0000000 Binary files a/scene/__pycache__/dataset_readers.cpython-38.pyc and /dev/null differ diff --git a/scene/__pycache__/gaussian_model.cpython-37.pyc b/scene/__pycache__/gaussian_model.cpython-37.pyc deleted file mode 100755 index 9e02776..0000000 Binary files a/scene/__pycache__/gaussian_model.cpython-37.pyc and /dev/null differ diff --git a/scene/__pycache__/gaussian_model.cpython-38.pyc b/scene/__pycache__/gaussian_model.cpython-38.pyc deleted file mode 100755 index 7814895..0000000 Binary files a/scene/__pycache__/gaussian_model.cpython-38.pyc and /dev/null differ diff --git a/scene/cameras.py b/scene/cameras.py index abf6e52..755962c 100755 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -12,12 +12,40 @@ import torch from torch import nn import numpy as np +from collections import OrderedDict +from PIL import Image from utils.graphics_utils import getWorld2View2, getProjectionMatrix +from utils.general_utils import PILtoTorch + +# Bounded LRU cache of decoded+resized images, kept on CPU and keyed by +# (image_path, width, height). For the multi-view-video 4D datasets there can be +# tens of thousands of images, so eager GPU upload is impossible; cameras load +# lazily on first access and the most recently used images stay resident. +_IMAGE_CACHE = OrderedDict() +_IMAGE_CACHE_CAP = 64 + + +def _cache_get(key): + img = _IMAGE_CACHE.get(key) + if img is not None: + _IMAGE_CACHE.move_to_end(key) + return img + + +def _cache_put(key, tensor): + _IMAGE_CACHE[key] = tensor + _IMAGE_CACHE.move_to_end(key) + while len(_IMAGE_CACHE) > _IMAGE_CACHE_CAP: + _IMAGE_CACHE.popitem(last=False) + class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", + timestamp: float = 0.0, frame_idx: int = 0, + image_path: str = None, resolution=None, + gt_width: int = None, gt_height: int = None ): super(Camera, self).__init__() @@ -29,6 +57,11 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.FoVy = FoVy self.image_name = image_name + # Temporal attributes (TD-FastGS 4D extension). + # timestamp is the normalized time in [0, 1]; frame_idx is the integer frame index. + self.timestamp = timestamp + self.frame_idx = frame_idx + try: self.data_device = torch.device(data_device) except Exception as e: @@ -36,14 +69,28 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_device) - self.image_width = self.original_image.shape[2] - self.image_height = self.original_image.shape[1] + # Lazy-loading state (used when `image` is None). + self.image_path = image_path + self.resolution = resolution # (W, H) target for PILtoTorch + self._gt_alpha_mask = gt_alpha_mask # only meaningful in the eager path - if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_device) + if image is not None: + # Eager path (3D Colmap / Blender readers): keep the original behavior. + self._eager_image = image.clamp(0.0, 1.0).to(self.data_device) + self.image_width = self._eager_image.shape[2] + self.image_height = self._eager_image.shape[1] + if gt_alpha_mask is not None: + self._eager_image = self._eager_image * gt_alpha_mask.to(self.data_device) + else: + self._eager_image = self._eager_image * torch.ones( + (1, self.image_height, self.image_width), device=self.data_device) else: - self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + # Lazy path: dims come from the (post-resize) gt_width/gt_height. + self._eager_image = None + assert image_path is not None and gt_width is not None and gt_height is not None, \ + "Lazy Camera requires image_path, gt_width, gt_height" + self.image_width = gt_width + self.image_height = gt_height self.zfar = 100.0 self.znear = 0.01 @@ -56,6 +103,29 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] + @property + def original_image(self): + """Return the GT image on the camera's device. + + Eager cameras hold the tensor directly. Lazy cameras decode + resize from + disk on first access, serving repeats from a bounded CPU LRU cache so the + full image set never has to live in memory at once. + """ + if self._eager_image is not None: + return self._eager_image + + key = (self.image_path, self.image_width, self.image_height) + cpu_img = _cache_get(key) + if cpu_img is None: + pil = Image.open(self.image_path) + resized = PILtoTorch(pil, self.resolution) # (C, H, W) in [0, 1] + rgb = resized[:3, ...].clamp(0.0, 1.0) + if resized.shape[0] == 4: + rgb = rgb * resized[3:4, ...] + cpu_img = rgb.contiguous() + _cache_put(key, cpu_img) + return cpu_img.to(self.data_device) + class MiniCam: def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): self.image_width = width diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py index 2a6f904..4ec51f3 100755 --- a/scene/dataset_readers.py +++ b/scene/dataset_readers.py @@ -34,6 +34,21 @@ class CameraInfo(NamedTuple): image_name: str width: int height: int + timestamp: float = 0.0 # normalized time in [0, 1] (4D extension) + frame_idx: int = 0 # integer frame index (4D extension) + +class TemporalPointCloud(NamedTuple): + """Point cloud carrying per-point temporal metadata for TD-FastGS. + + Points are ordered static-first, then dynamic (frame-by-frame), matching the + concatenation order expected by GaussianModel.create_from_pcd_4d. + """ + points: np.array # (N, 3) + colors: np.array # (N, 3) in [0, 1] + normals: np.array # (N, 3) + timestamps: np.array # (N,) normalized birth time t_mu in [0, 1] + is_static: np.array # (N,) bool, True for background points + velocities: np.array = None # (N, 3) flow-estimated initial velocity, None → zero-init class SceneInfo(NamedTuple): point_cloud: BasicPointCloud @@ -41,6 +56,8 @@ class SceneInfo(NamedTuple): test_cameras: list nerf_normalization: dict ply_path: str + temporal_point_cloud: object = None # TemporalPointCloud for 4D scenes, else None + n_frames: int = 1 # number of temporal frames def getNerfppNorm(cam_info): def get_center_and_diag(cam_centers): @@ -254,7 +271,427 @@ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): ply_path=ply_path) return scene_info +import re +import glob + +def parse_frame_idx(name): + """Extract the integer frame index from an image name. + + Looks for a `frame_` token first (the documented layout), then falls + back to the last run of digits in the name. Returns 0 if nothing is found. + """ + base = os.path.basename(str(name)) + m = re.search(r"frame[_-]?(\d+)", base, flags=re.IGNORECASE) + if m is not None: + return int(m.group(1)) + digits = re.findall(r"\d+", base) + if digits: + return int(digits[-1]) + return 0 + +def _fetch_ply_xyz_rgb(path): + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + try: + colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 + except (ValueError, KeyError): + colors = np.ones_like(positions) * 0.5 + return positions, colors + +def load_temporal_point_cloud_pcd(scene_path, frame_to_t, flows_dir="flows"): + """Load static + per-frame dynamic point clouds for the multi-view-video layout. + + Layout (flower300 / two): + scene_path/static_points/*.ply (e.g. pcd1.ply, t_mu=0) + scene_path/dynamic_points/pcd.ply (frame N, t_mu=frame_to_t[N]) + + Static points come first (t_mu=0, is_static=True), followed by dynamic points + for each frame in ascending frame order. `frame_to_t` maps an integer frame id + to its normalized timestamp (the same map used for the cameras), so a dynamic + Gaussian's birth time equals its source frame's camera timestamp exactly. + + If scene_path/flows_dir exists, 3D velocities are estimated from optical flow + and stored in the returned TemporalPointCloud.velocities field (static=0). + """ + static_dir = os.path.join(scene_path, "static_points") + dyn_dir = os.path.join(scene_path, "dynamic_points") + + pts_list, col_list, ts_list, static_list = [], [], [], [] + # Track per dynamic point which frame it belongs to (for flow estimation). + frame_id_list = [] + + if os.path.isdir(static_dir): + for spath in sorted(glob.glob(os.path.join(static_dir, "*.ply")), + key=lambda p: parse_frame_idx(p)): + s_pts, s_col = _fetch_ply_xyz_rgb(spath) + if s_pts.shape[0] == 0: + continue + pts_list.append(s_pts) + col_list.append(s_col) + ts_list.append(np.zeros(s_pts.shape[0], dtype=np.float32)) + static_list.append(np.ones(s_pts.shape[0], dtype=bool)) + frame_id_list.append(np.full(s_pts.shape[0], -1, dtype=np.int32)) + + if os.path.isdir(dyn_dir): + frame_files = sorted(glob.glob(os.path.join(dyn_dir, "*.ply")), + key=lambda p: parse_frame_idx(p)) + for fpath in frame_files: + fidx = parse_frame_idx(fpath) + d_pts, d_col = _fetch_ply_xyz_rgb(fpath) + if d_pts.shape[0] == 0: + continue + t = float(frame_to_t.get(fidx, 0.0)) + pts_list.append(d_pts) + col_list.append(d_col) + ts_list.append(np.full(d_pts.shape[0], t, dtype=np.float32)) + static_list.append(np.zeros(d_pts.shape[0], dtype=bool)) + frame_id_list.append(np.full(d_pts.shape[0], fidx, dtype=np.int32)) + + if not pts_list: + return None + + points = np.concatenate(pts_list, axis=0).astype(np.float32) + colors = np.concatenate(col_list, axis=0).astype(np.float32) + timestamps = np.concatenate(ts_list, axis=0).astype(np.float32) + is_static = np.concatenate(static_list, axis=0) + frame_ids_all = np.concatenate(frame_id_list, axis=0) + normals = np.zeros_like(points) + + # Estimate velocities from optical flow if the flows directory exists. + velocities = None + flows_root = os.path.join(scene_path, flows_dir) + if os.path.isdir(flows_root): + dyn_mask = ~is_static + if dyn_mask.any(): + print(f"[flow-vel] Estimating initial velocities from optical flow …") + vel_dyn = load_flow_velocities( + scene_path=scene_path, + points_world=points[dyn_mask], + frame_ids=frame_ids_all[dyn_mask], + frame_to_t=frame_to_t, + flows_dir=flows_dir, + ) + velocities = np.zeros_like(points) + velocities[dyn_mask] = vel_dyn + v_mag = np.linalg.norm(vel_dyn, axis=1) + print(f"[flow-vel] velocity magnitude: mean={v_mag.mean():.4f} " + f"p50={np.median(v_mag):.4f} p95={np.percentile(v_mag,95):.4f}") + + return TemporalPointCloud(points=points, colors=colors, normals=normals, + timestamps=timestamps, is_static=is_static, + velocities=velocities) + +def _read_colmap_calib_full(path): + """Like _read_colmap_calib but also returns raw intrinsic params (fx,fy,cx,cy).""" + try: + cam_extrinsics = read_extrinsics_binary(os.path.join(path, "sparse/0", "images.bin")) + cam_intrinsics = read_intrinsics_binary(os.path.join(path, "sparse/0", "cameras.bin")) + except Exception: + cam_extrinsics = read_extrinsics_text(os.path.join(path, "sparse/0", "images.txt")) + cam_intrinsics = read_intrinsics_text(os.path.join(path, "sparse/0", "cameras.txt")) + + cams = {} # keyed by image name stem + for key in cam_extrinsics: + extr = cam_extrinsics[key] + intr = cam_intrinsics[extr.camera_id] + height, width = intr.height, intr.width + R = np.transpose(qvec2rotmat(extr.qvec)) # world←cam rotation + T = np.array(extr.tvec) # world-to-cam translation + if intr.model == "SIMPLE_PINHOLE": + fx = fy = intr.params[0] + cx, cy = intr.params[1], intr.params[2] + elif intr.model == "PINHOLE": + fx, fy = intr.params[0], intr.params[1] + cx, cy = intr.params[2], intr.params[3] + else: + continue # skip unsupported models + name = os.path.basename(extr.name).split(".")[0] + cams[name] = {"R": R, "T": T, "fx": fx, "fy": fy, + "cx": cx, "cy": cy, "width": width, "height": height} + return cams + + +def load_flow_velocities(scene_path, points_world, frame_ids, frame_to_t, flows_dir="flows"): + """Estimate per-point 3D velocity from optical flow for each dynamic frame. + + For each dynamic point at frame f the function: + 1. Projects the point into every camera that has a flow file for frame f. + 2. Reads flow(u,v) at that pixel (bilinear); adds the displacement to get + the pixel position at frame f+1. + 3. Un-projects both pixel positions at depth=1 into camera-space rays, takes + the difference as a direction, converts to world space and normalises by + the frame interval in normalised time (Δt). + 4. Averages the estimates from all cameras that could see the point. + + Returns float32 array (N_dynamic, 3). Points whose frame has no flow (last frame) + or that project outside all cameras get velocity=0. + + Args: + scene_path: dataset root + points_world: (N, 3) 3-D positions of *dynamic* points only, in world space + frame_ids: (N,) integer frame index for each point + frame_to_t: dict frame_int → normalised timestamp + flows_dir: relative path inside scene_path to flow folder + """ + calib = _read_colmap_calib_full(scene_path) + if not calib: + return np.zeros((len(points_world), 3), dtype=np.float32) + + flows_root = os.path.join(scene_path, flows_dir) + N = len(points_world) + velocities = np.zeros((N, 3), dtype=np.float32) + + unique_frames = np.unique(frame_ids) + sorted_frames = sorted(frame_to_t.keys()) + frame_idx_map = {f: i for i, f in enumerate(sorted_frames)} + + for f in unique_frames: + frame_flow_dir = os.path.join(flows_root, str(f)) + if not os.path.isdir(frame_flow_dir): + continue + + # Next frame for Δt + fi = frame_idx_map.get(int(f)) + if fi is None or fi + 1 >= len(sorted_frames): + continue + f_next = sorted_frames[fi + 1] + delta_t = frame_to_t[f_next] - frame_to_t[f] + if abs(delta_t) < 1e-8: + continue + + mask = (frame_ids == f) + pts = points_world[mask] # (M, 3) + M = pts.shape[0] + vel_sum = np.zeros((M, 3), dtype=np.float64) + vel_cnt = np.zeros(M, dtype=np.float32) + + for cam_name, c in calib.items(): + flow_path = os.path.join(frame_flow_dir, cam_name + ".npy") + if not os.path.exists(flow_path): + continue + flow = np.load(flow_path) # (Hf, Wf, 2) values in flow-image coords + Hf, Wf = flow.shape[:2] + H_cam, W_cam = c["height"], c["width"] + + # W2C matrix + R, T = c["R"], c["T"] + # COLMAP convention: R is world←cam rotation (R^T is cam←world) + # T is cam translation in world coords (applied as p_cam = R^T p_world - R^T T) + # Actually in COLMAP: p_cam = R^T @ (p_world - T) where T = camera center in world + # But dataset_readers stores R = transpose(qvec2rotmat) and T = tvec (not center). + # So: p_cam = R.T @ p_world + T (standard COLMAP w2c) + Rc = R.T # (3,3) cam←world rotation + Tc = T # (3,) translation part of w2c + + pts_cam = (Rc @ pts.T).T + Tc # (M, 3) + + # Only points in front of camera + z = pts_cam[:, 2] + valid = z > 0.01 + if not np.any(valid): + continue + + # Project to pixel (using full-res intrinsics) + u0 = (pts_cam[valid, 0] / z[valid]) * c["fx"] + c["cx"] + v0 = (pts_cam[valid, 1] / z[valid]) * c["fy"] + c["cy"] + + # Scale to flow image resolution + scale_u = Wf / W_cam + scale_v = Hf / H_cam + uf = u0 * scale_u + vf = v0 * scale_v + + # Clip to flow image bounds + in_bounds = (uf >= 0) & (uf < Wf - 1) & (vf >= 0) & (vf < Hf - 1) + if not np.any(in_bounds): + continue + + valid_idx = np.where(valid)[0][in_bounds] # indices into pts + uf_valid = uf[in_bounds] + vf_valid = vf[in_bounds] + + # Bilinear sample flow + ui = uf_valid.astype(np.int32) + vi = vf_valid.astype(np.int32) + du = uf_valid - ui + dv = vf_valid - vi + # flow is (Hf, Wf, 2): [delta_u_in_flow_coords, delta_v_in_flow_coords] + f00 = flow[vi, ui ] # (K, 2) + f10 = flow[vi+1, ui ] + f01 = flow[vi, ui+1] + f11 = flow[vi+1, ui+1] + flow_uv = (f00 * (1-du[:,None]) * (1-dv[:,None]) + + f01 * (1-dv[:,None]) * du[:,None] + + f10 * dv[:,None] * (1-du[:,None]) + + f11 * dv[:,None] * du[:,None]) # (K, 2) + + # Convert flow from flow-image coords to full-res pixel coords + flow_u_px = flow_uv[:, 0] / scale_u + flow_v_px = flow_uv[:, 1] / scale_v + + # Pixel coords at frame f+1 + u1 = u0[in_bounds] + flow_u_px + v1 = v0[in_bounds] + flow_v_px + + # Unproject both pixels at depth=1 → direction in camera space + inv_fx, inv_fy = 1.0 / c["fx"], 1.0 / c["fy"] + dir0 = np.stack([(u0[in_bounds] - c["cx"]) * inv_fx, + (v0[in_bounds] - c["cy"]) * inv_fy, + np.ones(len(u1))], axis=1) # (K, 3) + dir1 = np.stack([(u1 - c["cx"]) * inv_fx, + (v1 - c["cy"]) * inv_fy, + np.ones(len(u1))], axis=1) # (K, 3) + + # Scale directions by actual depth so displacement is metric + depth = z[valid][in_bounds] + dir0 *= depth[:, None] + dir1 *= depth[:, None] + + # Displacement in camera space → world space (rotation only, no translation) + Rw = Rc.T # world←cam + disp_world = (Rw @ (dir1 - dir0).T).T # (K, 3) + + vel_world = disp_world / delta_t + vel_sum[valid_idx] += vel_world + vel_cnt[valid_idx] += 1 + + has_est = vel_cnt > 0 + velocities[mask] = np.where(has_est[:, None], + (vel_sum / np.maximum(vel_cnt[:, None], 1)).astype(np.float32), + 0.0) + n_est = int(has_est.sum()) + print(f"[flow-vel] frame {f}: {M} pts, {n_est} got flow estimate " + f"({M - n_est} zero-init)") + + return velocities + + +def _read_colmap_calib(path): + """Read COLMAP camera calibration without opening any images. + + Returns a list of dicts {uid, R, T, FovX, FovY, width, height, name} ordered + by image name. `name` is the COLMAP image NAME stem (here a camera id like + "1", "2", ...). Supports both binary and text COLMAP exports. + """ + try: + cam_extrinsics = read_extrinsics_binary(os.path.join(path, "sparse/0", "images.bin")) + cam_intrinsics = read_intrinsics_binary(os.path.join(path, "sparse/0", "cameras.bin")) + except Exception: + cam_extrinsics = read_extrinsics_text(os.path.join(path, "sparse/0", "images.txt")) + cam_intrinsics = read_intrinsics_text(os.path.join(path, "sparse/0", "cameras.txt")) + + cams = [] + for key in cam_extrinsics: + extr = cam_extrinsics[key] + intr = cam_intrinsics[extr.camera_id] + height, width = intr.height, intr.width + R = np.transpose(qvec2rotmat(extr.qvec)) + T = np.array(extr.tvec) + if intr.model == "SIMPLE_PINHOLE": + fx = intr.params[0] + FovY = focal2fov(fx, height) + FovX = focal2fov(fx, width) + elif intr.model == "PINHOLE": + fx, fy = intr.params[0], intr.params[1] + FovY = focal2fov(fy, height) + FovX = focal2fov(fx, width) + else: + assert False, ("Colmap camera model not handled: only undistorted " + "datasets (PINHOLE or SIMPLE_PINHOLE) supported!") + cams.append({"uid": intr.id, "R": R, "T": T, "FovX": FovX, "FovY": FovY, + "width": width, "height": height, + "name": os.path.basename(extr.name).split(".")[0]}) + cams.sort(key=lambda c: c["name"]) + return cams + + +def readColmap4DSceneInfo(path, images, eval, n_frames=-1, llffhold=8): + """4DGS multi-view-video reader (flower300 layout). + + `sparse/0` calibrates a fixed set of cameras (the COLMAP image names are CAMERA + ids, not frames). The actual frames live as folders under `images//images/`, + each holding one image per camera. Training views are therefore the cross + product (camera x frame); a Camera is emitted per (camera, frame) pair with the + frame's normalized timestamp. Images are NOT opened here (lazy loading): the + CameraInfo carries `image=None` and the path, and dims come from the COLMAP + intrinsics. The decoupled static/dynamic .ply clouds provide 4D initialization. + """ + reading_dir = "images" if images is None else images + images_root = os.path.join(path, reading_dir) + + calib = _read_colmap_calib(path) + + # Discover integer-named frame folders under images/. + frames = [] + if os.path.isdir(images_root): + for entry in os.listdir(images_root): + if entry.isdigit() and os.path.isdir(os.path.join(images_root, entry)): + frames.append(int(entry)) + frames.sort() + if not frames: + frames = [0] + fmin, fmax = frames[0], frames[-1] + span = float(fmax - fmin) if fmax > fmin else 1.0 + frame_to_t = {f: (float(f - fmin) / span) for f in frames} + + if n_frames is None or n_frames <= 0: + n_frames = len(frames) + + # Cross product: one CameraInfo per (frame, camera), lazily loaded. + cam_infos = [] + uid = 0 + missing = 0 + for f in frames: + frame_dir = os.path.join(images_root, str(f), "images") + t = frame_to_t[f] + for c in calib: + img_path = os.path.join(frame_dir, c["name"] + ".png") + if not os.path.exists(img_path): + missing += 1 + continue + cam_infos.append(CameraInfo( + uid=uid, R=c["R"], T=c["T"], FovY=c["FovY"], FovX=c["FovX"], + image=None, image_path=img_path, + image_name="f{}_c{}".format(f, c["name"]), + width=c["width"], height=c["height"], + timestamp=t, frame_idx=f)) + uid += 1 + if missing: + print("[TD-FastGS] Warning: {} (camera, frame) images were missing and skipped.".format(missing)) + print("[TD-FastGS] Built {} cameras across {} frames ({} calibrated cams).".format( + len(cam_infos), len(frames), len(calib))) + + if eval: + train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] + test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] + else: + train_cam_infos = cam_infos + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + temporal_pcd = load_temporal_point_cloud_pcd(path, frame_to_t) + if temporal_pcd is not None: + pcd = BasicPointCloud(points=temporal_pcd.points, + colors=temporal_pcd.colors, + normals=temporal_pcd.normals) + else: + pcd = None + # The decoupled clouds are the only init source (points3D is empty here); no + # input.ply round-trip, so report no ply_path. + ply_path = None + + return SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path, + temporal_point_cloud=temporal_pcd, + n_frames=n_frames) + sceneLoadTypeCallbacks = { "Colmap": readColmapSceneInfo, - "Blender" : readNerfSyntheticInfo + "Blender" : readNerfSyntheticInfo, + "Colmap4D": readColmap4DSceneInfo } \ No newline at end of file diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index ec4b440..f4ae095 100755 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -10,6 +10,7 @@ # import torch +import math import numpy as np from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation, identity_gate from torch import nn @@ -53,7 +54,7 @@ def modify_functions(self): def __init__(self, sh_degree, optimizer_type="default"): self.active_sh_degree = 0 self.optimizer_type = optimizer_type - self.max_sh_degree = sh_degree + self.max_sh_degree = sh_degree self._xyz = torch.empty(0) self._features_dc = torch.empty(0) self._features_rest = torch.empty(0) @@ -68,6 +69,19 @@ def __init__(self, sh_degree, optimizer_type="default"): self.shoptimizer = None self.percent_dense = 0 self.spatial_lr_scale = 0 + + # ----- TD-FastGS temporal attributes ----- + # _t_mu and is_static are resident (NOT in the optimizer); _sigma_t_raw and + # _velocity are optimized but the static rows are hard-zeroed every step. + self.is_4d = False # toggled on by create_from_pcd_4d / load_ply + self._t_mu = torch.empty(0) # (N,) birth time, frozen + self._sigma_t_raw = torch.empty(0) # (N,) life radius (log space), learnable + self._velocity = torch.empty(0) # (N,3) motion velocity, learnable (dynamic) + self.is_static = torch.empty(0, dtype=torch.bool) # (N,) identity mask + self.tau_alive = 0.005 # causal pruning threshold on alpha'(t) + self.n_frames = 1 # number of temporal frames + self._current_wt_mean = torch.empty(0) # cached batch-mean w_t for ADC gating + self.setup_functions() def capture(self, optimizer_type): @@ -189,6 +203,91 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): self._opacity = nn.Parameter(opacities.requires_grad_(True)) self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + def _init_temporal_static(self, N): + """Default temporal attributes for a fully-static model (3D fallback).""" + self._t_mu = torch.zeros(N, device="cuda") + self._sigma_t_raw = nn.Parameter( + torch.full((N,), math.log(1000.0), device="cuda").requires_grad_(True)) + self._velocity = nn.Parameter( + torch.zeros((N, 3), device="cuda").requires_grad_(True)) + self.is_static = torch.ones(N, dtype=torch.bool, device="cuda") + + def compute_temporal_weight(self, t): + """Per-Gaussian temporal activity weight w_t^(i)(t), shape (N,). + + w_t = exp(-(t - t_mu)^2 / (2 sigma_t^2 + eps)); static points are pinned + to 1.0. Must NOT be wrapped in torch.no_grad() when its gradient w.r.t. + sigma_t_raw is needed (rendering / VCD score).""" + sigma_t = torch.exp(self._sigma_t_raw) # (N,) + dt = t - self._t_mu # (N,) + w_t = torch.exp(-dt ** 2 / (2 * sigma_t ** 2 + 1e-8)) + # Pin static points to 1.0 without breaking the graph for dynamic points. + w_t = torch.where(self.is_static, torch.ones_like(w_t), w_t) + return w_t + + def create_from_pcd_4d(self, tpcd, spatial_lr_scale, n_frames): + """Initialize Gaussians from a TemporalPointCloud (static-first ordering). + + Static points: t_mu=0, sigma_t=1000 (full-time visible), v=0, frozen. + Dynamic points: t_mu=birth timestamp, sigma_t covering ~2.5 frames, v=0. + """ + self.is_4d = True + self.n_frames = max(int(n_frames), 1) + self.spatial_lr_scale = spatial_lr_scale + + points = np.asarray(tpcd.points) + colors = np.asarray(tpcd.colors) + timestamps = np.asarray(tpcd.timestamps) + is_static_np = np.asarray(tpcd.is_static) + + fused_point_cloud = torch.tensor(points).float().cuda() + fused_color = RGB2SH(torch.tensor(colors).float().cuda()) + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features[:, :3, 0] = fused_color + features[:, 3:, 1:] = 0.0 + + N = fused_point_cloud.shape[0] + print(f"[TD-FastGS] points at init: {N} " + f"(static={int(is_static_np.sum())}, dynamic={int((~is_static_np).sum())})") + + dist2 = torch.clamp_min(distCUDA2(fused_point_cloud), 0.0000001) + scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3) + rots = torch.zeros((N, 4), device="cuda") + rots[:, 0] = 1 + opacities = self.inverse_opacity_activation(0.1 * torch.ones((N, 1), dtype=torch.float, device="cuda")) + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self.max_radii2D = torch.zeros((N), device="cuda") + + # Temporal attributes. + is_static = torch.tensor(is_static_np, dtype=torch.bool, device="cuda") + t_mu = torch.tensor(timestamps, dtype=torch.float, device="cuda") + # Static points: t_mu pinned to 0. + t_mu = torch.where(is_static, torch.zeros_like(t_mu), t_mu) + + sigma_t_raw = torch.empty(N, device="cuda") + sigma_t_raw[is_static] = math.log(1000.0) + sigma_t_raw[~is_static] = math.log(2.5 / max(self.n_frames, 1)) + + # Use flow-estimated velocities if available, otherwise zero-init. + vel_np = getattr(tpcd, "velocities", None) + if vel_np is not None and np.asarray(vel_np).shape == (N, 3): + velocity = torch.tensor(np.asarray(vel_np), dtype=torch.float, device="cuda") + print(f"[TD-FastGS] initialising velocity from optical flow " + f"(nonzero: {int((velocity.norm(dim=1) > 1e-6).sum())}/{N})") + else: + velocity = torch.zeros((N, 3), device="cuda") + + self.is_static = is_static + self._t_mu = t_mu + self._sigma_t_raw = nn.Parameter(sigma_t_raw.requires_grad_(True)) + self._velocity = nn.Parameter(velocity.requires_grad_(True)) + def training_setup(self, training_args): self.percent_dense = training_args.percent_dense self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") @@ -204,6 +303,16 @@ def training_setup(self, training_args): ] sh_l = [{'params': [self._features_rest], 'lr': training_args.highfeature_lr / 20.0, "name": "f_rest"}] + # Temporal parameters share the main optimizer so prune/clone/cat keep them + # in sync with the geometry tensors (writeback is keyed on group["name"]). + if self.is_4d: + if not isinstance(self._sigma_t_raw, nn.Parameter): + self._init_temporal_static(self.get_xyz.shape[0]) + l.append({'params': [self._velocity], + 'lr': getattr(training_args, "velocity_lr", 0.0016), "name": "velocity"}) + l.append({'params': [self._sigma_t_raw], + 'lr': getattr(training_args, "sigma_t_lr", 0.002), "name": "sigma_t_raw"}) + if self.optimizer_type == "default": self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) self.shoptimizer = torch.optim.Adam(sh_l, lr=0.0, eps=1e-15) @@ -255,6 +364,8 @@ def construct_list_of_attributes(self): l.append('scale_{}'.format(i)) for i in range(self._rotation.shape[1]): l.append('rot_{}'.format(i)) + if self.is_4d: + l += ['t_mu', 'sigma_t_raw', 'vel_x', 'vel_y', 'vel_z', 'is_static'] return l def save_ply(self, path): @@ -271,7 +382,15 @@ def save_ply(self, path): dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] elements = np.empty(xyz.shape[0], dtype=dtype_full) - attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + if self.is_4d: + t_mu = self._t_mu.detach().cpu().numpy()[:, None] + sigma_t_raw = self._sigma_t_raw.detach().cpu().numpy()[:, None] + velocity = self._velocity.detach().cpu().numpy() + is_static = self.is_static.detach().cpu().numpy().astype(np.float32)[:, None] + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation, + t_mu, sigma_t_raw, velocity, is_static), axis=1) + else: + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) elements[:] = list(map(tuple, attributes)) el = PlyElement.describe(elements, 'vertex') PlyData([el]).write(path) @@ -281,6 +400,64 @@ def reset_opacity(self): optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") self._opacity = optimizable_tensors["opacity"] + def reset_opacity_decoupled(self, reset_value=0.01): + """Decoupled opacity reset (TD-FastGS): only static points are reset; dynamic + points keep their opacity to protect foreground temporal state. Adam state is + synced via replace_tensor_to_optimizer.""" + if not self.is_4d: + return self.reset_opacity() + with torch.no_grad(): + static_mask = self.is_static + target = self.inverse_opacity_activation( + torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * reset_value)) + opacities_new = self._opacity.clone() + opacities_new[static_mask] = target[static_mask] + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def enforce_static_constraints(self): + """Hard pull-back for static points to counter Adam momentum residue. + Call immediately after optimizer.step(). Zeros velocity, pins sigma_t_raw to + log(1000) and t_mu to 0, and clears the matching Adam moments.""" + if not self.is_4d: + return + with torch.no_grad(): + static_mask = self.is_static + if static_mask.sum() == 0: + return + self._velocity.data[static_mask] = 0.0 + self._sigma_t_raw.data[static_mask] = math.log(1000.0) + self._t_mu[static_mask] = 0.0 + for group in self.optimizer.param_groups: + if group["name"] in ("velocity", "sigma_t_raw"): + p = group["params"][0] + state = self.optimizer.state.get(p, None) + if state is not None and "exp_avg" in state: + state["exp_avg"][static_mask] = 0.0 + state["exp_avg_sq"][static_mask] = 0.0 + + def apply_gradient_gating(self, t_current, wt_current_thresh=0.5): + """Three-level gradient gate (call after backward(), before step()): + - static points: velocity & sigma_t_raw grads zeroed; + - dynamic & current (w_t > thresh): all grads pass; + - dynamic & other frame: geometry grads zeroed, opacity/velocity/sigma_t pass.""" + if not self.is_4d: + return + with torch.no_grad(): + w_t = self.compute_temporal_weight(t_current).detach() + is_static = self.is_static + is_dynamic_other = (~is_static) & (w_t <= wt_current_thresh) + + for name in ("_velocity", "_sigma_t_raw"): + p = getattr(self, name) + if p.grad is not None: + p.grad[is_static] = 0.0 + + for name in ("_xyz", "_features_dc", "_features_rest", "_scaling", "_rotation"): + p = getattr(self, name) + if p.grad is not None: + p.grad[is_dynamic_other] = 0.0 + def load_ply(self, path): plydata = PlyData.read(path) @@ -322,6 +499,25 @@ def load_ply(self, path): self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + # Restore temporal attributes if present; otherwise degrade to static mode. + prop_names = [p.name for p in plydata.elements[0].properties] + N = xyz.shape[0] + if "t_mu" in prop_names and "sigma_t_raw" in prop_names and "vel_x" in prop_names: + self.is_4d = True + t_mu = np.asarray(plydata.elements[0]["t_mu"]) + sigma_t_raw = np.asarray(plydata.elements[0]["sigma_t_raw"]) + vel = np.stack((np.asarray(plydata.elements[0]["vel_x"]), + np.asarray(plydata.elements[0]["vel_y"]), + np.asarray(plydata.elements[0]["vel_z"])), axis=1) + if "is_static" in prop_names: + is_static = np.asarray(plydata.elements[0]["is_static"]) > 0.5 + else: + is_static = np.zeros(N, dtype=bool) + self._t_mu = torch.tensor(t_mu, dtype=torch.float, device="cuda") + self._sigma_t_raw = nn.Parameter(torch.tensor(sigma_t_raw, dtype=torch.float, device="cuda").requires_grad_(True)) + self._velocity = nn.Parameter(torch.tensor(vel, dtype=torch.float, device="cuda").requires_grad_(True)) + self.is_static = torch.tensor(is_static, dtype=torch.bool, device="cuda") + self.active_sh_degree = self.max_sh_degree def replace_tensor_to_optimizer(self, tensor, name): @@ -372,6 +568,13 @@ def prune_points(self, mask): self._scaling = optimizable_tensors["scaling"] self._rotation = optimizable_tensors["rotation"] + if self.is_4d: + self._velocity = optimizable_tensors["velocity"] + self._sigma_t_raw = optimizable_tensors["sigma_t_raw"] + # Resident (non-optimizer) temporal tensors. + self._t_mu = self._t_mu[valid_points_mask] + self.is_static = self.is_static[valid_points_mask] + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] self.xyz_gradient_accum_abs = self.xyz_gradient_accum_abs[valid_points_mask] @@ -406,7 +609,8 @@ def cat_tensors_to_optimizer(self, tensors_dict): return optimizable_tensors - def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_tmp_radii): + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_tmp_radii, + new_velocity=None, new_sigma_t_raw=None, new_t_mu=None, new_is_static=None): d = {"xyz": new_xyz, "f_dc": new_features_dc, "f_rest": new_features_rest, @@ -414,6 +618,10 @@ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new "scaling" : new_scaling, "rotation" : new_rotation} + if self.is_4d: + d["velocity"] = new_velocity + d["sigma_t_raw"] = new_sigma_t_raw + optimizable_tensors = self.cat_tensors_to_optimizer(d) self._xyz = optimizable_tensors["xyz"] self._features_dc = optimizable_tensors["f_dc"] @@ -422,6 +630,13 @@ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new self._scaling = optimizable_tensors["scaling"] self._rotation = optimizable_tensors["rotation"] + if self.is_4d: + self._velocity = optimizable_tensors["velocity"] + self._sigma_t_raw = optimizable_tensors["sigma_t_raw"] + # Resident temporal tensors grow by concatenation. + self._t_mu = torch.cat((self._t_mu, new_t_mu), dim=0) + self.is_static = torch.cat((self.is_static, new_is_static), dim=0) + self.tmp_radii = torch.cat((self.tmp_radii, new_tmp_radii)) self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") self.xyz_gradient_accum_abs = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") # abs @@ -447,7 +662,18 @@ def densify_and_split_fastgs(self, metric_mask, filter, N=2): new_opacity = self._opacity[selected_pts_mask].repeat(N,1) new_tmp_radii = self.tmp_radii[selected_pts_mask].repeat(N) - self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_tmp_radii) + if self.is_4d: + # Children inherit temporal attributes verbatim (position is perturbed, + # temporal parameters are copied). Static children stay static. + new_velocity = self._velocity[selected_pts_mask].repeat(N, 1) + new_sigma_t_raw = self._sigma_t_raw[selected_pts_mask].repeat(N) + new_t_mu = self._t_mu[selected_pts_mask].repeat(N) + new_is_static = self.is_static[selected_pts_mask].repeat(N) + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_tmp_radii, + new_velocity=new_velocity, new_sigma_t_raw=new_sigma_t_raw, + new_t_mu=new_t_mu, new_is_static=new_is_static) + else: + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_tmp_radii) prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) self.prune_points(prune_filter) @@ -463,7 +689,16 @@ def densify_and_clone_fastgs(self, metric_mask, filter): new_rotation = self._rotation[selected_pts_mask] new_tmp_radii = self.tmp_radii[selected_pts_mask] - self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_tmp_radii) + if self.is_4d: + new_velocity = self._velocity[selected_pts_mask] + new_sigma_t_raw = self._sigma_t_raw[selected_pts_mask] + new_t_mu = self._t_mu[selected_pts_mask] + new_is_static = self.is_static[selected_pts_mask] + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_tmp_radii, + new_velocity=new_velocity, new_sigma_t_raw=new_sigma_t_raw, + new_t_mu=new_t_mu, new_is_static=new_is_static) + else: + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_tmp_radii) def densify_and_prune_fastgs(self, max_screen_size, min_opacity, extent, radii, args, importance_score = None, pruning_score = None): @@ -534,7 +769,104 @@ def final_prune_fastgs(self, min_opacity, pruning_score = None): """Final-stage pruning: remove Gaussians based on opacity and multi-view consistency. In the final stage we remove Gaussians that have low opacity or that are flagged by our multi-view reconstruction consistency metric (provided as `pruning_score`).""" - prune_mask = (self.get_opacity < min_opacity).squeeze() + prune_mask = (self.get_opacity < min_opacity).squeeze() scores_mask = pruning_score > 0.9 final_prune = torch.logical_or(prune_mask, scores_mask) + self.prune_points(final_prune) + + # ===================== TD-FastGS temporal ADC ===================== + + def set_current_wt_mean(self, timestamps): + """Cache the per-Gaussian mean temporal weight over a set of view timestamps. + Used by the densify/prune gating to decide which dynamic points are 'active' + in the current densification batch. Computed without grad.""" + with torch.no_grad(): + if len(timestamps) == 0: + self._current_wt_mean = torch.ones(self.get_xyz.shape[0], device="cuda") + return + acc = torch.zeros(self.get_xyz.shape[0], device="cuda") + for t in timestamps: + acc += self.compute_temporal_weight(float(t)) + self._current_wt_mean = acc / len(timestamps) + + def densify_and_prune_4d(self, max_screen_size, min_opacity, extent, radii, args, + importance_score=None, pruning_score=None): + """Temporal-aware ADC. Mirrors densify_and_prune_fastgs but: + - uses per-point densification thresholds (static=tau_d_static, + dynamic=tau_d_dynamic) and gates dynamic densify on the active window; + - prunes static points by VCP and dynamic points by credit-assigned VCP + restricted to their active window (w_t > wt_densify_thresh).""" + grad_vars = self.xyz_gradient_accum / self.denom + grad_vars[grad_vars.isnan()] = 0.0 + self.tmp_radii = radii + + grads_abs = self.xyz_gradient_accum_abs / self.denom + grads_abs[grads_abs.isnan()] = 0.0 + + grad_qualifiers = torch.norm(grad_vars, dim=-1) >= args.grad_thresh + grad_qualifiers_abs = torch.norm(grads_abs, dim=-1) >= args.grad_abs_thresh + clone_qualifiers = torch.max(self.get_scaling, dim=1).values <= args.dense * extent + split_qualifiers = torch.max(self.get_scaling, dim=1).values > args.dense * extent + + all_clones = torch.logical_and(clone_qualifiers, grad_qualifiers) + all_splits = torch.logical_and(split_qualifiers, grad_qualifiers_abs) + + # Per-point densification threshold (dynamic points use a lower tau_d). + tau_d = torch.where(self.is_static, + torch.full_like(self._current_wt_mean, args.tau_d_static), + torch.full_like(self._current_wt_mean, args.tau_d_dynamic)) + metric_mask = importance_score.squeeze() > tau_d + + # Dynamic points may only densify inside their active window. + dynamic_active = (~self.is_static) & (self._current_wt_mean > args.wt_densify_thresh) + densify_allowed = self.is_static | dynamic_active + metric_mask = metric_mask & densify_allowed + + self.densify_and_clone_fastgs(metric_mask, all_clones) + self.densify_and_split_fastgs(metric_mask, all_splits) + + # ---- pruning ---- + # Clone/split appended new points at the end, so vcp / wt_mean (computed at + # the pre-densification size) must be padded to the current size. New points + # get score 0 (eligible for opacity pruning only, never VCP pruning). + N_now = self.get_xyz.shape[0] + vcp = pruning_score.squeeze() + if vcp.shape[0] < N_now: + pad = torch.zeros(N_now - vcp.shape[0], device=vcp.device, dtype=vcp.dtype) + vcp = torch.cat((vcp, pad), dim=0) + wt_mean = self._current_wt_mean + if wt_mean.shape[0] < N_now: + pad = torch.zeros(N_now - wt_mean.shape[0], device=wt_mean.device, dtype=wt_mean.dtype) + wt_mean = torch.cat((wt_mean, pad), dim=0) + + prune_mask = (self.get_opacity < min_opacity).squeeze() + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + + # VCP score prune: static always eligible; dynamic only inside active window. + static_prune = self.is_static & (vcp > args.tau_p) + dyn_active_now = (~self.is_static) & (wt_mean > args.wt_densify_thresh) + dynamic_prune = dyn_active_now & (vcp > args.tau_p) + prune_mask = prune_mask | static_prune | dynamic_prune + + self.prune_points(prune_mask) + + opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.8)) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + self.tmp_radii = None + torch.cuda.empty_cache() + + def final_prune_4d(self, min_opacity, pruning_score=None, args=None): + """Final-stage temporal pruning. Like final_prune_fastgs, but dynamic points + are protected outside their active window (credit assignment).""" + prune_mask = (self.get_opacity < min_opacity).squeeze() + vcp = pruning_score.squeeze() + wt_thresh = args.wt_densify_thresh if args is not None else 0.2 + static_prune = self.is_static & (vcp > 0.9) + dyn_active = (~self.is_static) & (self._current_wt_mean > wt_thresh) + dynamic_prune = dyn_active & (vcp > 0.9) + final_prune = prune_mask | static_prune | dynamic_prune self.prune_points(final_prune) \ No newline at end of file diff --git a/slim_ply.py b/slim_ply.py new file mode 100644 index 0000000..ea9e156 --- /dev/null +++ b/slim_ply.py @@ -0,0 +1,170 @@ +""" +对 4DGS point_cloud.ply 进行瘦身,仅保留渲染器所需属性。 + +兼容两种 4DGS 属性命名(输入自动检测,输出统一为 DT-4DGS 命名): + TD-FastGS (本项目): sigma_t_raw, vel_x/vel_y/vel_z, is_static + DT-4DGS (原始): t_sigma, velocity_0/1/2 + +输出始终使用 DT-4DGS 字段名,与播放器兼容: + x, y, z + f_dc_0, f_dc_1, f_dc_2 + f_rest_0..2, f_rest_15..17, f_rest_30..32 + opacity + scale_0, scale_1, scale_2 + rot_0, rot_1, rot_2, rot_3 + t_mu, t_sigma ← 时域中心与宽度(标准名) + velocity_0/1/2 ← 速度(标准名) + +用法: + python slim_ply.py # 默认处理 point_cloud.ply → point_cloud_slim.ply + python slim_ply.py -i input.ply -o out.ply # 指定输入/输出路径 +""" + +import argparse +import os +import numpy as np +from plyfile import PlyData, PlyElement + + +# Output attribute list (DT-4DGS naming — what the player expects). +_OUT_PROPS = [ + "x", "y", "z", + "f_dc_0", "f_dc_1", "f_dc_2", + "f_rest_0", "f_rest_1", "f_rest_2", + "f_rest_15", "f_rest_16", "f_rest_17", + "f_rest_30", "f_rest_31", "f_rest_32", + "opacity", + "scale_0", "scale_1", "scale_2", + "rot_0", "rot_1", "rot_2", "rot_3", + "t_mu", "t_sigma", + "velocity_0", "velocity_1", "velocity_2", +] + +# Mapping: input field name → canonical output field name. +# Fields absent from this map are copied under the same name. +_RENAME = { + "sigma_t_raw": "t_sigma", + "vel_x": "velocity_0", + "vel_y": "velocity_1", + "vel_z": "velocity_2", +} + +# Fields that exist only in TD-FastGS and have no output equivalent → drop. +_DROP = {"is_static"} + + +def slim_ply(input_path: str, output_path: str, num_frames: int, vel_threshold: float = 1e-3) -> None: + print(f"读取:{input_path}") + plydata = PlyData.read(input_path) + src = plydata["vertex"] + src_dtype = src.data.dtype + + all_props = [p.name for p in src.properties] + print(f"原始属性数:{len(all_props)}") + + # Build the output column map: output_name → source data array. + col_data = {} # output_name → np.ndarray + col_dtype = {} # output_name → dtype str + for name in all_props: + out_name = _RENAME.get(name, name) + if out_name in _DROP or out_name not in _OUT_PROPS: + continue + col_data[out_name] = np.asarray(src.data[name]) + col_dtype[out_name] = src_dtype[name].str + + # Determine which output columns we actually have. + out_names = [n for n in _OUT_PROPS if n in col_data] + removed = [n for n in all_props if (_RENAME.get(n, n) not in col_data)] + print(f"输出属性数:{len(out_names)}") + print(f"删除/映射后丢弃属性数:{len(all_props) - len(out_names)}") + if removed: + print(f"删除属性:{removed}") + + # Detect which sigma/velocity source was present for informational output. + sigma_src = "sigma_t_raw" if "sigma_t_raw" in all_props else ("t_sigma" if "t_sigma" in all_props else None) + vel_src = ("vel_x","vel_y","vel_z") if "vel_x" in all_props else \ + (("velocity_0","velocity_1","velocity_2") if "velocity_0" in all_props else None) + if sigma_src: + print(f"时域宽度:{sigma_src} → t_sigma") + if vel_src: + print(f"速度字段:{vel_src} → velocity_0/1/2") + + # Assemble output structured array. + N = len(src.data) + out_dtype = [(n, col_dtype[n]) for n in out_names] + new_data = np.empty(N, dtype=out_dtype) + for n in out_names: + new_data[n] = col_data[n] + + # 静态/动态高斯分离:速度模长 < vel_threshold 的为静态高斯 + comments = [f"num_frames {num_frames}"] + has_velocity = all(n in col_data for n in ("velocity_0", "velocity_1", "velocity_2")) + num_static = 0 + + if has_velocity and vel_threshold > 0: + vel_sq = (new_data["velocity_0"].astype(np.float64) ** 2 + + new_data["velocity_1"].astype(np.float64) ** 2 + + new_data["velocity_2"].astype(np.float64) ** 2) + is_static_mask = vel_sq < vel_threshold ** 2 + num_static = int(is_static_mask.sum()) + num_dynamic = N - num_static + print(f"静态高斯(|v|<{vel_threshold}):{num_static} 动态高斯:{num_dynamic}") + + if num_static > 0: + static_idx = np.where(is_static_mask)[0] + dynamic_idx = np.where(~is_static_mask)[0] + order = np.concatenate([static_idx, dynamic_idx]) + new_data = new_data[order] + comments.append(f"num_static {num_static}") + else: + print("未启用静态分离(vel_threshold=0 或无速度属性)") + + # 动态高斯按 t_mu 排序 + if "t_mu" in out_names and num_static < N: + dyn_chunk = new_data[num_static:].copy() + dyn_order = np.argsort(dyn_chunk["t_mu"], kind="stable") + new_data[num_static:] = dyn_chunk[dyn_order] + print(f"动态高斯按 t_mu 排序({len(dyn_chunk)} 条),支持时域窗口优化") + + new_element = PlyElement.describe(new_data, "vertex") + PlyData([new_element], text=False, comments=comments).write(output_path) + + orig_mb = os.path.getsize(input_path) / 1024 / 1024 + slim_mb = os.path.getsize(output_path) / 1024 / 1024 + print(f"\n完成!") + print(f" 总帧数:{num_frames}(已写入 PLY 头部 comment)") + if num_static > 0: + print(f" 静态高斯:{num_static}(预计算一次,不参与逐帧 GPU compute)") + print(f" 原始文件:{orig_mb:.1f} MB") + print(f" 瘦身文件:{slim_mb:.1f} MB ({slim_mb/orig_mb*100:.1f}%)") + print(f" 输出路径:{output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="4DGS PLY 属性瘦身工具(输出统一为 DT-4DGS 命名)") + parser.add_argument("-i", "--input", default="point_cloud.ply", help="输入 PLY 路径(默认:point_cloud.ply)") + parser.add_argument("-o", "--output", default="point_cloud_slim.ply", help="输出 PLY 路径(默认:point_cloud_slim.ply)") + parser.add_argument("-n", "--num-frames", type=int, default=None, help="动画总帧数(不传则交互输入)") + parser.add_argument("--vel-threshold", type=float, default=1e-3, help="速度模长阈值,低于此值视为静态高斯(默认 1e-3),0 禁用分离") + args = parser.parse_args() + + if not os.path.isfile(args.input): + print(f"错误:找不到输入文件 {args.input}") + raise SystemExit(1) + + num_frames = args.num_frames + if num_frames is None: + while True: + try: + num_frames = int(input("请输入动画总帧数:").strip()) + if num_frames > 0: + break + print("帧数必须大于 0,请重新输入。") + except ValueError: + print("请输入有效整数。") + + slim_ply(args.input, args.output, num_frames, args.vel_threshold) + + +if __name__ == "__main__": + main() diff --git a/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/auxiliary.h b/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/auxiliary.h index 27d372d..e66ad7c 100755 --- a/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/auxiliary.h +++ b/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/auxiliary.h @@ -333,6 +333,11 @@ __device__ inline uint32_t duplicateToTilesTouched( return 0; } + // Below 1/255 effective alpha, same as render skip; avoids log(negative) -> NaN tile bounds + if (con_o.w <= (1.0f / 255.0f)) { + return 0; + } + // Threshold: opacity * Gaussian = 1 / 255 float t = 2.0f * log(con_o.w * 255.0f); t = mult * t; // beta in Compact Box diff --git a/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/forward.cu b/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/forward.cu index 50c844d..f89453a 100755 --- a/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/forward.cu +++ b/submodules/diff-gaussian-rasterization_fastgs/cuda_rasterizer/forward.cu @@ -19,6 +19,12 @@ #include namespace cg = cooperative_groups; +struct MaxUint32Op { + __device__ __forceinline__ uint32_t operator()(const uint32_t& a, const uint32_t& b) const { + return max(a, b); + } +}; + // Forward method for converting the input spherical harmonics // coefficients of each Gaussian to a simple RGB color. __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* dc, const float* shs, bool* clamped) @@ -431,7 +437,7 @@ renderCUDA( // max reduce the last contributor typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - last_contributor = BlockReduce(temp_storage).Reduce(last_contributor, cub::Max()); + last_contributor = BlockReduce(temp_storage).Reduce(last_contributor, MaxUint32Op()); if (block.thread_rank() == 0) { max_contrib[tile_id] = last_contributor; } diff --git a/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/PKG-INFO b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/PKG-INFO new file mode 100644 index 0000000..f386790 --- /dev/null +++ b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/PKG-INFO @@ -0,0 +1,5 @@ +Metadata-Version: 2.4 +Name: diff_gaussian_rasterization_fastgs +Version: 0.0.0 +License-File: LICENSE.md +Dynamic: license-file diff --git a/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/SOURCES.txt b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/SOURCES.txt new file mode 100644 index 0000000..7e3806a --- /dev/null +++ b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/SOURCES.txt @@ -0,0 +1,14 @@ +LICENSE.md +README.md +ext.cpp +rasterize_points.cu +setup.py +cuda_rasterizer/adam.cu +cuda_rasterizer/backward.cu +cuda_rasterizer/forward.cu +cuda_rasterizer/rasterizer_impl.cu +diff_gaussian_rasterization_fastgs/__init__.py +diff_gaussian_rasterization_fastgs.egg-info/PKG-INFO +diff_gaussian_rasterization_fastgs.egg-info/SOURCES.txt +diff_gaussian_rasterization_fastgs.egg-info/dependency_links.txt +diff_gaussian_rasterization_fastgs.egg-info/top_level.txt \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/dependency_links.txt b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/top_level.txt b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/top_level.txt new file mode 100644 index 0000000..96191af --- /dev/null +++ b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs.egg-info/top_level.txt @@ -0,0 +1 @@ +diff_gaussian_rasterization_fastgs diff --git a/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs/__init__.py b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs/__init__.py index d06aaa3..64b6ffd 100755 --- a/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs/__init__.py +++ b/submodules/diff-gaussian-rasterization_fastgs/diff_gaussian_rasterization_fastgs/__init__.py @@ -94,7 +94,7 @@ def forward( if raster_settings.debug: cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted try: - num_rendered, num_buckets, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + num_rendered, num_buckets, color, radii, geomBuffer, binningBuffer, imgBuffer, sampleBuffer, accum_metric_counts = _C.rasterize_gaussians(*args) except Exception as ex: torch.save(cpu_args, "snapshot_fw.dump") print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") diff --git a/submodules/diff-gaussian-rasterization_fastgs/setup.py b/submodules/diff-gaussian-rasterization_fastgs/setup.py index 1874cba..56c4c50 100755 --- a/submodules/diff-gaussian-rasterization_fastgs/setup.py +++ b/submodules/diff-gaussian-rasterization_fastgs/setup.py @@ -27,7 +27,10 @@ "cuda_rasterizer/adam.cu", "rasterize_points.cu", "ext.cpp"], - extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) + extra_compile_args={"nvcc": [ + "-allow-unsupported-compiler", + "-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/"), + ]}) ], cmdclass={ 'build_ext': BuildExtension diff --git a/submodules/fused-ssim/setup.py b/submodules/fused-ssim/setup.py index fcffbbe..9e04cd4 100755 --- a/submodules/fused-ssim/setup.py +++ b/submodules/fused-ssim/setup.py @@ -9,7 +9,8 @@ name="fused_ssim_cuda", sources=[ "ssim.cu", - "ext.cpp"]) + "ext.cpp"], + extra_compile_args={"nvcc": ["-allow-unsupported-compiler"]}) ], cmdclass={ 'build_ext': BuildExtension diff --git a/submodules/simple-knn/build/lib.win-amd64-cpython-312/simple_knn/_C.cp312-win_amd64.pyd b/submodules/simple-knn/build/lib.win-amd64-cpython-312/simple_knn/_C.cp312-win_amd64.pyd new file mode 100644 index 0000000..c941542 Binary files /dev/null and b/submodules/simple-knn/build/lib.win-amd64-cpython-312/simple_knn/_C.cp312-win_amd64.pyd differ diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/.ninja_deps b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/.ninja_deps new file mode 100644 index 0000000..47cc80e Binary files /dev/null and b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/.ninja_deps differ diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/.ninja_log b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/.ninja_log new file mode 100644 index 0000000..db83787 --- /dev/null +++ b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/.ninja_log @@ -0,0 +1,4 @@ +# ninja log v7 +8 8851 8021761600000000 F:/project/FastGS/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/simple_knn.obj f0901b659a32d63c +2 9379 8021761599700000 F:/project/FastGS/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/ext.obj 5d7b84f32a4a6e80 +14 27726 8021761600000000 F:/project/FastGS/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/spatial.obj a8f12bd2e3bfcbf9 diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/_C.cp312-win_amd64.exp b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/_C.cp312-win_amd64.exp new file mode 100644 index 0000000..dc6d9c2 Binary files /dev/null and b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/_C.cp312-win_amd64.exp differ diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/_C.cp312-win_amd64.lib b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/_C.cp312-win_amd64.lib new file mode 100644 index 0000000..32b59b6 Binary files /dev/null and b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/_C.cp312-win_amd64.lib differ diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/build.ninja b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/build.ninja new file mode 100644 index 0000000..cbbc25e --- /dev/null +++ b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/build.ninja @@ -0,0 +1,38 @@ +ninja_required_version = 1.3 +cxx = cl +nvcc = C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.0\bin\nvcc + +cflags = /nologo /O2 /W3 /GL /DNDEBUG /MD -IC:\ProgramData\miniconda3\envs\fastgs\Lib\site-packages\torch\include -IC:\ProgramData\miniconda3\envs\fastgs\Lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.0\include" -IC:\ProgramData\miniconda3\envs\fastgs\include -IC:\ProgramData\miniconda3\envs\fastgs\Include "-IC:\Program Files (x86)\Microsoft Visual Studio\18\BuildTools\VC\Tools\MSVC\14.50.35717\include" "-IC:\Program Files (x86)\Microsoft Visual Studio\18\BuildTools\VC\Auxiliary\VS\include" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.26100.0\ucrt" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\um" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\shared" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\winrt" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\cppwinrt" "-IC:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um" /MD /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc +post_cflags = /wd4624 -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C /std:c++17 +cuda_cflags = -std=c++17 -Xcompiler /MD -Xcompiler /wd4819 -Xcompiler /wd4251 -Xcompiler /wd4244 -Xcompiler /wd4267 -Xcompiler /wd4275 -Xcompiler /wd4018 -Xcompiler /wd4190 -Xcompiler /wd4624 -Xcompiler /wd4067 -Xcompiler /wd4068 -Xcompiler /EHsc --use-local-env -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -IC:\ProgramData\miniconda3\envs\fastgs\Lib\site-packages\torch\include -IC:\ProgramData\miniconda3\envs\fastgs\Lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.0\include" -IC:\ProgramData\miniconda3\envs\fastgs\include -IC:\ProgramData\miniconda3\envs\fastgs\Include "-IC:\Program Files (x86)\Microsoft Visual Studio\18\BuildTools\VC\Tools\MSVC\14.50.35717\include" "-IC:\Program Files (x86)\Microsoft Visual Studio\18\BuildTools\VC\Auxiliary\VS\include" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.26100.0\ucrt" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\um" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\shared" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\winrt" "-IC:\Program Files (x86)\Windows Kits\10\\include\10.0.26100.0\\cppwinrt" "-IC:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um" +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -allow-unsupported-compiler -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 +cuda_dlink_post_cflags = +sycl_dlink_post_cflags = +ldflags = + +rule compile + command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags + deps = msvc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc -MD -MF $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags + + + + + + + +build F$:\project\FastGS\submodules\simple-knn\build\temp.win-amd64-cpython-312\Release\ext.obj: compile F$:\project\FastGS\submodules\simple-knn\ext.cpp +build F$:\project\FastGS\submodules\simple-knn\build\temp.win-amd64-cpython-312\Release\simple_knn.obj: cuda_compile F$:\project\FastGS\submodules\simple-knn\simple_knn.cu +build F$:\project\FastGS\submodules\simple-knn\build\temp.win-amd64-cpython-312\Release\spatial.obj: cuda_compile F$:\project\FastGS\submodules\simple-knn\spatial.cu + + + + + + + + diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/ext.obj b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/ext.obj new file mode 100644 index 0000000..2dbfd39 Binary files /dev/null and b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/ext.obj differ diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/simple_knn.obj b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/simple_knn.obj new file mode 100644 index 0000000..1f376db Binary files /dev/null and b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/simple_knn.obj differ diff --git a/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/spatial.obj b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/spatial.obj new file mode 100644 index 0000000..f8c9bac Binary files /dev/null and b/submodules/simple-knn/build/temp.win-amd64-cpython-312/Release/spatial.obj differ diff --git a/submodules/simple-knn/setup.py b/submodules/simple-knn/setup.py index 580d2bd..68a0b28 100755 --- a/submodules/simple-knn/setup.py +++ b/submodules/simple-knn/setup.py @@ -27,7 +27,7 @@ "spatial.cu", "simple_knn.cu", "ext.cpp"], - extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) + extra_compile_args={"nvcc": ["-allow-unsupported-compiler"], "cxx": cxx_compiler_flags}) ], cmdclass={ 'build_ext': BuildExtension diff --git a/submodules/simple-knn/simple_knn.egg-info/PKG-INFO b/submodules/simple-knn/simple_knn.egg-info/PKG-INFO index 872241a..2e0eba4 100755 --- a/submodules/simple-knn/simple_knn.egg-info/PKG-INFO +++ b/submodules/simple-knn/simple_knn.egg-info/PKG-INFO @@ -1,4 +1,5 @@ -Metadata-Version: 2.1 +Metadata-Version: 2.4 Name: simple_knn Version: 0.0.0 License-File: LICENSE.md +Dynamic: license-file diff --git a/tests/test_td_fastgs.py b/tests/test_td_fastgs.py new file mode 100644 index 0000000..f524d6b --- /dev/null +++ b/tests/test_td_fastgs.py @@ -0,0 +1,194 @@ +"""Unit tests for the TD-FastGS temporal extension. + +These tests exercise the temporal logic that does not require the CUDA rasterizer. +They are written to be run with pytest on a CUDA-enabled machine: + + pytest tests/test_td_fastgs.py + +Tests that intrinsically need rasterization (full render_4d, VCD/VCP scoring) are +described as structured logic / smoke checks and skipped when CUDA is absent. +""" + +import math +import numpy as np +import pytest + +try: + import torch + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +pytestmark = pytest.mark.skipif(not HAS_CUDA, reason="requires CUDA") + + +def _make_model(n_static=4, n_dynamic=6, n_frames=10): + """Build a small 4D GaussianModel via create_from_pcd_4d.""" + from scene.gaussian_model import GaussianModel + from scene.dataset_readers import TemporalPointCloud + + N = n_static + n_dynamic + pts = np.random.randn(N, 3).astype(np.float32) + cols = np.random.rand(N, 3).astype(np.float32) + normals = np.zeros((N, 3), np.float32) + is_static = np.zeros(N, dtype=bool) + is_static[:n_static] = True + ts = np.zeros(N, np.float32) + # dynamic points born at staggered frames + ts[n_static:] = np.linspace(0.0, 1.0, n_dynamic).astype(np.float32) + + tpcd = TemporalPointCloud(points=pts, colors=cols, normals=normals, + timestamps=ts, is_static=is_static) + g = GaussianModel(sh_degree=1) + g.create_from_pcd_4d(tpcd, spatial_lr_scale=1.0, n_frames=n_frames) + + class _Opt: # minimal training-args stub + percent_dense = 0.01 + position_lr_init = 1e-4 + position_lr_final = 1e-6 + position_lr_delay_mult = 0.01 + position_lr_max_steps = 30000 + lowfeature_lr = 2.5e-3 + highfeature_lr = 5e-3 + opacity_lr = 0.025 + scaling_lr = 5e-3 + rotation_lr = 1e-3 + velocity_lr = 1.6e-3 + sigma_t_lr = 2e-3 + g.training_setup(_Opt()) + return g + + +def test_init_ordering_and_flags(): + """1.6/6: static-first ordering, correct is_static flags and sigma_t init.""" + g = _make_model(n_static=4, n_dynamic=6, n_frames=10) + assert g.is_static[:4].all() + assert (~g.is_static[4:]).all() + # static sigma_t_raw ~ log(1000); dynamic ~ log(2.5/n_frames) + assert torch.allclose(g._sigma_t_raw[:4], + torch.full((4,), math.log(1000.0), device="cuda"), atol=1e-4) + assert torch.allclose(g._sigma_t_raw[4:], + torch.full((6,), math.log(2.5 / 10), device="cuda"), atol=1e-4) + # static t_mu pinned to 0 + assert torch.all(g._t_mu[:4] == 0) + + +def test_temporal_weight_static_is_one(): + """compute_temporal_weight pins static points to 1.0 at any t.""" + g = _make_model() + for t in (0.0, 0.3, 1.0): + w = g.compute_temporal_weight(t) + assert torch.allclose(w[g.is_static], torch.ones(int(g.is_static.sum()), device="cuda")) + + +def test_temporal_weight_gradient_flows_to_sigma(): + """3: sigma_t_raw receives gradient through w_t for dynamic points.""" + g = _make_model() + w = g.compute_temporal_weight(0.5) + # a dynamic-only objective + loss = (w[~g.is_static] ** 2).sum() + loss.backward() + assert g._sigma_t_raw.grad is not None + assert g._sigma_t_raw.grad[~g.is_static].abs().sum() > 0 + + +def test_static_hard_pullback(): + """1: after perturbing then enforcing constraints, static velocity == 0.""" + g = _make_model() + with torch.no_grad(): + g._velocity.data[g.is_static] = 5.0 + g._sigma_t_raw.data[g.is_static] = 0.0 + g._t_mu[g.is_static] = 0.7 + g.enforce_static_constraints() + assert g._velocity[g.is_static].abs().max() == 0 + assert torch.allclose(g._sigma_t_raw[g.is_static], + torch.full((int(g.is_static.sum()),), math.log(1000.0), device="cuda")) + assert torch.all(g._t_mu[g.is_static] == 0) + + +def test_gradient_gating_static_velocity_zeroed(): + """3.1: gate zeros static velocity/sigma grads and frozen-frame geometry grads.""" + g = _make_model() + # fake gradients on all params + for name in ("_velocity", "_sigma_t_raw", "_xyz", "_scaling"): + p = getattr(g, name) + p.grad = torch.ones_like(p) + g.apply_gradient_gating(t_current=0.0, wt_current_thresh=0.5) + # static velocity / sigma grads cleared + assert g._velocity.grad[g.is_static].abs().max() == 0 + assert g._sigma_t_raw.grad[g.is_static].abs().max() == 0 + # dynamic points far from t=0 (w_t<=0.5) have geometry grads cleared + w = g.compute_temporal_weight(0.0) + dyn_other = (~g.is_static) & (w <= 0.5) + if dyn_other.any(): + assert g._xyz.grad[dyn_other].abs().max() == 0 + + +def test_decoupled_opacity_reset_preserves_dynamic(): + """5: dynamic opacity unchanged after decoupled reset; static lowered.""" + g = _make_model() + with torch.no_grad(): + g._opacity.data.fill_(g.inverse_opacity_activation(torch.tensor(0.9)).item()) + dyn_before = g.get_opacity[~g.is_static].clone() + g.reset_opacity_decoupled(reset_value=0.01) + assert torch.allclose(g.get_opacity[~g.is_static], dyn_before, atol=1e-5) + assert g.get_opacity[g.is_static].max() <= 0.05 + + +def test_save_load_roundtrip(tmp_path): + """PLY serialization preserves temporal attributes (graceful for legacy too).""" + from scene.gaussian_model import GaussianModel + g = _make_model() + path = str(tmp_path / "pc.ply") + g.save_ply(path) + g2 = GaussianModel(sh_degree=1) + g2.load_ply(path) + assert g2.is_4d + assert torch.allclose(g2._t_mu, g._t_mu, atol=1e-5) + assert torch.equal(g2.is_static, g.is_static) + assert torch.allclose(g2._velocity, g._velocity, atol=1e-5) + + +def test_causal_mask_logic(): + """2: causal mask excludes Gaussians born after t (pure tensor logic).""" + g = _make_model() + t = 0.1 + causal = g._t_mu <= (t + 1e-6) + # any dynamic point with t_mu > 0.1 must be excluded + born_later = (~g.is_static) & (g._t_mu > 0.1) + if born_later.any(): + assert (~causal[born_later]).all() + + +def test_velocity_smoothness_nonnegative(): + """Module 5: velocity-smoothness loss is a finite non-negative scalar.""" + from utils.fast_utils import compute_velocity_smoothness_loss + g = _make_model() + with torch.no_grad(): + g._velocity.data[~g.is_static] = torch.randn_like(g._velocity[~g.is_static]) + loss = compute_velocity_smoothness_loss(g, K_pairs=128) + assert loss.item() >= 0 + assert torch.isfinite(loss) + + +def test_child_inherits_is_static(): + """6: clone preserves is_static for children (structural check via clone path).""" + g = _make_model(n_static=4, n_dynamic=6) + n_before = g.get_xyz.shape[0] + g.tmp_radii = torch.zeros(n_before, device="cuda") + # clone all static points + metric = g.is_static.clone() + filt = torch.ones(n_before, dtype=bool, device="cuda") + g.densify_and_clone_fastgs(metric, filt) + # appended children should all be static (parents were static) + n_added = g.get_xyz.shape[0] - n_before + assert n_added == int(metric.sum()) + assert g.is_static[n_before:].all() + + +# --- Logic-only descriptions for rasterizer-dependent paths ----------------- +# test_render_4d_subset_size: render_4d must pass exactly alive_mask.sum() +# Gaussians to the rasterizer (CB runs after the alive subset is extracted), +# while returning full-size radii / viewspace_points / visibility_filter. +# test_vcd_temporal_zeroing: a t_mu=0.8 dynamic Gaussian contributes ~0 to the +# VCD score under a t=0.0 view because w_t(0.0) -> 0 for that Gaussian. diff --git a/train.py b/train.py index 0cf7e1f..de2d58c 100755 --- a/train.py +++ b/train.py @@ -11,12 +11,12 @@ import torch import numpy as np -import os, random, time +import os, random, time, math from random import randint from lpipsPyTorch import lpips from utils.loss_utils import l1_loss from fused_ssim import fused_ssim as fast_ssim -from gaussian_renderer import render_fastgs, network_gui_ws +from gaussian_renderer import render_fastgs, render_4d, network_gui_ws import sys from scene import Scene, GaussianModel from utils.general_utils import safe_state @@ -31,7 +31,9 @@ except ImportError: TENSORBOARD_FOUND = False -from utils.fast_utils import compute_gaussian_score_fastgs, sampling_cameras +from utils.fast_utils import (compute_gaussian_score_fastgs, sampling_cameras, + compute_gaussian_score_fastgs_4d, sample_camera_4d, + sample_views_for_vcd_vcp, compute_velocity_smoothness_loss) def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, websockets): @@ -39,6 +41,13 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi tb_writer = prepare_output_and_logger(dataset) gaussians = GaussianModel(dataset.sh_degree, opt.optimizer_type) scene = Scene(dataset, gaussians) + + # Dispatch to the temporal training loop for 4DGS datasets. + if getattr(scene, "is_4dgs", False): + return training_4d(dataset, opt, pipe, testing_iterations, saving_iterations, + checkpoint_iterations, checkpoint, debug_from, websockets, + tb_writer=tb_writer, gaussians=gaussians, scene=scene) + gaussians.training_setup(opt) if checkpoint: (model_params, first_iter) = torch.load(checkpoint) @@ -175,8 +184,156 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi # scene.save(iteration) print(f"Gaussian number: {gaussians._xyz.shape[0]}") print(f"Training time: {total_time}") - -def prepare_output_and_logger(args): + + +def training_4d(dataset, opt, pipe, testing_iterations, saving_iterations, + checkpoint_iterations, checkpoint, debug_from, websockets, + tb_writer=None, gaussians=None, scene=None): + """TD-FastGS temporal training loop. + + Differences from the 3D loop: + - renders through render_4d at each camera's timestamp; + - 3-stage temporal camera sampling; + - velocity-smoothness regularizer added to the loss; + - three-level gradient gate after backward(), before step(); + - static hard pull-back after step(); + - decoupled (static-only) opacity reset; + - temporal-aware VCD/VCP densification and pruning. + """ + first_iter = 0 + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + n_frames = getattr(scene, "n_frames", 1) + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing=True) + iter_end = torch.cuda.Event(enable_timing=True) + total_time = 0.0 + + train_cameras = scene.getTrainCameras().copy() + + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress (4D)") + first_iter += 1 + bg = torch.rand((3), device="cuda") if opt.random_background else background + + for iteration in range(first_iter, opt.iterations + 1): + iter_start.record() + gaussians.update_learning_rate(iteration) + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # 3-stage temporal camera sampling. + viewpoint_cam = sample_camera_4d(train_cameras, iteration, n_frames, opt) + + if (iteration - 1) == debug_from: + pipe.debug = True + + render_pkg = render_4d(viewpoint_cam, gaussians, pipe, bg, opt.mult) + image = render_pkg["render"] + viewspace_point_tensor = render_pkg["viewspace_points"] + visibility_filter = render_pkg["visibility_filter"] + radii = render_pkg["radii"] + + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + ssim_value = fast_ssim(image.unsqueeze(0), gt_image.unsqueeze(0)) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) + + # Velocity-smoothness regularizer over dynamic points. + if opt.lambda_velocity > 0: + loss = loss + opt.lambda_velocity * compute_velocity_smoothness_loss( + gaussians, opt.velocity_smooth_pairs) + + # Optional soft scale penalty for dynamic points. + if opt.lambda_scale_penalty > 0: + scale_limit = math.log(0.05 * scene.cameras_extent + 1e-8) + dyn = ~gaussians.is_static + if dyn.any(): + excess = (gaussians._scaling[dyn] - scale_limit).clamp(min=0) + loss = loss + opt.lambda_scale_penalty * excess.pow(2).mean() + + loss.backward() + + # Three-level gradient gate (after backward, before step). + gaussians.apply_gradient_gating(viewpoint_cam.timestamp, opt.wt_current_thresh) + + iter_end.record() + + with torch.no_grad(): + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + if iteration in saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + _training_4d_densify(iteration, opt, pipe, bg, dataset, scene, + gaussians, train_cameras, n_frames, + viewspace_point_tensor, visibility_filter, radii) + + # Optimization step + static hard pull-back. + if iteration < opt.iterations: + if opt.optimizer_type == "default": + gaussians.optimizer_step(iteration) + elif opt.optimizer_type == "sparse_adam": + visible = radii > 0 + gaussians.optimizer.step(visible, radii.shape[0]) + gaussians.optimizer.zero_grad(set_to_none=True) + gaussians.enforce_static_constraints() + + torch.cuda.synchronize() + total_time += iter_start.elapsed_time(iter_end) / 1e3 + + print(f"[TD-FastGS] Gaussian number: {gaussians._xyz.shape[0]}") + print(f"[TD-FastGS] Training time: {total_time}") + + +def _training_4d_densify(iteration, opt, pipe, bg, dataset, scene, gaussians, + train_cameras, n_frames, viewspace_point_tensor, + visibility_filter, radii): + """Densification / pruning sub-step for the 4D loop (temporal VCD/VCP).""" + if iteration < opt.densify_until_iter: + gaussians.max_radii2D[visibility_filter] = torch.max( + gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + camlist = sample_views_for_vcd_vcp(train_cameras, 10, iteration, opt) + timestamps = [c.timestamp for c in camlist] + gaussians.set_current_wt_mean(timestamps) + importance_score, pruning_score = compute_gaussian_score_fastgs_4d( + camlist, gaussians, pipe, bg, opt, render_4d, DENSIFY=True) + gaussians.densify_and_prune_4d( + max_screen_size=size_threshold, min_opacity=0.005, + extent=scene.cameras_extent, radii=radii, args=opt, + importance_score=importance_score, pruning_score=pruning_score) + + # Decoupled opacity reset: static points only. + if iteration % opt.opacity_reset_interval == 0 or \ + (dataset.white_background and iteration == opt.densify_from_iter): + gaussians.reset_opacity_decoupled() + + # Post-densification temporal pruning every 3k iters in [15k, 30k). + if iteration % 3000 == 0 and 15000 < iteration < 30000: + camlist = sample_views_for_vcd_vcp(train_cameras, 10, iteration, opt) + timestamps = [c.timestamp for c in camlist] + gaussians.set_current_wt_mean(timestamps) + _, pruning_score = compute_gaussian_score_fastgs_4d( + camlist, gaussians, pipe, bg, opt, render_4d, DENSIFY=False) + gaussians.final_prune_4d(min_opacity=0.1, pruning_score=pruning_score, args=opt) + + +def prepare_output_and_logger(args): if not args.model_path: if os.getenv('OAR_JOB_ID'): unique_str=os.getenv('OAR_JOB_ID') diff --git a/utils/__pycache__/camera_utils.cpython-37.pyc b/utils/__pycache__/camera_utils.cpython-37.pyc deleted file mode 100755 index 5dace64..0000000 Binary files a/utils/__pycache__/camera_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/camera_utils.cpython-38.pyc b/utils/__pycache__/camera_utils.cpython-38.pyc deleted file mode 100755 index 98c0ba1..0000000 Binary files a/utils/__pycache__/camera_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/__pycache__/fast_utils.cpython-37.pyc b/utils/__pycache__/fast_utils.cpython-37.pyc deleted file mode 100755 index a7a0bc8..0000000 Binary files a/utils/__pycache__/fast_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/general_utils.cpython-37.pyc b/utils/__pycache__/general_utils.cpython-37.pyc deleted file mode 100755 index 572dde9..0000000 Binary files a/utils/__pycache__/general_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/general_utils.cpython-38.pyc b/utils/__pycache__/general_utils.cpython-38.pyc deleted file mode 100755 index b9a097d..0000000 Binary files a/utils/__pycache__/general_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/__pycache__/graphics_utils.cpython-37.pyc b/utils/__pycache__/graphics_utils.cpython-37.pyc deleted file mode 100755 index 5a862a4..0000000 Binary files a/utils/__pycache__/graphics_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/graphics_utils.cpython-38.pyc b/utils/__pycache__/graphics_utils.cpython-38.pyc deleted file mode 100755 index 994180e..0000000 Binary files a/utils/__pycache__/graphics_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/__pycache__/image_utils.cpython-37.pyc b/utils/__pycache__/image_utils.cpython-37.pyc deleted file mode 100755 index 6584050..0000000 Binary files a/utils/__pycache__/image_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/image_utils.cpython-38.pyc b/utils/__pycache__/image_utils.cpython-38.pyc deleted file mode 100755 index 5d7f516..0000000 Binary files a/utils/__pycache__/image_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/__pycache__/loss_utils.cpython-37.pyc b/utils/__pycache__/loss_utils.cpython-37.pyc deleted file mode 100755 index 0c1694b..0000000 Binary files a/utils/__pycache__/loss_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/loss_utils.cpython-38.pyc b/utils/__pycache__/loss_utils.cpython-38.pyc deleted file mode 100755 index 4ded4ee..0000000 Binary files a/utils/__pycache__/loss_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/__pycache__/sh_utils.cpython-37.pyc b/utils/__pycache__/sh_utils.cpython-37.pyc deleted file mode 100755 index 513bd90..0000000 Binary files a/utils/__pycache__/sh_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/sh_utils.cpython-38.pyc b/utils/__pycache__/sh_utils.cpython-38.pyc deleted file mode 100755 index 62b7853..0000000 Binary files a/utils/__pycache__/sh_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/__pycache__/system_utils.cpython-37.pyc b/utils/__pycache__/system_utils.cpython-37.pyc deleted file mode 100755 index 4faa8ed..0000000 Binary files a/utils/__pycache__/system_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/system_utils.cpython-38.pyc b/utils/__pycache__/system_utils.cpython-38.pyc deleted file mode 100755 index 086141b..0000000 Binary files a/utils/__pycache__/system_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/__pycache__/taming_utils.cpython-37.pyc b/utils/__pycache__/taming_utils.cpython-37.pyc deleted file mode 100755 index 6b52744..0000000 Binary files a/utils/__pycache__/taming_utils.cpython-37.pyc and /dev/null differ diff --git a/utils/__pycache__/taming_utils.cpython-38.pyc b/utils/__pycache__/taming_utils.cpython-38.pyc deleted file mode 100755 index e3b3683..0000000 Binary files a/utils/__pycache__/taming_utils.cpython-38.pyc and /dev/null differ diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0a..386f2fb 100755 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -16,27 +16,49 @@ WARNED = False -def loadCam(args, id, cam_info, resolution_scale): - orig_w, orig_h = cam_info.image.size +def _compute_resolution(args, orig_w, orig_h, resolution_scale): + """Resolve the target (W, H) for an image given the CLI resolution policy. + Shared by the eager and lazy loading paths so they downscale identically. + """ if args.resolution in [1, 2, 4, 8]: - resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) - else: # should be a type that converts to float - if args.resolution == -1: - if orig_w > 1600: - global WARNED - if not WARNED: - print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " - "If this is not desired, please explicitly specify '--resolution/-r' as 1") - WARNED = True - global_down = orig_w / 1600 - else: - global_down = 1 + return (round(orig_w / (resolution_scale * args.resolution)), + round(orig_h / (resolution_scale * args.resolution))) + # type that converts to float + if args.resolution == -1: + if orig_w > 1600: + global WARNED + if not WARNED: + print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " + "If this is not desired, please explicitly specify '--resolution/-r' as 1") + WARNED = True + global_down = orig_w / 1600 else: - global_down = orig_w / args.resolution + global_down = 1 + else: + global_down = orig_w / args.resolution + + scale = float(global_down) * float(resolution_scale) + return (int(orig_w / scale), int(orig_h / scale)) + - scale = float(global_down) * float(resolution_scale) - resolution = (int(orig_w / scale), int(orig_h / scale)) +def loadCam(args, id, cam_info, resolution_scale): + # Lazy path: 4D multi-view-video cameras carry no decoded image. Compute the + # target resolution from the COLMAP intrinsic dims and defer disk decoding to + # Camera.original_image (bounded LRU cache). + if getattr(cam_info, "image", None) is None: + resolution = _compute_resolution(args, cam_info.width, cam_info.height, resolution_scale) + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=None, gt_alpha_mask=None, + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + timestamp=getattr(cam_info, "timestamp", 0.0), + frame_idx=getattr(cam_info, "frame_idx", 0), + image_path=cam_info.image_path, resolution=resolution, + gt_width=resolution[0], gt_height=resolution[1]) + + orig_w, orig_h = cam_info.image.size + resolution = _compute_resolution(args, orig_w, orig_h, resolution_scale) resized_image_rgb = PILtoTorch(cam_info.image, resolution) @@ -46,10 +68,12 @@ def loadCam(args, id, cam_info, resolution_scale): if resized_image_rgb.shape[1] == 4: loaded_mask = resized_image_rgb[3:4, ...] - return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, - FoVx=cam_info.FovX, FoVy=cam_info.FovY, + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id, data_device=args.data_device) + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + timestamp=getattr(cam_info, "timestamp", 0.0), + frame_idx=getattr(cam_info, "frame_idx", 0)) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/fast_utils.py b/utils/fast_utils.py index 2fbe9e7..78f3314 100755 --- a/utils/fast_utils.py +++ b/utils/fast_utils.py @@ -15,7 +15,7 @@ def sampling_cameras(my_viewpoint_stack): for _ in range(num_cams): loc = random.randint(0, len(my_viewpoint_stack) - 1) camlist.append(my_viewpoint_stack.pop(loc)) - + return camlist def get_loss(reconstructed_image, original_image): @@ -103,3 +103,136 @@ def compute_gaussian_score_fastgs(camlist, gaussians, pipe, bg, args, DENSIFY = else: importance_score = None return importance_score, pruning_score + + +# ============================================================================ +# TD-FastGS 4D (temporal) helpers +# ============================================================================ + +def sample_camera_4d(train_cameras, iteration, n_frames, opt): + """Three-stage temporal camera sampling strategy. + + Stage 1 (iter <= static_only_until): sample only frame-0 views to first + converge the static background base. + Stage 2 (<= temporal_window_until): pick a random start frame and sample + within a contiguous window of `temporal_window_size` frames, giving the + velocity term adjacent-frame supervision. + Stage 3 (otherwise): uniform random over all training cameras. + """ + if iteration <= opt.static_only_until: + pool = [c for c in train_cameras if c.frame_idx == 0] + if not pool: + pool = train_cameras + return random.choice(pool) + elif iteration <= opt.temporal_window_until: + w = max(1, opt.temporal_window_size) + start = random.randint(0, max(0, n_frames - w)) + pool = [c for c in train_cameras if start <= c.frame_idx < start + w] + if not pool: + pool = train_cameras + return random.choice(pool) + else: + return random.choice(train_cameras) + + +def sample_views_for_vcd_vcp(train_cameras, K, iteration, opt): + """Sample K views for the temporal VCD/VCP score computation. + + Aligned with the camera-sampling strategy: stage 1 draws only frame-0 views + (static-scene consistency); afterwards it samples globally and the temporal + weight in the score automatically discounts inactive views. Returns a list of + Camera objects (each carrying a `timestamp`).""" + if iteration <= opt.static_only_until: + pool = [c for c in train_cameras if c.frame_idx == 0] + if not pool: + pool = train_cameras + else: + pool = train_cameras + return random.sample(pool, min(K, len(pool))) + + +def compute_velocity_smoothness_loss(gaussians, K_pairs=4096): + """Spatially-weighted velocity consistency regularizer over dynamic points. + + L_smooth = (1/K) * sum_k w_k * ||v_{a_k} - v_{b_k}||^2 + w_k = exp(-||x_{a_k} - x_{b_k}||^2 / (2 * s_bar^2)) + + s_bar^2 is a global spatial scale estimated from the sampled pair distances. + Returns a zero scalar if there are fewer than two dynamic points. + """ + dynamic_idx = (~gaussians.is_static).nonzero(as_tuple=False).squeeze(-1) + if dynamic_idx.shape[0] < 2: + return torch.zeros((), device="cuda") + + n = dynamic_idx.shape[0] + idx_a = dynamic_idx[torch.randint(0, n, (K_pairs,), device=dynamic_idx.device)] + idx_b = dynamic_idx[torch.randint(0, n, (K_pairs,), device=dynamic_idx.device)] + + pos_a = gaussians.get_xyz[idx_a] + pos_b = gaussians.get_xyz[idx_b] + vel_a = gaussians._velocity[idx_a] + vel_b = gaussians._velocity[idx_b] + + dist_sq = ((pos_a - pos_b) ** 2).sum(-1) + s_bar_sq = dist_sq.mean().detach() + 1e-8 + w = torch.exp(-dist_sq / (2 * s_bar_sq)) + loss = (w * ((vel_a - vel_b) ** 2).sum(-1)).mean() + return loss + + +def compute_gaussian_score_fastgs_4d(camlist, gaussians, pipe, bg, args, render_4d, DENSIFY=False): + """Temporal-aware multi-view VCD/VCP score. + + Identical in structure to compute_gaussian_score_fastgs, but (1) renders each + view through the 4D renderer at that view's timestamp, and (2) weights each + view's per-Gaussian high-error counts by the temporal weight w_t(t_view) so + that Gaussians inactive at a view's time contribute ~0 to that view's score. + + IMPORTANT: compute_temporal_weight must be called OUTSIDE torch.no_grad() so + that gradients can flow to sigma_t_raw. The caller is responsible for the + grad context; here we keep w_t in the graph and only detach where counts are + combined into the (non-differentiable) score statistics. + """ + full_metric_counts = None + full_metric_score = None + + for view in range(len(camlist)): + cam = camlist[view] + t_j = cam.timestamp + + render_pkg0 = render_4d(cam, gaussians, pipe, bg, args.mult) + render_image = render_pkg0["render"] + photometric_loss = compute_photometric_loss(cam, render_image) + + gt_image = cam.original_image.cuda() + l1_loss_norm = get_loss(render_image, gt_image) + metric_map = (l1_loss_norm > args.loss_thresh).int() + + render_pkg = render_4d(cam, gaussians, pipe, bg, args.mult, + get_flag=True, metric_map=metric_map) + accum_loss_counts = render_pkg["accum_metric_counts"] # (N_full,) + + # Temporal weight at this view's time, broadcast over the full set. + w_t = gaussians.compute_temporal_weight(t_j).detach() # (N,) + weighted_counts = accum_loss_counts.float() * w_t + + if DENSIFY: + if full_metric_counts is None: + full_metric_counts = weighted_counts.clone() + else: + full_metric_counts += weighted_counts + + contrib = photometric_loss * weighted_counts + if full_metric_score is None: + full_metric_score = contrib.clone() + else: + full_metric_score += contrib + + denom = (torch.max(full_metric_score) - torch.min(full_metric_score)) + pruning_score = (full_metric_score - torch.min(full_metric_score)) / (denom + 1e-8) + + if DENSIFY: + importance_score = full_metric_counts / len(camlist) + else: + importance_score = None + return importance_score, pruning_score