diff --git a/depth_anything/dpt.py b/depth_anything/dpt.py index d09c0e7..e793f2c 100644 --- a/depth_anything/dpt.py +++ b/depth_anything/dpt.py @@ -141,7 +141,7 @@ def forward(self, out_features, patch_h, patch_w,need_fp=False,need_prior=False, class DPT_DINOv2(nn.Module): - def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True, version='v1'): + def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True, version='v1', to_onnx=False): super(DPT_DINOv2, self).__init__() assert encoder in ['vits', 'vitb', 'vitl'] @@ -153,6 +153,7 @@ def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1 } self.encoder = encoder self.version = version + self.to_onnx = to_onnx # in case the Internet connection is not stable, please load the DINOv2 locally # if localhub: # self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=True) @@ -178,7 +179,10 @@ def forward(self, x,need_fp=False,teacher_features=None,alpha=0.8,prior_mode='te depth_all = self.depth_head(features, patch_h, patch_w,need_fp,teacher_features,alpha) depth=depth_all['out'] depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) - depth = F.relu(depth).squeeze(1) + if self.to_onnx: + depth = F.relu(depth) # .squeeze(1) has been removed for tensorrt + else: + depth = F.relu(depth).squeeze(1) depth_out['out']=depth return depth_out diff --git a/tools/infer_onnx.py b/tools/infer_onnx.py new file mode 100644 index 0000000..05cddd6 --- /dev/null +++ b/tools/infer_onnx.py @@ -0,0 +1,111 @@ +import onnxruntime as ort +import cv2 +import numpy as np +import torch # <--- IMPORT TORCH +import torch.nn.functional as F +from PIL import Image +import matplotlib.pyplot as plt +from matplotlib import cm + +def normalize_depth(depth): + # This function normalizes a 2D depth map for visualization + eps = 1e-6 + depth_min = np.min(depth) + depth_max = np.max(depth) + normalized_depth = (depth - depth_min) / (depth_max - depth_min + eps) + return normalized_depth + +def preprocess_image(image_path, target_size=518): + raw_image = cv2.imread(image_path) + if raw_image is None: + raise ValueError(f"Cannot read image: {image_path}") + + # The model expects a square input, so we resize directly + image_rgb = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) + image_resized = cv2.resize(image_rgb, (target_size, target_size), interpolation=cv2.INTER_CUBIC) + + image_float = image_resized.astype(np.float32) / 255.0 + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + image_normalized = (image_float - mean) / std + + # HWC to NCHW + image_transposed = image_normalized.transpose(2, 0, 1) + image_tensor = np.expand_dims(image_transposed, axis=0).astype(np.float32) + + return image_tensor, (raw_image.shape[0], raw_image.shape[1]) + +def postprocess_depth(depth_tensor_numpy, original_size): + """ + Postprocesses the raw 4D model output to a 2D depth map. + - Converts numpy to torch tensor. + - Resizes to original image dimensions. + - Converts back to numpy. + """ + # 1. Convert the numpy array to a torch tensor + # The input is already in the correct 4D shape (1, 1, H, W) + depth_tensor_torch = torch.from_numpy(depth_tensor_numpy) + + h, w = original_size + + # 2. Interpolate using torch.nn.functional + # align_corners=False is the modern default and generally recommended + depth_resized = F.interpolate(depth_tensor_torch, size=(h, w), mode='bilinear', align_corners=False) + + # 3. Squeeze, convert back to a CPU numpy array + depth_output = depth_resized.squeeze().cpu().numpy() + + return depth_output + +def save_depth_map(depth, output_path, colormap='inferno'): + # Assumes depth is already normalized to 0-1 range + depth_raw_path = output_path.replace('.png', '_raw.npy') + np.save(depth_raw_path, depth) + + if colormap == 'inferno': + depth_colored = (plt.get_cmap(colormap)(depth)[:, :, :3] * 255).astype(np.uint8) + elif colormap == 'spectral': + spectral_cmap = cm.get_cmap('Spectral_r') + depth_colored = (spectral_cmap(depth) * 255).astype(np.uint8)[:, :, :3] + else: # Grayscale + depth_colored = (depth * 255).astype(np.uint8) + + # Convert to BGR for OpenCV + depth_colored_bgr = cv2.cvtColor(depth_colored, cv2.COLOR_RGB2BGR) + + cv2.imwrite(output_path, depth_colored_bgr) + print(f"Depth map saved: {output_path}") + print(f"Raw depth data saved: {depth_raw_path}") + +def infer_onnx(onnx_path, image_path, output_path, colormap='inferno'): + print(f"Processing: {image_path}") + # Use CPUExecutionProvider as requested by CUDA_VISIBLE_DEVICES=-1 + sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + + image_tensor, original_size = preprocess_image(image_path) + print(f"Input tensor shape: {image_tensor.shape}") + print(f"Original image size: {original_size}") + + # The model outputs disparity, which is inversely proportional to depth + disparity = sess.run([output_name], {input_name: image_tensor})[0] + print(f"Model output shape (disparity): {disparity.shape}") + + # Postprocess (resize to original size) + # disparity is a 4D numpy array (1, 1, H, W) + depth_resized = postprocess_depth(disparity, original_size) + print(f"Resized depth map shape: {depth_resized.shape}") + + # Normalize the final depth map for visualization + depth_normalized = normalize_depth(depth_resized) + + save_depth_map(depth_normalized, output_path, colormap) + +if __name__ == '__main__': + onnx_path = '/home/e300/code/DepthAnythingAC/depth_anything_AC_vits.onnx' + image_path = '/home/e300/Downloads/WhatsApp Image 2025-09-26 at 1.11.48 PM.jpeg' + output_path = 'syed_depth.png' # Changed output name for clarity + + infer_onnx(onnx_path, image_path, output_path) \ No newline at end of file diff --git a/tools/to_onnx.py b/tools/to_onnx.py new file mode 100644 index 0000000..997b1ae --- /dev/null +++ b/tools/to_onnx.py @@ -0,0 +1,60 @@ +# convert_to_onnx.py +import torch +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from depth_anything.dpt import DepthAnything_AC + +def load_model(model_path, encoder='vits'): + model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384], 'version': 'v2', 'to_onnx': True} + } + model = DepthAnything_AC(model_configs[encoder]) + checkpoint = torch.load(model_path, map_location='cpu') + model.load_state_dict(checkpoint, strict=False) + model.eval() + return model + +def export_to_onnx(model, onnx_path, input_size=(518, 518), encoder='vits'): + # Dummy input with dynamic batch size + dummy_input = torch.randn(1, 3, input_size[0], input_size[1]) + + # Since model returns a dict, we need to wrap it to extract 'out' + class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + output = self.model(x) + return output['out'] # Extract the 'out' tensor from the dictionary + + wrapped_model = ModelWrapper(model) + wrapped_model.eval() # Good practice to set wrapper to eval mode too + + # Export to ONNX with dynamic axes + torch.onnx.export( + wrapped_model, # <--- CORRECTED: Use the wrapper + dummy_input, + onnx_path, + export_params=True, + opset_version=12, # Opset 12 or higher is generally better if your TRT supports it + do_constant_folding=True, + input_names=['input'], + output_names=['depth'], + dynamic_axes={ + 'input': {0: 'batch_size', 2: 'height', 3: 'width'}, + 'depth': {0: 'batch_size', 2: 'height', 3: 'width'} + }, + verbose=False # Set to True for debugging if needed + ) + print(f"Model exported to ONNX: {onnx_path}") + +if __name__ == '__main__': + model_path = 'checkpoints/depth_anything_AC_vits.pth' # Replace with your model path + onnx_path = 'depth_anything_AC_vits.onnx' # Output ONNX file + encoder = 'vits' # Adjust if needed + model = load_model(model_path, encoder) + export_to_onnx(model, onnx_path, encoder=encoder) \ No newline at end of file