From 37b2e7563c1aaff0a2c3e3f8a29fc4e895ed9b66 Mon Sep 17 00:00:00 2001 From: stepeos Date: Mon, 13 Apr 2026 01:04:04 +0300 Subject: [PATCH 1/2] add token merge --- README.md | 1 - main.py | 12 ++++++++---- setup.sh | 28 ++++++++++++++-------------- vggt_slam/solver.py | 36 ++++++++++++++++++++++-------------- vggt_slam/viewer.py | 10 +++++++++- 5 files changed, 53 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index c33c9f4..36f8d2c 100644 --- a/README.md +++ b/README.md @@ -182,4 +182,3 @@ If our code is helpful, please cite our papers as follows: year={2026} } ``` - diff --git a/main.py b/main.py index d6d0962..479f6eb 100644 --- a/main.py +++ b/main.py @@ -69,12 +69,16 @@ def main(): clip_model, clip_preprocess = None, None clip_tokenizer = None - model = VGGT() - _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" - model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) + # model = VGGT() + # _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" + # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) + use_point_map = True + model = VGGT(enable_point=use_point_map, enable_track=False) + ckpt = torch.load(os.getenv("HOME", "")+"/model_tracker_fixed_e20.pt", map_location="cpu") + incompat = model.load_state_dict(ckpt, strict=False) model.eval() - model = model.to(torch.bfloat16) # use half precision + model = model.to(torch.float16) # use half precision model = model.to(device) # Use the provided image folder path diff --git a/setup.sh b/setup.sh index 10d1520..0d23905 100755 --- a/setup.sh +++ b/setup.sh @@ -17,23 +17,23 @@ cd .. # 3. Clone and install our fork of VGGT echo "Cloning and installing VGGT..." cd third_party -git clone https://github.com/MIT-SPARK/VGGT_SPARK.git vggt +git clone https://github.com/stepeos/SparkFastVGGT.git vggt pip install -e ./vggt cd .. -# 4. Install Perception Encoder -echo "Cloning and installing Perception Encoder..." -cd third_party -git clone https://github.com/facebookresearch/perception_models.git -pip install -e ./perception_models -cd .. - -# 5. Install SAM 3 -echo "Cloning and installing SAM 3..." -cd third_party -git clone https://github.com/facebookresearch/sam3.git -pip install -e ./sam3 -cd .. +# # 4. Install Perception Encoder +# echo "Cloning and installing Perception Encoder..." +# cd third_party +# git clone https://github.com/facebookresearch/perception_models.git +# pip install -e ./perception_models +# cd .. + +# # 5. Install SAM 3 +# echo "Cloning and installing SAM 3..." +# cd third_party +# git clone https://github.com/facebookresearch/sam3.git +# pip install -e ./sam3 +# cd .. # 6. Install current repo in editable mode echo "Installing current repo..." diff --git a/vggt_slam/solver.py b/vggt_slam/solver.py index cf1ffeb..18eac12 100644 --- a/vggt_slam/solver.py +++ b/vggt_slam/solver.py @@ -20,6 +20,11 @@ from vggt_slam.graph import PoseGraph from vggt_slam.scale_solver import estimate_scale_pairwise from vggt_slam.viewer import Viewer +<<<<<<< HEAD +======= +from vggt.utils.eval_utils import get_vgg_input_imgs, load_images_rgb + +>>>>>>> 8ed9c22 (add token merge) DEBUG = False @@ -298,13 +303,16 @@ def sample_pixel_coordinates(self, H, W, n): def run_predictions(self, image_names, model, max_loops, clip_model, clip_preprocess): device = "cuda" if torch.cuda.is_available() else "cpu" t1 = time.time() - with self.vggt_timer: - images = load_and_preprocess_images(image_names).to(device) + + images = load_images_rgb(image_names) + images_array = np.stack(images) + vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array) print(f"Loaded and preprocessed {len(image_names)} images in {time.time() - t1:.2f} seconds") - print(f"Preprocessed images shape: {images.shape}") + print(f"Preprocessed images shape: {images_array.shape}") # print("Running inference...") dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + # dtype = torch.float16 # First submap so set new pcd num to 0 if self.map.get_largest_key() is None: @@ -315,9 +323,9 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro print(f"Creating new submap with id {new_pcd_num}") t1 = time.time() new_submap = Submap(new_pcd_num) - new_submap.add_all_frames(images) + new_submap.add_all_frames(vgg_input.to(torch.float16).cpu()) new_submap.set_frame_ids(image_names) - new_submap.set_last_non_loop_frame_index(images.shape[0] - 1) + new_submap.set_last_non_loop_frame_index(images_array.shape[0] - 1) new_submap.set_all_retrieval_vectors(self.image_retrieval.get_all_submap_embeddings(new_submap)) new_submap.set_img_names(image_names) @@ -328,12 +336,11 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro self.current_working_submap = new_submap print(f"Created new submap in {time.time() - t1:.2f} seconds") + model.update_patch_dimensions(patch_width, patch_height) with torch.no_grad(): - t1 = time.time() - with self.vggt_timer: - predictions = model(images) - print(f"VGGT model inference took {time.time() - t1:.2f} seconds") + with torch.amp.autocast(device, dtype=dtype): + predictions = model(vgg_input.to(device).to(dtype)) # Check for loop closures and add retrieval vectors from new submap to the database predictions_lc = None @@ -344,10 +351,11 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro print(colored("detected_loops", "yellow"), detected_loops) retrieved_frames = self.map.get_frames_from_loops(detected_loops) with torch.no_grad(): - lc_frames = torch.stack((new_submap.get_frame_at_index(detected_loops[0].query_submap_frame), retrieved_frames[0]), axis=0) - predictions_lc = model(lc_frames, compute_similarity=True) - loop_closure_frame_names = [new_submap.get_img_names_at_index(detected_loops[0].query_submap_frame), - self.map.get_submap(detected_loops[0].detected_submap_id).get_img_names_at_index(detected_loops[0].detected_submap_frame)] + with torch.amp.autocast(device, dtype=dtype): + lc_frames = torch.stack((new_submap.get_frame_at_index(detected_loops[0].query_submap_frame), retrieved_frames[0]), axis=0) + predictions_lc = model(lc_frames.to(device), compute_similarity=True) + loop_closure_frame_names = [new_submap.get_img_names_at_index(detected_loops[0].query_submap_frame), + self.map.get_submap(detected_loops[0].detected_submap_id).get_img_names_at_index(detected_loops[0].detected_submap_frame)] # Visualize loop closure frames if DEBUG: @@ -361,7 +369,7 @@ def run_predictions(self, image_names, model, max_loops, clip_model, clip_prepro plt.show() print("Converting pose encoding to extrinsic and intrinsic matrices...") - extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) + extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], vgg_input.shape[-2:]) predictions["extrinsic"] = extrinsic predictions["intrinsic"] = intrinsic diff --git a/vggt_slam/viewer.py b/vggt_slam/viewer.py index c259b29..6b59bc0 100644 --- a/vggt_slam/viewer.py +++ b/vggt_slam/viewer.py @@ -11,7 +11,15 @@ def __init__(self, port: int = 8080): print(f"Starting viser server on port {port}") self.server = viser.ViserServer(host="0.0.0.0", port=port) - self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") + # self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") + self.server.gui.configure_theme( + titlebar_content=None, + control_layout="collapsible", + dark_mode=True, + brand_color=(20, 20, 20), + ) + black_img = np.zeros((1, 1, 3), dtype=np.uint8) + self.server.scene.set_background_image(black_img, format="jpeg") # --- GUI Elements --- self.gui_show_frames = self.server.gui.add_checkbox("Show Cameras", initial_value=True) From 8d6e5f82a1543eb0cfe1783636e2fc49f6a06522 Mon Sep 17 00:00:00 2001 From: stepeos Date: Thu, 7 May 2026 18:29:40 +0300 Subject: [PATCH 2/2] remove merge error --- vggt_slam/solver.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vggt_slam/solver.py b/vggt_slam/solver.py index 18eac12..079a684 100644 --- a/vggt_slam/solver.py +++ b/vggt_slam/solver.py @@ -20,11 +20,8 @@ from vggt_slam.graph import PoseGraph from vggt_slam.scale_solver import estimate_scale_pairwise from vggt_slam.viewer import Viewer -<<<<<<< HEAD -======= from vggt.utils.eval_utils import get_vgg_input_imgs, load_images_rgb ->>>>>>> 8ed9c22 (add token merge) DEBUG = False