From c069cbdc14ae73d802f0d1c32d2456aa5ae850cf Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Wed, 8 Apr 2026 23:26:25 +0700 Subject: [PATCH 1/4] bumped onnx to 1.21 and fixed/improved some related dependencies --- modelconverter/packages/base_exporter.py | 27 +++-- modelconverter/utils/config.py | 49 +++++---- modelconverter/utils/onnx_compatibility.py | 72 +++++++++++++ modelconverter/utils/onnx_tools.py | 44 ++++---- modelconverter/utils/types.py | 8 ++ requirements.txt | 2 +- tests/test_utils/test_onnx_compatibility.py | 112 ++++++++++++++++++++ 7 files changed, 259 insertions(+), 55 deletions(-) create mode 100644 modelconverter/utils/onnx_compatibility.py create mode 100644 tests/test_utils/test_onnx_compatibility.py diff --git a/modelconverter/packages/base_exporter.py b/modelconverter/packages/base_exporter.py index d65f126..96d9d71 100644 --- a/modelconverter/packages/base_exporter.py +++ b/modelconverter/packages/base_exporter.py @@ -6,7 +6,6 @@ from typing import Any import numpy as np -import onnx from loguru import logger from modelconverter.utils import ( @@ -20,6 +19,7 @@ RandomCalibrationConfig, SingleStageConfig, ) +from modelconverter.utils.onnx_compatibility import save_onnx_model from modelconverter.utils.subprocess import SubprocessResult from modelconverter.utils.types import InputFileType, Target @@ -129,7 +129,13 @@ def simplify_onnx(self) -> Path: ) return self.input_model - onnx_sim, check = simplify(str(self.input_model)) + try: + onnx_sim, check = simplify(str(self.input_model)) + except Exception as e: + logger.warning( + f"Failed to simplify ONNX: {e}. Proceeding without simplification." + ) + return self.input_model if not check: logger.warning( "Provided ONNX could not be simplified. " @@ -141,15 +147,14 @@ def simplify_onnx(self) -> Path: self.input_model, "simplified.onnx" ) logger.info(f"Saving simplified ONNX to {onnx_sim_path}") - if self.input_model.with_suffix(".onnx_data").exists(): - onnx.save( - onnx_sim, - str(onnx_sim_path), - save_as_external_data=True, - location=f"{onnx_sim_path.name}_data", - ) - else: - onnx.save(onnx_sim, str(onnx_sim_path)) + save_onnx_model( + onnx_sim, + onnx_sim_path, + save_as_external_data=self.input_model.with_suffix( + ".onnx_data" + ).exists(), + location=f"{onnx_sim_path.name}_data", + ) return onnx_sim_path @abstractmethod diff --git a/modelconverter/utils/config.py b/modelconverter/utils/config.py index 568966d..ac1ae1c 100644 --- a/modelconverter/utils/config.py +++ b/modelconverter/utils/config.py @@ -22,6 +22,7 @@ from modelconverter.utils.filesystem_utils import resolve_path from modelconverter.utils.layout import make_default_layout from modelconverter.utils.metadata import Metadata, get_metadata +from modelconverter.utils.onnx_compatibility import save_onnx_model from modelconverter.utils.types import ( DataType, Encoding, @@ -641,14 +642,9 @@ def _get_onnx_node_info( f"Output value info for node '{node_name}' not found." ) - shape = [ - dim.dim_value for dim in output_value_info.type.tensor_type.shape.dim - ] - if any(dim == 0 for dim in shape): - raise ValueError( - "Dynamic shapes are not supported. " - f"Shape of node '{node_name}' is {shape}." - ) + shape = _get_static_onnx_shape( + output_value_info.type.tensor_type, f"node '{node_name}'" + ) data_type = output_value_info.type.tensor_type.elem_type return shape, DataType.from_onnx_dtype(data_type) @@ -662,12 +658,7 @@ def _get_onnx_tensor_info( def extract_tensor_info( tensor_type: TypeProto.Tensor, ) -> tuple[list[int], DataType]: - shape = [dim.dim_value for dim in tensor_type.shape.dim] - if any(dim == 0 for dim in shape): - raise ValueError( - "Dynamic shapes are not supported. " - f"Shape of tensor '{tensor_name}' is {shape}." - ) + shape = _get_static_onnx_shape(tensor_type, f"tensor '{tensor_name}'") return shape, DataType.from_onnx_dtype(tensor_type.elem_type) for tensor in chain(model.graph.input, model.graph.output): @@ -687,6 +678,21 @@ def extract_tensor_info( raise NameError(f"Tensor '{tensor_name}' not found in the ONNX model.") +def _get_static_onnx_shape( + tensor_type: TypeProto.Tensor, tensor_name: str +) -> list[int]: + shape = [] + for dim in tensor_type.shape.dim: + if dim.HasField("dim_value") and dim.dim_value > 0: + shape.append(dim.dim_value) + else: + raise ValueError( + "Dynamic shapes are not supported. " + f"Shape of {tensor_name} is {[d.dim_value for d in tensor_type.shape.dim]}." + ) + return shape + + def _get_onnx_inter_info( model_path: Path, name: str ) -> tuple[list[int] | None, DataType | None]: @@ -739,12 +745,9 @@ def generate_renamed_onnx( if output_name in rename_dict: node.output[i] = rename_dict[output_name] - if model_data_path: - onnx.save( - model, - str(output_path), - save_as_external_data=True, - location=f"{output_path.name}_data", - ) - else: - onnx.save(model, str(output_path)) + save_onnx_model( + model, + output_path, + save_as_external_data=model_data_path is not None, + location=f"{output_path.name}_data", + ) diff --git a/modelconverter/utils/onnx_compatibility.py b/modelconverter/utils/onnx_compatibility.py new file mode 100644 index 0000000..7ba3b84 --- /dev/null +++ b/modelconverter/utils/onnx_compatibility.py @@ -0,0 +1,72 @@ +from pathlib import Path + +import ml_dtypes +import numpy as np +import onnx +from onnx.external_data_helper import convert_model_to_external_data + + +def ensure_onnx_helper_compatibility() -> None: + helper = onnx.helper + + def _convert_scalar( + value: float, dtype: np.dtype, container: np.dtype + ) -> int: + arr = np.asarray(value, dtype=dtype) + return arr.view(container).item() + + if not hasattr(helper, "float32_to_bfloat16"): + helper.float32_to_bfloat16 = lambda value: _convert_scalar( # type: ignore[attr-defined] + value, ml_dtypes.bfloat16, np.uint16 + ) + + if not hasattr(helper, "float32_to_float8e4m3"): + dtype_map = { + (False, False): ml_dtypes.float8_e4m3, + (True, False): ml_dtypes.float8_e4m3fn, + (True, True): ml_dtypes.float8_e4m3fnuz, + (False, True): ml_dtypes.float8_e4m3b11fnuz, + } + + def float32_to_float8e4m3( + value: float, *, fn: bool = True, uz: bool = False + ) -> int: + return _convert_scalar(value, dtype_map[(fn, uz)], np.uint8) + + helper.float32_to_float8e4m3 = float32_to_float8e4m3 # type: ignore[attr-defined] + + +def save_onnx_model( + model: onnx.ModelProto, + output_path: str | Path, + *, + save_as_external_data: bool = False, + location: str | None = None, +) -> None: + output_path = Path(output_path) + + if save_as_external_data: + external_data_path = output_path.with_name( + location or f"{output_path.name}_data" + ) + if external_data_path.exists(): + external_data_path.unlink() + convert_model_to_external_data( + model, + all_tensors_to_one_file=True, + location=external_data_path.name, + size_threshold=0, + convert_attribute=False, + ) + onnx.save( + model, + str(output_path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_data_path.name, + size_threshold=0, + convert_attribute=False, + ) + return + + onnx.save(model, str(output_path)) diff --git a/modelconverter/utils/onnx_tools.py b/modelconverter/utils/onnx_tools.py index 1d916d3..79bf7a9 100644 --- a/modelconverter/utils/onnx_tools.py +++ b/modelconverter/utils/onnx_tools.py @@ -4,16 +4,25 @@ import numpy as np import onnx -import onnx_graphsurgeon as gs import onnxruntime as ort from loguru import logger from onnx import TensorProto, checker, helper from onnxsim import simplify from modelconverter.utils.config import InputConfig, OutputConfig +from modelconverter.utils.onnx_compatibility import ( + ensure_onnx_helper_compatibility, + save_onnx_model, +) from .exceptions import ONNXException +ensure_onnx_helper_compatibility() + +# GraphSurgeon still imports helper conversion functions removed in ONNX 1.21. +# Patch them back in before importing GraphSurgeon. +import onnx_graphsurgeon as gs # noqa: E402 + def get_opset_version(model: onnx.ModelProto) -> int: for imp in model.opset_import: @@ -284,15 +293,12 @@ def onnx_attach_normalization_to_inputs( graph.initializer.extend(new_initializers) - if model_data_path: - onnx.save( - model, - str(save_path), - save_as_external_data=True, - location=f"{save_path.name}_data", - ) - else: - onnx.save(model, str(save_path)) + save_onnx_model( + model, + save_path, + save_as_external_data=model_data_path is not None, + location=f"{save_path.name}_data", + ) checker.check_model(str(save_path)) @@ -413,7 +419,7 @@ def optimize_onnx(self) -> None: with tempfile.NamedTemporaryFile( delete=True, suffix=".onnx" ) as tmp_onnx_file: - onnx.save( + save_onnx_model( optimized_onnx_model, tmp_onnx_file.name, save_as_external_data=True, @@ -440,15 +446,13 @@ def export_onnx(self) -> None: self.onnx_model.ir_version = min(self.onnx_model.ir_version, 10) - if self.has_external_data: - onnx.save( - self.onnx_model, - str(self.output_path), - save_as_external_data=True, - location=f"{self.output_path.name}_data", - ) - else: - onnx.save(self.onnx_model, self.output_path) + save_onnx_model( + self.onnx_model, + self.output_path, + save_as_external_data=self.has_external_data, + location=f"{self.output_path.name}_data", + ) + onnx.checker.check_model(str(self.output_path)) def add_outputs(self, output_names: list[str]) -> None: """Add output nodes to the ONNX model. diff --git a/modelconverter/utils/types.py b/modelconverter/utils/types.py index d2bace3..66ecf0b 100644 --- a/modelconverter/utils/types.py +++ b/modelconverter/utils/types.py @@ -20,6 +20,7 @@ class Encoding(Enum): class DataType(Enum): + BFLOAT16 = "bfloat16" FLOAT16 = "float16" FLOAT32 = "float32" FLOAT64 = "float64" @@ -28,6 +29,7 @@ class DataType(Enum): INT16 = "int16" INT32 = "int32" INT64 = "int64" + UINT4 = "uint4" UINT8 = "uint8" UINT16 = "uint16" UINT32 = "uint32" @@ -93,10 +95,13 @@ def from_dlc_dtype(cls, dtype: str) -> "DataType": @classmethod def from_onnx_dtype(cls, dtype: int) -> "DataType": dtype_map = { + TensorProto.BFLOAT16: "bfloat16", TensorProto.FLOAT16: "float16", TensorProto.FLOAT: "float32", TensorProto.DOUBLE: "float64", + TensorProto.INT4: "int4", TensorProto.UINT8: "uint8", + TensorProto.UINT4: "uint4", TensorProto.UINT16: "uint16", TensorProto.UINT32: "uint32", TensorProto.UINT64: "uint64", @@ -172,13 +177,16 @@ def from_ir_runtime_dtype(cls, dtype: str) -> "DataType": def as_numpy_dtype(self) -> np.dtype: return { + "bfloat16": np.float32, # Preserve bfloat16 range better than float16. "float16": np.float16, "float32": np.float32, "float64": np.float64, + "int4": np.int8, # NumPy has no 4-bit signed integer dtype. "int8": np.int8, "int16": np.int16, "int32": np.int32, "int64": np.int64, + "uint4": np.uint8, # NumPy has no 4-bit unsigned integer dtype. "uint8": np.uint8, "uint16": np.uint16, "uint32": np.uint32, diff --git a/requirements.txt b/requirements.txt index b9dacc0..8d4e5ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ Pillow luxonis-ml[data,nn_archive,s3,gcs]>=0.8.2 cyclopts==4.8.0 -onnx>=1.17.0,<1.19.0 +onnx==1.21.0 onnxruntime onnxsim docker diff --git a/tests/test_utils/test_onnx_compatibility.py b/tests/test_utils/test_onnx_compatibility.py new file mode 100644 index 0000000..1372827 --- /dev/null +++ b/tests/test_utils/test_onnx_compatibility.py @@ -0,0 +1,112 @@ +from pathlib import Path + +import numpy as np +import pytest +from onnx import TensorProto, helper, numpy_helper + +from modelconverter.packages.base_exporter import Exporter +from modelconverter.utils.config import generate_renamed_onnx +from modelconverter.utils.onnx_compatibility import ( + ensure_onnx_helper_compatibility, + save_onnx_model, +) +from modelconverter.utils.types import DataType + + +@pytest.mark.parametrize( + ("tensor_name", "expected"), + [ + ("BFLOAT16", DataType.BFLOAT16), + ("INT4", DataType.INT4), + ("UINT4", DataType.UINT4), + ], +) +def test_extended_onnx_dtype_support(tensor_name: str, expected: DataType): + if not hasattr(TensorProto, tensor_name): + pytest.skip(f"{tensor_name} is not available in this ONNX version") + assert ( + DataType.from_onnx_dtype(getattr(TensorProto, tensor_name)) == expected + ) + + +def test_simplify_onnx_falls_back_on_error( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +): + import onnx + import onnxsim + + input_tensor = helper.make_tensor_value_info( + "input0", TensorProto.FLOAT, [1, 4] + ) + output_tensor = helper.make_tensor_value_info( + "output0", TensorProto.FLOAT, [1, 4] + ) + node = helper.make_node("Identity", inputs=["input0"], outputs=["output0"]) + model = helper.make_model( + helper.make_graph( + [node], "SimplifyFallbackModel", [input_tensor], [output_tensor] + ) + ) + input_path = tmp_path / "fallback.onnx" + onnx.save(model, input_path) + + monkeypatch.setattr( + onnxsim, + "simplify", + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("boom")), + ) + + class DummyExporter: + input_model = input_path + _attach_suffix = staticmethod(Exporter._attach_suffix) + + assert Exporter.simplify_onnx(DummyExporter()) == input_path + + +def test_generate_renamed_onnx_overwrites_external_data(tmp_path: Path): + input_tensor = helper.make_tensor_value_info( + "input0", TensorProto.FLOAT, [1, 1024] + ) + output_tensor = helper.make_tensor_value_info( + "output0", TensorProto.FLOAT, [1, 1024] + ) + bias_tensor = numpy_helper.from_array( + np.arange(1024, dtype=np.float32).reshape(1, 1024), name="bias" + ) + node = helper.make_node( + "Add", inputs=["input0", "bias"], outputs=["output0"] + ) + model = helper.make_model( + helper.make_graph( + [node], + "ExternalDataModel", + [input_tensor], + [output_tensor], + initializer=[bias_tensor], + ), + producer_name="DummyModelProducer", + ) + + input_path = tmp_path / "external_input.onnx" + output_path = tmp_path / "external_output.onnx" + + save_onnx_model( + model, + input_path, + save_as_external_data=True, + location=f"{input_path.name}_data", + ) + assert input_path.with_name(f"{input_path.name}_data").exists() + + generate_renamed_onnx(input_path, {"output0": "renamed0"}, output_path) + assert output_path.with_name(f"{output_path.name}_data").exists() + + generate_renamed_onnx(input_path, {"output0": "renamed1"}, output_path) + assert output_path.with_name(f"{output_path.name}_data").exists() + + +def test_onnx_graphsurgeon_imports_with_onnx_121(): + ensure_onnx_helper_compatibility() + import onnx_graphsurgeon as gs + + assert gs is not None From b5764dd46de07c42fa16c1fd4d2ba1f07443c275 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Wed, 8 Apr 2026 23:52:31 +0700 Subject: [PATCH 2/4] pinning ONNX versions per platform --- modelconverter/packages/hailo/requirements.txt | 1 + modelconverter/packages/rvc4/requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/modelconverter/packages/hailo/requirements.txt b/modelconverter/packages/hailo/requirements.txt index b7211a9..267bfb3 100644 --- a/modelconverter/packages/hailo/requirements.txt +++ b/modelconverter/packages/hailo/requirements.txt @@ -3,3 +3,4 @@ nvidia-dali-tf-plugin-cuda120==1.49.0 protobuf==3.20.3 matplotlib==3.10.6 pyparsing==2.4.7 +onnx==1.18.0 # Hailo SDK still imports onnx.mapping, removed in newer ONNX. diff --git a/modelconverter/packages/rvc4/requirements.txt b/modelconverter/packages/rvc4/requirements.txt index bbd159c..a595c24 100644 --- a/modelconverter/packages/rvc4/requirements.txt +++ b/modelconverter/packages/rvc4/requirements.txt @@ -3,3 +3,4 @@ psutil numpy<2 polars pytest # this is actually required by snpe packages +onnx==1.18.0 # SNPE's ONNX importer fails with onnx==1.21.0 in CI. From e93796426dc4aeb9390770a5c00ef2e1818835f6 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Thu, 9 Apr 2026 13:25:25 +0700 Subject: [PATCH 3/4] Don't rebuild image every time in CI because it causes device to run of out space. Also some docker utils refactoring --- modelconverter/__main__.py | 20 +++++++++++-- modelconverter/utils/__init__.py | 4 +++ modelconverter/utils/docker_utils.py | 45 +++++++++++++++++++++++++++- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/modelconverter/__main__.py b/modelconverter/__main__.py index b4ce3f0..a5711d5 100644 --- a/modelconverter/__main__.py +++ b/modelconverter/__main__.py @@ -33,6 +33,8 @@ archive_from_model, docker_build, docker_exec, + get_default_target_version, + get_local_docker_image, in_docker, resolve_path, upload_to_remote, @@ -630,9 +632,21 @@ def launcher( target = bound.arguments["target"] if dev: - docker_build( - target.value, bare_tag=tag, version=tool_version, image=image - ) + version = tool_version or get_default_target_version(target.value) + # CI invokes multiple dev docker commands per job; reuse the first + # local build so later commands don't rebuild the same image again. + if not ( + os.getenv("CI") == "true" + and get_local_docker_image( + target.value, + bare_tag=tag, + version=version, + image=image, + ) + ): + docker_build( + target.value, bare_tag=tag, version=version, image=image + ) docker_exec( target.value, diff --git a/modelconverter/utils/__init__.py b/modelconverter/utils/__init__.py index 5315e14..beb3f6e 100644 --- a/modelconverter/utils/__init__.py +++ b/modelconverter/utils/__init__.py @@ -7,7 +7,9 @@ docker_exec, get_container_memory_available, get_container_memory_limit, + get_default_target_version, get_docker_image, + get_local_docker_image, in_docker, ) from .environ import environ @@ -64,8 +66,10 @@ "get_archive_input", "get_container_memory_available", "get_container_memory_limit", + "get_default_target_version", "get_docker_image", "get_extra_quant_tensors", + "get_local_docker_image", "get_metadata", "get_protocol", "guess_new_layout", diff --git a/modelconverter/utils/docker_utils.py b/modelconverter/utils/docker_utils.py index 4c8195d..2540e90 100644 --- a/modelconverter/utils/docker_utils.py +++ b/modelconverter/utils/docker_utils.py @@ -346,9 +346,26 @@ def get_docker_image( ) -> str: check_docker() + local_image = get_local_docker_image(target, bare_tag, version, image) + if local_image is not None: + return local_image + + candidate_images = _get_candidate_docker_images( + target, bare_tag, version, image + ) + return _get_or_build_docker_image( + target, bare_tag, version, candidate_images, image + ) + + +def _get_candidate_docker_images( + target: Literal["rvc2", "rvc3", "rvc4", "hailo"], + bare_tag: str, + version: str, + image: str | None = None, +) -> list[str]: tag_version = rvc4_tag_version(version) if target == "rvc4" else version tag = f"{tag_version}-{bare_tag}" - client = get_docker_client_from_active_context() if image is not None: image_repo, image_tag = parse_repository_tag(image) @@ -368,6 +385,21 @@ def get_docker_image( if tag_version != version and image_tag is None: candidate_images.append(f"{image_repo}:{version}-{bare_tag}") + return candidate_images + + +def get_local_docker_image( + target: Literal["rvc2", "rvc3", "rvc4", "hailo"], + bare_tag: str, + version: str, + image: str | None = None, +) -> str | None: + check_docker() + + candidate_images = _get_candidate_docker_images( + target, bare_tag, version, image + ) + client = get_docker_client_from_active_context() candidate_tags = set() for candidate in candidate_images: candidate_tags.add(candidate) @@ -379,6 +411,17 @@ def get_docker_image( if tags: return next(iter(tags)) + return None + + +def _get_or_build_docker_image( + target: Literal["rvc2", "rvc3", "rvc4", "hailo"], + bare_tag: str, + version: str, + candidate_images: list[str], + image: str | None = None, +) -> str: + client = get_docker_client_from_active_context() for candidate in candidate_images: logger.warning( f"Image '{candidate}' not found locally, pulling " From 6209646f5da2345920885cd0dcf139f2318fc8e0 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Thu, 9 Apr 2026 13:25:46 +0700 Subject: [PATCH 4/4] pin ONNX to 1.17 for Hailo --- modelconverter/packages/hailo/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelconverter/packages/hailo/requirements.txt b/modelconverter/packages/hailo/requirements.txt index 267bfb3..def2ebc 100644 --- a/modelconverter/packages/hailo/requirements.txt +++ b/modelconverter/packages/hailo/requirements.txt @@ -3,4 +3,4 @@ nvidia-dali-tf-plugin-cuda120==1.49.0 protobuf==3.20.3 matplotlib==3.10.6 pyparsing==2.4.7 -onnx==1.18.0 # Hailo SDK still imports onnx.mapping, removed in newer ONNX. +onnx==1.17.0 # Hailo SDK still imports onnx.mapping and still pins protobuf==3.20.3.