run_onnx.py
我在使用静态onnx的时候可以正确运行,输入尺寸为320x320,使用动态onnx的时候
宽度过小的时候会报错
ONNX inference error: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'/bridge_model/stereo_encoder/Reshape_64' Status Message: D:\a_work\1\s\onnxruntime\core/providers/cpu/tensor/reshape_helper.h:30 onnxruntime::ReshapeHelper::ReshapeHelper i < input_shape.NumDimensions() was false. The dimension with value zero exceeds the dimension size of the input tensor.
宽度固定在320会报错
2025-11-06 11:12:02.5807984 [E:onnxruntime:, sequential_executor.cc:572 onnxruntime::ExecuteKernel] Non-zero status code returned while running Reshape node. Name:'/bridge_model/align_0/layers.1/attn/nmp/attn/Reshape_3' Status Message: D:\a_work\1\s\onnxruntime\core/providers/cpu/tensor/reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{28,4,72,72}, requested shape:{28,49,4,72,72}
高度是小于320的,但是没有固定值,可能小到32,可能大到320
下面是转为onnx的代码
"""
Export the BridgeDepth model to ONNX for TensorRT conversion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os, sys, argparse
code_dir = os.path.dirname(os.path.abspath(file))
sys.path.append(f'{code_dir}/../')
from bridgedepth.bridgedepth import BridgeDepth
from bridgedepth.monocular.dinov2.models.vision_transformer import DinoVisionTransformer
import torch.nn.functional as F
import math
class SimplifiedBridgeDepth(nn.Module):
"""replace the BridgeDepth, remove all the problem operations"""
def __init__(self, bridge_model):
super().__init__()
self.bridge_model = bridge_model
self.bridge_model.eval()
# replace all the problem operations
self._patch_all_operations()
def _patch_all_operations(self):
"""fix all the ONNX incompatible operations"""
# 遍历所有子模块,找到DinoVisionTransformer的实例
for module in self.bridge_model.modules():
if isinstance(module, DinoVisionTransformer):
print("🎯 Found DinoVisionTransformer, applying ONNX patch for interpolate_pos_encoding...")
# 使用一个闭包工厂函数来正确捕获当前的`module`实例
def make_new_interpolate_func(dino_module):
# 这就是我们新的、对ONNX导出友好的函数实现
def onnx_friendly_interpolate_pos_encoding(x, w, h):
"""
这个版本被重写以避免TracerWarning,并确保尺寸计算是动态的。
"""
# 获取模型参数和输入形状信息
pos_embed = dino_module.pos_embed
patch_size = dino_module.patch_size
num_tokens = dino_module.num_tokens # 通常是1 (CLS token)
npatch = x.shape[1] - num_tokens
N = pos_embed.shape[1] - num_tokens
# 这是一个优化路径,如果输入尺寸恰好和预训练尺寸一致,直接返回
# 这个if语句会产生TracerWarning,但在这种情况下是可接受的,
# 因为我们主要关心的是动态尺寸的路径。
if npatch == N and w == h:
return pos_embed
# --- 核心动态计算逻辑 ---
pos_embed_float = pos_embed.float()
class_pos_embed = pos_embed_float[:, 0:num_tokens] # 兼容有无register_token的情况
patch_pos_embed = pos_embed_float[:, num_tokens:]
dim = x.shape[-1]
# 动态计算目标网格尺寸 h0, w0
# 这里的 w 和 h 是Python int,这是从上层函数 x.shape 解包得到的
# 这是ONNX追踪的难点,但 F.interpolate 的 size 参数可以接受 int
w0 = w // patch_size
h0 = h // patch_size
# 预训练时的网格尺寸
M = int(math.sqrt(N))
if M * M != N:
raise ValueError("预训练的位置编码数量必须是完全平方数。")
# 将位置编码从 1D 序列重塑为 2D 网格
patch_pos_embed_grid = patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2)
# 使用 F.interpolate 进行动态缩放
# 这是最关键的操作。我们传递动态计算出的 h0, w0
# ONNX应该将此转换为一个Resize节点,其目标尺寸依赖于输入 h, w
resized_patch_pos_embed = F.interpolate(
patch_pos_embed_grid,
size=(h0, w0),
mode="bicubic",
align_corners=False # 关键:对动态尺寸更鲁棒
)
# 将缩放后的2D网格重塑回1D序列
patch_pos_embed_flat = resized_patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
# 将CLS token的位置编码和缩放后的图像块位置编码拼接起来
final_pos_embed = torch.cat((class_pos_embed, patch_pos_embed_flat), dim=1)
return final_pos_embed.to(x.dtype)
# 返回新创建的函数
return onnx_friendly_interpolate_pos_encoding
# 将原始方法替换为我们新创建的、绑定了正确`module`实例的函数
module.interpolate_pos_encoding = make_new_interpolate_func(module)
# 1. replace fill_diagonal
original_fill_diagonal = torch.Tensor.fill_diagonal_
def safe_fill_diagonal(tensor, value, wrap=False):
if tensor.dim() >= 2:
diag_size = min(tensor.shape[-2], tensor.shape[-1])
diag_mask = torch.eye(diag_size, dtype=torch.bool, device=tensor.device)
if tensor.dim() > 2:
for _ in range(tensor.dim() - 2):
diag_mask = diag_mask.unsqueeze(0)
diag_mask = diag_mask.expand_as(tensor)
tensor.masked_fill_(diag_mask, value)
return tensor
torch.Tensor.fill_diagonal_ = safe_fill_diagonal
# 2. replace interpolate
original_interpolate = F.interpolate
def safe_interpolate(input, size=None, scale_factor=None, mode='nearest',
align_corners=None, recompute_scale_factor=None, **kwargs):
# remove all the parameters that may cause problems
return original_interpolate(
input, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor
)
F.interpolate = safe_interpolate
# 3. fix prepare_input
original_prepare_input = self.bridge_model.prepare_input
def onnx_prepare_input(inputs):
img1 = inputs["img1"].to(self.bridge_model.device)
img2 = inputs["img2"].to(self.bridge_model.device)
# simplify the padding logic
if not self.bridge_model.training:
from bridgedepth.utils.frame_utils import InputPadder
self.bridge_model.padder = InputPadder(img1.shape, mode="nmrf", divis_by=16)
img1, img2 = self.bridge_model.padder.pad(img1, img2)
else:
self.bridge_model.padder = None
H, W = img1.shape[-2:]
mono_size = (H // 8 * 7, W // 8 * 7)
mono_input = F.interpolate(img1, size=mono_size, mode="bilinear", align_corners=False)
inputs["img1"] = img1
inputs["img2"] = img2
inputs["image"] = mono_input.sub_(self.bridge_model.mean).div_(self.bridge_model.std)
return inputs
self.bridge_model.prepare_input = onnx_prepare_input
def forward(self, img1, img2):
inputs = {'img1': img1, 'img2': img2}
results = self.bridge_model(inputs)
disparity = results['disp_pred']
if disparity.dim() == 4 and disparity.shape[1] == 1:
disparity = disparity.squeeze(1)
return disparity
def export_onnx_for_tensorrt():
"""export the ONNX for TensorRT"""
print("🎯 export the ONNX for TensorRT")
# load the model
# model = BridgeDepth.from_pretrained('./checkpoints/hospital/step_112000.pth')
model = BridgeDepth.from_pretrained('./checkpoints/bridge_rvc_pretrain.pth')
model = model.to(torch.device("cuda")).eval()
# simplify the model
simplified_model = SimplifiedBridgeDepth(model)
MAX=320
H = MAX
W = MAX
# prepare the input
dummy_input1 = torch.randn(1, 3, H, W).cuda()
dummy_input2 = torch.randn(1, 3, H, W).cuda()
# test the simplified model
print("🧪 test the simplified model...")
with torch.no_grad():
output = simplified_model(dummy_input1, dummy_input2)
print(f"✅ output shape: {output.shape}")
# try multiple opset versions
for opset in [17, 18]:
output_file = f"./onnx/bridgedepth_for_trt_opset{opset}_{H}_{W}_dynamic.onnx"
print(f"📤 try to export the ONNX (opset {opset})...")
try:
torch.onnx.export(
simplified_model,
(dummy_input1, dummy_input2),
output_file,
input_names=["left_image", "right_image"],
output_names=["disparity"],
opset_version=opset,
do_constant_folding=True,
dynamic_axes={
'left_image': {0: 'batch_size', 2: 'height', 3: 'width'},
'right_image': {0: 'batch_size', 2: 'height', 3: 'width'},
'disparity': {0: 'batch_size', 1: 'height', 2: 'width'}
# 'left_image': {0: 'batch_size'},
# 'right_image': {0: 'batch_size'},
# 'disparity': {0: 'batch_size'}
},
verbose=True
)
# verify the ONNX
try:
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
import os
file_size = os.path.getsize(output_file) / (1024**2)
print(f"✅ ONNX export success: {output_file}")
print(f"📊 file size: {file_size:.1f} MB")
# generate the trtexec command
print(f"\n🛠️ correct trtexec command:")
print(f"# basic conversion")
print(f"trtexec --onnx={output_file} --saveEngine=./onnx/bridgedepth_opset{opset}.trt")
print(f"")
print(f"# FP16 optimization")
print(f"trtexec --onnx={output_file} --saveEngine=./onnx/bridgedepth_opset{opset}_fp16.trt --fp16")
print(f"")
print(f"# specify the input shape")
print(f"trtexec --onnx={output_file} --saveEngine=./onnx/bridgedepth_opset{opset}_shaped.trt --shapes=left_image:1x3x480x640,right_image:1x3x480x640")
return output_file
except Exception as e:
print(f"⚠️ ONNX verification failed: {e}")
continue
except Exception as e:
print(f"❌ opset {opset} export failed: {str(e)[:100]}...")
continue
print("💥 all ONNX exports failed")
return None
if name == 'main':
onnx_file = export_onnx_for_tensorrt()
if onnx_file:
print(f"\n🎉 success! now run:")
print(f"trtexec --onnx={onnx_file} --saveEngine=./onnx/bridgedepth.trt --fp16")
else:
print(f"\n💡 suggest using torch-tensorrt directly")
run_onnx.py
我在使用静态onnx的时候可以正确运行,输入尺寸为320x320,使用动态onnx的时候
宽度过小的时候会报错
ONNX inference error: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'/bridge_model/stereo_encoder/Reshape_64' Status Message: D:\a_work\1\s\onnxruntime\core/providers/cpu/tensor/reshape_helper.h:30 onnxruntime::ReshapeHelper::ReshapeHelper i < input_shape.NumDimensions() was false. The dimension with value zero exceeds the dimension size of the input tensor.
宽度固定在320会报错
2025-11-06 11:12:02.5807984 [E:onnxruntime:, sequential_executor.cc:572 onnxruntime::ExecuteKernel] Non-zero status code returned while running Reshape node. Name:'/bridge_model/align_0/layers.1/attn/nmp/attn/Reshape_3' Status Message: D:\a_work\1\s\onnxruntime\core/providers/cpu/tensor/reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{28,4,72,72}, requested shape:{28,49,4,72,72}
高度是小于320的,但是没有固定值,可能小到32,可能大到320
下面是转为onnx的代码
"""
Export the BridgeDepth model to ONNX for TensorRT conversion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os, sys, argparse
code_dir = os.path.dirname(os.path.abspath(file))
sys.path.append(f'{code_dir}/../')
from bridgedepth.bridgedepth import BridgeDepth
from bridgedepth.monocular.dinov2.models.vision_transformer import DinoVisionTransformer
import torch.nn.functional as F
import math
class SimplifiedBridgeDepth(nn.Module):
"""replace the BridgeDepth, remove all the problem operations"""
def export_onnx_for_tensorrt():
"""export the ONNX for TensorRT"""
print("🎯 export the ONNX for TensorRT")
if name == 'main':
onnx_file = export_onnx_for_tensorrt()