Skip to content

feat: Inference-time gradient checkpointing for low-VRAM GPUs#2

Open
jashshah999 wants to merge 1 commit into
MIT-SPARK:mainfrom
jashshah999:feat/inference-checkpointing
Open

feat: Inference-time gradient checkpointing for low-VRAM GPUs#2
jashshah999 wants to merge 1 commit into
MIT-SPARK:mainfrom
jashshah999:feat/inference-checkpointing

Conversation

@jashshah999
Copy link
Copy Markdown

Summary

Adds two memory optimization features for running VGGT inference on GPUs with limited VRAM (8-12GB):

  1. Inference-time gradient checkpointing (use_checkpointing_inference): Reuses the existing torch.utils.checkpoint machinery (currently training-only) during inference. Recomputes activations instead of storing all 24 layers of intermediates. Saves ~40% peak VRAM.

  2. Sequential head execution (sequential_heads): Frees camera_head intermediates before running depth_head, reducing peak memory from holding both heads' activations simultaneously.

Changes

  • vggt/models/aggregator.py: Added use_checkpointing_inference attribute. When True, both frame and global attention blocks use checkpoint() during eval (previously only during training).
  • vggt/models/vggt.py: Added sequential_heads attribute. When True, explicitly frees pose_enc_list and calls torch.cuda.empty_cache() between camera_head and depth_head.

Usage

model = VGGT()
model.eval()
model.aggregator.use_checkpointing_inference = True
model.sequential_heads = True

Companion PR: MIT-SPARK/VGGT-SLAM#41 (adds --low_vram CLI flag that sets these attributes automatically).

Trade-offs

  • Checkpointing adds ~30% inference time (recomputation cost) but saves ~40% VRAM
  • Sequential heads add negligible time but save peak memory from concurrent head execution
  • No change to outputs — mathematically identical results

…xecution

Reduces peak VRAM by ~40% during inference, enabling VGGT-SLAM to
run on GPUs with 8-12GB VRAM.

Changes:
- aggregator.py: New `use_checkpointing_inference` flag that enables
  gradient checkpointing for both frame and global attention blocks
  during inference (not just training). Recomputes activations during
  backward-free forward pass to avoid storing all intermediates.
- vggt.py: New `sequential_heads` flag that frees camera_head
  intermediates before running depth_head, reducing peak memory.

These are controlled by VGGT-SLAM's --low_vram flag (see companion
PR: MIT-SPARK/VGGT-SLAM#41).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant