Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,3 @@ If our code is helpful, please cite our papers as follows:
year={2026}
}
```

12 changes: 8 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ def main():
clip_model, clip_preprocess = None, None
clip_tokenizer = None

model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
# model = VGGT()
# _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
# model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
use_point_map = True
model = VGGT(enable_point=use_point_map, enable_track=False)
ckpt = torch.load(os.getenv("HOME", "")+"/model_tracker_fixed_e20.pt", map_location="cpu")
incompat = model.load_state_dict(ckpt, strict=False)

model.eval()
model = model.to(torch.bfloat16) # use half precision
model = model.to(torch.float16) # use half precision
model = model.to(device)

# Use the provided image folder path
Expand Down
28 changes: 14 additions & 14 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@ cd ..
# 3. Clone and install our fork of VGGT
echo "Cloning and installing VGGT..."
cd third_party
git clone https://github.com/MIT-SPARK/VGGT_SPARK.git vggt
git clone https://github.com/stepeos/SparkFastVGGT.git vggt
pip install -e ./vggt
cd ..

# 4. Install Perception Encoder
echo "Cloning and installing Perception Encoder..."
cd third_party
git clone https://github.com/facebookresearch/perception_models.git
pip install -e ./perception_models
cd ..

# 5. Install SAM 3
echo "Cloning and installing SAM 3..."
cd third_party
git clone https://github.com/facebookresearch/sam3.git
pip install -e ./sam3
cd ..
# # 4. Install Perception Encoder
# echo "Cloning and installing Perception Encoder..."
# cd third_party
# git clone https://github.com/facebookresearch/perception_models.git
# pip install -e ./perception_models
# cd ..

# # 5. Install SAM 3
# echo "Cloning and installing SAM 3..."
# cd third_party
# git clone https://github.com/facebookresearch/sam3.git
# pip install -e ./sam3
# cd ..

# 6. Install current repo in editable mode
echo "Installing current repo..."
Expand Down
33 changes: 19 additions & 14 deletions vggt_slam/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from vggt_slam.graph import PoseGraph
from vggt_slam.scale_solver import estimate_scale_pairwise
from vggt_slam.viewer import Viewer
from vggt.utils.eval_utils import get_vgg_input_imgs, load_images_rgb


DEBUG = False

Expand Down Expand Up @@ -298,13 +300,16 @@ def sample_pixel_coordinates(self, H, W, n):
def run_predictions(self, image_names, model, max_loops, clip_model, clip_preprocess):
device = "cuda" if torch.cuda.is_available() else "cpu"
t1 = time.time()
with self.vggt_timer:
images = load_and_preprocess_images(image_names).to(device)

images = load_images_rgb(image_names)
images_array = np.stack(images)
vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
print(f"Loaded and preprocessed {len(image_names)} images in {time.time() - t1:.2f} seconds")
print(f"Preprocessed images shape: {images.shape}")
print(f"Preprocessed images shape: {images_array.shape}")

# print("Running inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
# dtype = torch.float16

# First submap so set new pcd num to 0
if self.map.get_largest_key() is None:
Expand All @@ -315,9 +320,9 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro
print(f"Creating new submap with id {new_pcd_num}")
t1 = time.time()
new_submap = Submap(new_pcd_num)
new_submap.add_all_frames(images)
new_submap.add_all_frames(vgg_input.to(torch.float16).cpu())
new_submap.set_frame_ids(image_names)
new_submap.set_last_non_loop_frame_index(images.shape[0] - 1)
new_submap.set_last_non_loop_frame_index(images_array.shape[0] - 1)
new_submap.set_all_retrieval_vectors(self.image_retrieval.get_all_submap_embeddings(new_submap))
new_submap.set_img_names(image_names)

Expand All @@ -328,12 +333,11 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro

self.current_working_submap = new_submap
print(f"Created new submap in {time.time() - t1:.2f} seconds")
model.update_patch_dimensions(patch_width, patch_height)

with torch.no_grad():
t1 = time.time()
with self.vggt_timer:
predictions = model(images)
print(f"VGGT model inference took {time.time() - t1:.2f} seconds")
with torch.amp.autocast(device, dtype=dtype):
predictions = model(vgg_input.to(device).to(dtype))

# Check for loop closures and add retrieval vectors from new submap to the database
predictions_lc = None
Expand All @@ -344,10 +348,11 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro
print(colored("detected_loops", "yellow"), detected_loops)
retrieved_frames = self.map.get_frames_from_loops(detected_loops)
with torch.no_grad():
lc_frames = torch.stack((new_submap.get_frame_at_index(detected_loops[0].query_submap_frame), retrieved_frames[0]), axis=0)
predictions_lc = model(lc_frames, compute_similarity=True)
loop_closure_frame_names = [new_submap.get_img_names_at_index(detected_loops[0].query_submap_frame),
self.map.get_submap(detected_loops[0].detected_submap_id).get_img_names_at_index(detected_loops[0].detected_submap_frame)]
with torch.amp.autocast(device, dtype=dtype):
lc_frames = torch.stack((new_submap.get_frame_at_index(detected_loops[0].query_submap_frame), retrieved_frames[0]), axis=0)
predictions_lc = model(lc_frames.to(device), compute_similarity=True)
loop_closure_frame_names = [new_submap.get_img_names_at_index(detected_loops[0].query_submap_frame),
self.map.get_submap(detected_loops[0].detected_submap_id).get_img_names_at_index(detected_loops[0].detected_submap_frame)]

# Visualize loop closure frames
if DEBUG:
Expand All @@ -361,7 +366,7 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro
plt.show()

print("Converting pose encoding to extrinsic and intrinsic matrices...")
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], vgg_input.shape[-2:])
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic

Expand Down
10 changes: 9 additions & 1 deletion vggt_slam/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@ def __init__(self, port: int = 8080):
print(f"Starting viser server on port {port}")

self.server = viser.ViserServer(host="0.0.0.0", port=port)
self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
# self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
self.server.gui.configure_theme(
titlebar_content=None,
control_layout="collapsible",
dark_mode=True,
brand_color=(20, 20, 20),
)
black_img = np.zeros((1, 1, 3), dtype=np.uint8)
self.server.scene.set_background_image(black_img, format="jpeg")

# --- GUI Elements ---
self.gui_show_frames = self.server.gui.add_checkbox("Show Cameras", initial_value=True)
Expand Down