Skip to content

我在使用生成的动态onnx模型出现很多报错 #14

@1311523821

Description

@1311523821

run_onnx.py

我在使用静态onnx的时候可以正确运行,输入尺寸为320x320,使用动态onnx的时候
宽度过小的时候会报错
ONNX inference error: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'/bridge_model/stereo_encoder/Reshape_64' Status Message: D:\a_work\1\s\onnxruntime\core/providers/cpu/tensor/reshape_helper.h:30 onnxruntime::ReshapeHelper::ReshapeHelper i < input_shape.NumDimensions() was false. The dimension with value zero exceeds the dimension size of the input tensor.

宽度固定在320会报错
2025-11-06 11:12:02.5807984 [E:onnxruntime:, sequential_executor.cc:572 onnxruntime::ExecuteKernel] Non-zero status code returned while running Reshape node. Name:'/bridge_model/align_0/layers.1/attn/nmp/attn/Reshape_3' Status Message: D:\a_work\1\s\onnxruntime\core/providers/cpu/tensor/reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{28,4,72,72}, requested shape:{28,49,4,72,72}

高度是小于320的,但是没有固定值,可能小到32,可能大到320
下面是转为onnx的代码

"""
Export the BridgeDepth model to ONNX for TensorRT conversion
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import os, sys, argparse
code_dir = os.path.dirname(os.path.abspath(file))
sys.path.append(f'{code_dir}/../')
from bridgedepth.bridgedepth import BridgeDepth

from bridgedepth.monocular.dinov2.models.vision_transformer import DinoVisionTransformer
import torch.nn.functional as F
import math
class SimplifiedBridgeDepth(nn.Module):
"""replace the BridgeDepth, remove all the problem operations"""

def __init__(self, bridge_model):
    super().__init__()
    self.bridge_model = bridge_model
    self.bridge_model.eval()
    
    # replace all the problem operations
    self._patch_all_operations()

def _patch_all_operations(self):
    """fix all the ONNX incompatible operations"""
    # 遍历所有子模块,找到DinoVisionTransformer的实例
    for module in self.bridge_model.modules():
        if isinstance(module, DinoVisionTransformer):
            print("🎯 Found DinoVisionTransformer, applying ONNX patch for interpolate_pos_encoding...")

            # 使用一个闭包工厂函数来正确捕获当前的`module`实例
            def make_new_interpolate_func(dino_module):
                
                # 这就是我们新的、对ONNX导出友好的函数实现
                def onnx_friendly_interpolate_pos_encoding(x, w, h):
                    """
                    这个版本被重写以避免TracerWarning,并确保尺寸计算是动态的。
                    """
                    # 获取模型参数和输入形状信息
                    pos_embed = dino_module.pos_embed
                    patch_size = dino_module.patch_size
                    num_tokens = dino_module.num_tokens # 通常是1 (CLS token)
                    
                    npatch = x.shape[1] - num_tokens
                    N = pos_embed.shape[1] - num_tokens
                    
                    # 这是一个优化路径,如果输入尺寸恰好和预训练尺寸一致,直接返回
                    # 这个if语句会产生TracerWarning,但在这种情况下是可接受的,
                    # 因为我们主要关心的是动态尺寸的路径。
                    if npatch == N and w == h:
                        return pos_embed

                    # --- 核心动态计算逻辑 ---
                    pos_embed_float = pos_embed.float()
                    class_pos_embed = pos_embed_float[:, 0:num_tokens] # 兼容有无register_token的情况
                    patch_pos_embed = pos_embed_float[:, num_tokens:]
                    dim = x.shape[-1]

                    # 动态计算目标网格尺寸 h0, w0
                    # 这里的 w 和 h 是Python int,这是从上层函数 x.shape 解包得到的
                    # 这是ONNX追踪的难点,但 F.interpolate 的 size 参数可以接受 int
                    w0 = w // patch_size
                    h0 = h // patch_size

                    # 预训练时的网格尺寸
                    M = int(math.sqrt(N))
                    if M * M != N:
                        raise ValueError("预训练的位置编码数量必须是完全平方数。")

                    # 将位置编码从 1D 序列重塑为 2D 网格
                    patch_pos_embed_grid = patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2)
                    
                    # 使用 F.interpolate 进行动态缩放
                    # 这是最关键的操作。我们传递动态计算出的 h0, w0
                    # ONNX应该将此转换为一个Resize节点,其目标尺寸依赖于输入 h, w
                    resized_patch_pos_embed = F.interpolate(
                        patch_pos_embed_grid,
                        size=(h0, w0),
                        mode="bicubic",
                        align_corners=False # 关键:对动态尺寸更鲁棒
                    )

                    # 将缩放后的2D网格重塑回1D序列
                    patch_pos_embed_flat = resized_patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

                    # 将CLS token的位置编码和缩放后的图像块位置编码拼接起来
                    final_pos_embed = torch.cat((class_pos_embed, patch_pos_embed_flat), dim=1)
                    
                    return final_pos_embed.to(x.dtype)

                # 返回新创建的函数
                return onnx_friendly_interpolate_pos_encoding

            # 将原始方法替换为我们新创建的、绑定了正确`module`实例的函数
            module.interpolate_pos_encoding = make_new_interpolate_func(module)
    # 1. replace fill_diagonal
    original_fill_diagonal = torch.Tensor.fill_diagonal_
    def safe_fill_diagonal(tensor, value, wrap=False):
        if tensor.dim() >= 2:
            diag_size = min(tensor.shape[-2], tensor.shape[-1])
            diag_mask = torch.eye(diag_size, dtype=torch.bool, device=tensor.device)
            if tensor.dim() > 2:
                for _ in range(tensor.dim() - 2):
                    diag_mask = diag_mask.unsqueeze(0)
                diag_mask = diag_mask.expand_as(tensor)
            tensor.masked_fill_(diag_mask, value)
        return tensor
    
    torch.Tensor.fill_diagonal_ = safe_fill_diagonal
    
    # 2. replace interpolate
    original_interpolate = F.interpolate
    def safe_interpolate(input, size=None, scale_factor=None, mode='nearest',
                    align_corners=None, recompute_scale_factor=None, **kwargs):
        # remove all the parameters that may cause problems
        return original_interpolate(
            input, size=size, scale_factor=scale_factor, mode=mode,
            align_corners=align_corners, recompute_scale_factor=recompute_scale_factor
        )
    
    F.interpolate = safe_interpolate
    
    # 3. fix prepare_input
    original_prepare_input = self.bridge_model.prepare_input
    def onnx_prepare_input(inputs):
        img1 = inputs["img1"].to(self.bridge_model.device)
        img2 = inputs["img2"].to(self.bridge_model.device)

        # simplify the padding logic
        if not self.bridge_model.training:
            from bridgedepth.utils.frame_utils import InputPadder
            self.bridge_model.padder = InputPadder(img1.shape, mode="nmrf", divis_by=16)
            img1, img2 = self.bridge_model.padder.pad(img1, img2)
        else:
            self.bridge_model.padder = None
            
        H, W = img1.shape[-2:]
        mono_size = (H // 8 * 7, W // 8 * 7)
        mono_input = F.interpolate(img1, size=mono_size, mode="bilinear", align_corners=False)

        inputs["img1"] = img1
        inputs["img2"] = img2
        inputs["image"] = mono_input.sub_(self.bridge_model.mean).div_(self.bridge_model.std)
        return inputs
    
    self.bridge_model.prepare_input = onnx_prepare_input

def forward(self, img1, img2):
    inputs = {'img1': img1, 'img2': img2}
    results = self.bridge_model(inputs)
    disparity = results['disp_pred']
    
    if disparity.dim() == 4 and disparity.shape[1] == 1:
        disparity = disparity.squeeze(1)
    
    return disparity

def export_onnx_for_tensorrt():
"""export the ONNX for TensorRT"""
print("🎯 export the ONNX for TensorRT")

# load the model
# model = BridgeDepth.from_pretrained('./checkpoints/hospital/step_112000.pth')
model = BridgeDepth.from_pretrained('./checkpoints/bridge_rvc_pretrain.pth')

model = model.to(torch.device("cuda")).eval()

# simplify the model
simplified_model = SimplifiedBridgeDepth(model)
MAX=320
H = MAX
W = MAX
# prepare the input
dummy_input1 = torch.randn(1, 3, H, W).cuda()
dummy_input2 = torch.randn(1, 3, H, W).cuda()

# test the simplified model
print("🧪 test the simplified model...")
with torch.no_grad():
    output = simplified_model(dummy_input1, dummy_input2)
    print(f"✅ output shape: {output.shape}")

# try multiple opset versions
for opset in [17, 18]:
    output_file = f"./onnx/bridgedepth_for_trt_opset{opset}_{H}_{W}_dynamic.onnx"
    
    print(f"📤 try to export the ONNX (opset {opset})...")
    try:
        torch.onnx.export(
            simplified_model,
            (dummy_input1, dummy_input2),
            output_file,
            input_names=["left_image", "right_image"],
            output_names=["disparity"],
            opset_version=opset,
            do_constant_folding=True,
            dynamic_axes={
                'left_image': {0: 'batch_size', 2: 'height', 3: 'width'},
                'right_image': {0: 'batch_size', 2: 'height', 3: 'width'},
                'disparity': {0: 'batch_size', 1: 'height', 2: 'width'}
                # 'left_image': {0: 'batch_size'},
                # 'right_image': {0: 'batch_size'},
                # 'disparity': {0: 'batch_size'}
            },
            verbose=True
        )
        
        # verify the ONNX
        try:
            import onnx
            onnx_model = onnx.load(output_file)
            onnx.checker.check_model(onnx_model)
            
            import os
            file_size = os.path.getsize(output_file) / (1024**2)
            
            print(f"✅ ONNX export success: {output_file}")
            print(f"📊 file size: {file_size:.1f} MB")
            
            # generate the trtexec command
            print(f"\n🛠️  correct trtexec command:")
            print(f"# basic conversion")
            print(f"trtexec --onnx={output_file} --saveEngine=./onnx/bridgedepth_opset{opset}.trt")
            print(f"")
            print(f"# FP16 optimization")
            print(f"trtexec --onnx={output_file} --saveEngine=./onnx/bridgedepth_opset{opset}_fp16.trt --fp16")
            print(f"")
            print(f"# specify the input shape")
            print(f"trtexec --onnx={output_file} --saveEngine=./onnx/bridgedepth_opset{opset}_shaped.trt --shapes=left_image:1x3x480x640,right_image:1x3x480x640")
            
            return output_file
            
        except Exception as e:
            print(f"⚠️  ONNX verification failed: {e}")
            continue
            
    except Exception as e:
        print(f"❌ opset {opset} export failed: {str(e)[:100]}...")
        continue

print("💥 all ONNX exports failed")
return None

if name == 'main':
onnx_file = export_onnx_for_tensorrt()

if onnx_file:
    print(f"\n🎉 success! now run:")
    print(f"trtexec --onnx={onnx_file} --saveEngine=./onnx/bridgedepth.trt --fp16")
else:
    print(f"\n💡 suggest using torch-tensorrt directly")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions