From 33faeba00612df19d1328f786d4416035c304e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 09:37:45 +0200 Subject: [PATCH 1/9] Add support for ONNX import --- Project.toml | 3 + ext/ArrayDiffONNXExt.jl | 304 ++++++++++++++++++++++++++++++ src/ArrayDiff.jl | 17 ++ test/ONNXExt.jl | 404 ++++++++++++++++++++++++++++++++++++++++ test/Project.toml | 2 + test/runtests.jl | 1 + 6 files changed, 731 insertions(+) create mode 100644 ext/ArrayDiffONNXExt.jl create mode 100644 test/ONNXExt.jl diff --git a/Project.toml b/Project.toml index eb36f69..7320524 100644 --- a/Project.toml +++ b/Project.toml @@ -18,9 +18,11 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [weakdeps] MathOptAI = "e52c2cb8-508e-4e12-9dd2-9c4755b60e73" +ONNX = "d0dd6a25-fac6-55c0-abf7-829e0c774d20" [extensions] ArrayDiffMathOptAIExt = "MathOptAI" +ArrayDiffONNXExt = "ONNX" [compat] Calculus = "0.5.2" @@ -31,6 +33,7 @@ JuMP = "1.29.4" MathOptAI = "0.2" MathOptInterface = "1.40" NaNMath = "1" +ONNX = "0.2" OrderedCollections = "1.8.1" SparseArrays = "1.10" SpecialFunctions = "2.6.1" diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl new file mode 100644 index 0000000..a99d3ca --- /dev/null +++ b/ext/ArrayDiffONNXExt.jl @@ -0,0 +1,304 @@ +module ArrayDiffONNXExt + +import ArrayDiff +import MathOptInterface as MOI +import ONNX + +const ANF = ArrayDiff.ArrayNonlinearFunction +const _Shape = Tuple{Vararg{Int}} +# An entry in the conversion env: the value plus its array shape. +# Value may be `ANF`, `MOI.ScalarNonlinearFunction`, `MOI.VariableIndex`, +# `Float64`, `Vector{Float64}`, or `Matrix{Float64}`. Shape is `()` for scalars. +const _Entry = Tuple{Any,_Shape} + +# ── Attribute helpers ─────────────────────────────────────────────────────── + +function _find_attr(node::ONNX.NodeProto, name::AbstractString) + for a in node.attribute + if a.name == name + return a + end + end + return nothing +end + +function _attr_int(node, name; default::Int) + a = _find_attr(node, name) + return a === nothing ? default : Int(a.i) +end + +function _attr_float(node, name; default::Float64) + a = _find_attr(node, name) + return a === nothing ? default : Float64(a.f) +end + +function _attr_ints(node, name) + a = _find_attr(node, name) + return a === nothing ? Int[] : Int[Int(x) for x in a.ints] +end + +function _attr_tensor(node, name) + a = _find_attr(node, name) + return a === nothing ? nothing : a.t +end + +# ── TensorProto → Julia array ─────────────────────────────────────────────── +# ONNX stores tensor values in C (row-major) order. Julia is column-major, +# so a 2D tensor of dims=(m, n) must be reshaped to (n, m) then permuted. + +function _tensor_to_array(t::ONNX.TensorProto) + DT = getfield(ONNX, Symbol("TensorProto.DataType")) + dims = Int[Int(d) for d in t.dims] + flat = if t.data_type == Int32(DT.DOUBLE) && !isempty(t.double_data) + Float64[x for x in t.double_data] + elseif t.data_type == Int32(DT.FLOAT) && !isempty(t.float_data) + Float64[Float64(x) for x in t.float_data] + elseif !isempty(t.raw_data) + if t.data_type == Int32(DT.FLOAT) + Float64.(reinterpret(Float32, t.raw_data)) + elseif t.data_type == Int32(DT.DOUBLE) + Float64.(reinterpret(Float64, t.raw_data)) + else + error("Unsupported raw_data type: $(t.data_type)") + end + else + error("Unsupported tensor encoding for '$(t.name)'") + end + if isempty(dims) + @assert length(flat) == 1 + return (flat[1], ()) + elseif length(dims) == 1 + return (flat, (dims[1],)) + elseif length(dims) == 2 + m, n = dims + # ONNX row-major (m, n) → Julia column-major: reshape (n, m), permute. + mat = collect(permutedims(reshape(flat, (n, m)), (2, 1))) + return (mat, (m, n)) + else + error("Tensors with ndim > 2 not supported (got dims=$dims for '$(t.name)')") + end +end + +# ── User-supplied input wrapping ──────────────────────────────────────────── + +_wrap_input(v::Vector{MOI.VariableIndex}) = + (ANF{1}(:vect, Any[v...], (length(v),), false), (length(v),)) + +function _wrap_input(M::Matrix{MOI.VariableIndex}) + m, n = size(M) + rows = Any[ANF{1}(:row, Any[M[i, j] for j in 1:n], (n,), false) for i in 1:m] + return (ANF{2}(:vcat, rows, (m, n), false), (m, n)) +end + +_wrap_input(x::ANF{N}) where {N} = (x, x.size) +_wrap_input(x::Real) = (Float64(x), ()) +_wrap_input(x::Vector{<:Real}) = (Vector{Float64}(x), (length(x),)) +_wrap_input(x::Matrix{<:Real}) = (Matrix{Float64}(x), size(x)) + +_wrap_input(x) = error("Unsupported input value type: $(typeof(x))") + +# ── Shape arithmetic ──────────────────────────────────────────────────────── + +function _broadcast_shape(a::_Shape, b::_Shape) + # NumPy / ONNX semantics: align trailing dims; each pair must match or one is 1. + if a == () return b end + if b == () return a end + n = max(length(a), length(b)) + pa = ntuple(i -> i <= length(a) ? a[end - length(a) + i + (n - n)] : 1, n) # placeholder + # simpler explicit loop + out = Vector{Int}(undef, n) + for i in 1:n + da = i <= length(a) ? a[end - i + 1] : 1 + db = i <= length(b) ? b[end - i + 1] : 1 + if da == db + out[n - i + 1] = da + elseif da == 1 + out[n - i + 1] = db + elseif db == 1 + out[n - i + 1] = da + else + error("Incompatible broadcast shapes $a vs $b") + end + end + return tuple(out...) +end + +# ── Op-emit helpers ───────────────────────────────────────────────────────── + +# Broadcasted multivariate op +function _bcall(op::Symbol, args::Vector{Any}, shape::_Shape) + N = length(shape) + return ANF{N}(op, args, shape, true) +end + +# Non-broadcast multivariate op (used for matmul) +function _call(op::Symbol, args::Vector{Any}, shape::_Shape) + N = length(shape) + return ANF{N}(op, args, shape, false) +end + +# ── Per-op conversion ─────────────────────────────────────────────────────── + +function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) + op = node.op_type + if op == "Identity" + return env[node.input[1]] + + elseif op == "Constant" + t = _attr_tensor(node, "value") + t === nothing && error("Constant node '$(node.name)' has no 'value' attribute") + return _tensor_to_array(t) + + elseif op == "Add" + return _binop_broadcast(:+, node, env) + elseif op == "Sub" + return _binop_broadcast(:-, node, env) + elseif op == "Mul" + return _binop_broadcast(:*, node, env) + elseif op == "Div" + return _binop_broadcast(:/, node, env) + + elseif op == "Neg" + x, sx = env[node.input[1]] + return (_bcall(:-, Any[0.0, x], sx), sx) + + elseif op == "MatMul" + return _convert_matmul(node, env) + + elseif op == "Gemm" + return _convert_gemm(node, env) + + elseif op == "Relu" + x, sx = env[node.input[1]] + return (_bcall(:max, Any[x, 0.0], sx), sx) + + elseif op == "Tanh" + x, sx = env[node.input[1]] + return (_bcall(:tanh, Any[x], sx), sx) + + elseif op == "Sigmoid" + # 1 / (1 + exp(-x)), all broadcast over x's shape. + x, sx = env[node.input[1]] + negx = _bcall(:-, Any[0.0, x], sx) + ex = _bcall(:exp, Any[negx], sx) + one_plus = _bcall(:+, Any[1.0, ex], sx) + return (_bcall(:/, Any[1.0, one_plus], sx), sx) + + else + error("ONNX op '$(op)' is not supported by ArrayDiffONNXExt") + end +end + +function _binop_broadcast(op::Symbol, node, env) + a, sa = env[node.input[1]] + b, sb = env[node.input[2]] + s = _broadcast_shape(sa, sb) + return (_bcall(op, Any[a, b], s), s) +end + +function _convert_matmul(node, env) + a, sa = env[node.input[1]] + b, sb = env[node.input[2]] + # ArrayDiff's `:*` requires the left operand to be 2D (it walks left-to-right + # accumulating dims). For NumPy-style Vec × Mat = Vec, swap to Matᵀ × Vec — + # which requires `Mat` to be a constant so we can transpose at convert time. + if length(sa) == 1 && length(sb) == 2 + sa[1] == sb[1] || error("MatMul shape mismatch: $sa × $sb") + b isa AbstractMatrix{<:Real} || + error("MatMul Vec × Mat requires the matrix to be a constant initializer (got $(typeof(b)))") + bT = collect(permutedims(b)) + s = (sb[2],) + return (_call(:*, Any[bT, a], s), s) + elseif length(sa) == 2 && length(sb) == 1 + sa[2] == sb[1] || error("MatMul shape mismatch: $sa × $sb") + s = (sa[1],) + return (_call(:*, Any[a, b], s), s) + elseif length(sa) == 2 && length(sb) == 2 + sa[2] == sb[1] || error("MatMul shape mismatch: $sa × $sb") + s = (sa[1], sb[2]) + return (_call(:*, Any[a, b], s), s) + else + error("MatMul with shapes $sa, $sb not supported") + end +end + +function _convert_gemm(node, env) + A, sA = env[node.input[1]] + B, sB = env[node.input[2]] + has_C = length(node.input) >= 3 && !isempty(node.input[3]) + C, sC = has_C ? env[node.input[3]] : (nothing, ()) + α = _attr_float(node, "alpha"; default = 1.0) + β = _attr_float(node, "beta"; default = 1.0) + transA = _attr_int(node, "transA"; default = 0) != 0 + transB = _attr_int(node, "transB"; default = 0) != 0 + + transA && error("Gemm with transA=1 is not supported") + if transB + B isa AbstractMatrix{<:Real} || + error("Gemm with transB=1 requires B to be a constant tensor (got $(typeof(B)))") + B = collect(permutedims(B)) + sB = (sB[2], sB[1]) + end + + # ONNX Gemm requires 2D inputs; we follow the spec strictly. The user is + # expected to wrap a single sample as `Matrix{MOI.VariableIndex}` of shape + # `(1, K)`, matching the (batch, features) convention PyTorch exports use. + length(sA) == 2 && length(sB) == 2 || + error("Gemm requires 2D inputs (got A=$sA, B=$sB)") + sA[2] == sB[1] || error("Gemm: A=$sA, B=$sB shape mismatch") + AB_shape = (sA[1], sB[2]) + AB = _call(:*, Any[A, B], AB_shape) + + if α != 1.0 + AB = _bcall(:*, Any[α, AB], AB_shape) + end + + if !has_C || β == 0.0 + return (AB, AB_shape) + end + + Cterm = β == 1.0 ? C : _bcall(:*, Any[β, C], sC) + out_shape = _broadcast_shape(AB_shape, sC) + out = _bcall(:+, Any[AB, Cterm], out_shape) + return (out, out_shape) +end + +# ── Entry point ───────────────────────────────────────────────────────────── + +function ArrayDiff.from_onnx( + proto::ONNX.ModelProto; + inputs::AbstractDict = Dict{String,Any}(), +) + graph = proto.graph + env = Dict{String,_Entry}() + + for tp in graph.initializer + env[tp.name] = _tensor_to_array(tp) + end + + for inp in graph.input + if haskey(env, inp.name) + continue # already provided as initializer + end + haskey(inputs, String(inp.name)) || + error("ONNX graph input '$(inp.name)' has no supplied value") + env[inp.name] = _wrap_input(inputs[String(inp.name)]) + end + + for node in graph.node + result = _convert_node(node, env) + # All currently-supported ops produce exactly one output. + length(node.output) == 1 || + error("Multi-output op '$(node.op_type)' not supported") + env[node.output[1]] = result + end + + out_names = [String(o.name) for o in graph.output] + if length(out_names) == 1 + return env[out_names[1]][1] + else + return Dict(name => env[name][1] for name in out_names) + end +end + +end # module diff --git a/src/ArrayDiff.jl b/src/ArrayDiff.jl index c4b93c7..0601b35 100644 --- a/src/ArrayDiff.jl +++ b/src/ArrayDiff.jl @@ -67,6 +67,23 @@ include("evaluator.jl") include("array_nonlinear_function.jl") include("parse_moi.jl") +""" + from_onnx(model; inputs) + +Translate an ONNX `ModelProto` into a Julia `Expr` (or `Dict{String,Expr}` for +multi-output graphs) suitable for `ArrayDiff.set_objective` or for composing +further with `sum`, `LinearAlgebra.norm`, etc. + +`inputs` maps each ONNX graph-input name to the Julia value that should stand +in for it — typically a `Vector{MOI.VariableIndex}` or a +`Matrix{MOI.VariableIndex}` of the appropriate shape. Initializer tensors are +inlined as `Vector{Float64}` / `Matrix{Float64}` constants. + +The implementation lives in the package extension `ArrayDiffONNXExt`, which is +loaded automatically once `ONNX` is imported alongside `ArrayDiff`. +""" +function from_onnx end + model(::Mode{S}) where {S} = Model{eltype(S)}() # Extend MOI.Nonlinear.set_objective so that solvers calling diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl new file mode 100644 index 0000000..adc7df9 --- /dev/null +++ b/test/ONNXExt.jl @@ -0,0 +1,404 @@ +module TestONNXExt + +using Test +import LinearAlgebra +import ForwardDiff +import MathOptInterface as MOI + +import ArrayDiff +import ONNX # loads ArrayDiffONNXExt + +const AT = getfield(ONNX, Symbol("AttributeProto.AttributeType")) +const DT = getfield(ONNX, Symbol("TensorProto.DataType")) + +# ── ONNX protobuf builder helpers (kept in-test; not part of the package API) ─ + +# ONNX serializes row-major; Julia is column-major. Mirror the inverse of the +# loader's reshape so a Julia `Matrix` round-trips through (build → load). +function _flatten_row_major(A::AbstractArray) + if ndims(A) <= 1 + return Float64[Float64(x) for x in A] + elseif ndims(A) == 2 + return Float64[Float64(A[i, j]) for i in axes(A, 1) for j in axes(A, 2)] + else + error("Only 1D/2D tensors supported in tests") + end +end + +function _make_tensor(name::String, data::AbstractArray) + return ONNX.TensorProto( + dims = Int64[size(data)...], + data_type = Int32(DT.DOUBLE), + name = name, + double_data = _flatten_row_major(data), + ) +end + +function _make_scalar_tensor(name::String, v::Real) + return ONNX.TensorProto( + dims = Int64[], + data_type = Int32(DT.DOUBLE), + name = name, + double_data = Float64[Float64(v)], + ) +end + +function _attr_float(name::String, f::Real) + return ONNX.AttributeProto( + name = name, + f = Float32(f), + var"#type" = AT.FLOAT, + ) +end + +function _attr_int(name::String, i::Integer) + return ONNX.AttributeProto( + name = name, + i = Int64(i), + var"#type" = AT.INT, + ) +end + +function _attr_tensor(name::String, t::ONNX.TensorProto) + return ONNX.AttributeProto( + name = name, + t = t, + var"#type" = AT.TENSOR, + ) +end + +function _make_node( + op_type::String, + inputs::Vector{String}, + outputs::Vector{String}; + attrs::AbstractVector = ONNX.AttributeProto[], + name::String = "n", +) + # `AttributeProto` is parametric; an array literal of one concrete kind + # narrows its eltype, so re-wrap into the UnionAll-eltype vector that + # `NodeProto` expects. + attribute = ONNX.AttributeProto[a for a in attrs] + return ONNX.NodeProto( + input = inputs, + output = outputs, + name = name, + op_type = op_type, + domain = "", + attribute = attribute, + doc_string = "", + ) +end + +function _vinfo(name::String) + return ONNX.ValueInfoProto( + name = name, + var"#type" = nothing, + doc_string = "", + ) +end + +function _build_model( + nodes::Vector{ONNX.NodeProto}, + inputs::Vector{String}, + outputs::Vector{String}; + initializers::Vector{ONNX.TensorProto} = ONNX.TensorProto[], +) + g = ONNX.GraphProto( + node = nodes, + name = "g", + initializer = initializers, + input = ONNX.ValueInfoProto[_vinfo(n) for n in inputs], + output = ONNX.ValueInfoProto[_vinfo(n) for n in outputs], + ) + return ONNX.ModelProto( + ir_version = Int64(7), + producer_name = "test", + producer_version = "0", + domain = "", + model_version = Int64(0), + doc_string = "", + graph = g, + ) +end + +# ── Evaluation harness ─────────────────────────────────────────────────────── + +# Build an ArrayDiff model with objective = scalar_fn(from_onnx(proto)), then +# return (value, gradient) at `xv`. `vars` are the input variables that get +# bound to the graph's single input "x". +function _eval_with_gradient( + proto::ONNX.ModelProto, + vars::Vector{MOI.VariableIndex}, + xv::Vector{Float64}; + input = vars, # what gets bound to ONNX input "x" (Vector or Matrix of vars) + scalar_fn::Symbol = :dot, # :dot(out, out) — sum of squares +) + out = ArrayDiff.from_onnx(proto; inputs = Dict("x" => input)) + snf = if scalar_fn === :dot + MOI.ScalarNonlinearFunction(:dot, Any[out, out]) + elseif scalar_fn === :sum + MOI.ScalarNonlinearFunction(:sum, Any[out]) + else + error("unknown scalar_fn") + end + model = ArrayDiff.Model() + ArrayDiff.set_objective(model, snf) + evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), vars) + MOI.initialize(evaluator, [:Grad]) + val = MOI.eval_objective(evaluator, xv) + g = zeros(length(xv)) + MOI.eval_objective_gradient(evaluator, g, xv) + return val, g +end + +# ── Tests ──────────────────────────────────────────────────────────────────── + +function runtests() + for name in names(@__MODULE__; all = true) + if startswith("$(name)", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + +# Identity: y = x; ‖y‖² = ‖x‖² → gradient = 2x. +function test_identity() + n = 3 + vars = [MOI.VariableIndex(i) for i in 1:n] + node = _make_node("Identity", ["x"], ["y"]) + proto = _build_model([node], ["x"], ["y"]) + xv = [0.5, -1.2, 2.0] + val, g = _eval_with_gradient(proto, vars, xv) + @test val ≈ sum(xv .^ 2) + @test g ≈ 2 .* xv +end + +# Add with a constant vector bias: y = x .+ b +function test_add_constant_bias() + n = 4 + vars = [MOI.VariableIndex(i) for i in 1:n] + b = [0.1, -0.3, 0.7, 1.0] + init = _make_tensor("b", b) + node = _make_node("Add", ["x", "b"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init]) + xv = [1.0, 2.0, -1.5, 0.4] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = sum((x .+ b) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Sub / Mul / Div: cover the rest of the broadcasted-elementwise family. +function test_elementwise_sub_mul_div() + n = 3 + vars = [MOI.VariableIndex(i) for i in 1:n] + c = [2.0, -1.0, 0.5] + init = _make_tensor("c", c) + for (op, fjl) in [ + ("Sub", (x) -> sum((x .- c) .^ 2)), + ("Mul", (x) -> sum((x .* c) .^ 2)), + ("Div", (x) -> sum((x ./ c) .^ 2)), + ] + node = _make_node(op, ["x", "c"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init]) + xv = [0.7, 1.3, -0.4] + val, g = _eval_with_gradient(proto, vars, xv) + @test val ≈ fjl(xv) + @test g ≈ ForwardDiff.gradient(fjl, xv) + end +end + +# MatMul vector × matrix: y = x * W, W is (3, 2). Output is 1D length 2. +function test_matmul_vector_matrix() + vars = [MOI.VariableIndex(i) for i in 1:3] + W = [1.0 0.5; + -0.2 1.1; + 0.3 -0.7] + init = _make_tensor("W", W) + node = _make_node("MatMul", ["x", "W"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init]) + xv = [0.4, -1.0, 0.9] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = sum((x' * W) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Gemm without transB: y = α * (X * W) + β * b. X is shape (1, K). +function test_gemm_no_transpose() + vars = [MOI.VariableIndex(i) for i in 1:2] + var_mat = reshape(vars, 1, 2) # (1, 2) + W = [0.4 -0.6 0.2; + 1.1 0.3 -0.9] # (2, 3) + bias = [0.05, -0.1, 0.2] + α, β = 0.5, 2.0 + init_W = _make_tensor("W", W) + init_b = _make_tensor("b", bias) + node = _make_node( + "Gemm", + ["x", "W", "b"], + ["y"]; + attrs = [ + _attr_float("alpha", α), + _attr_float("beta", β), + _attr_int("transA", 0), + _attr_int("transB", 0), + ], + ) + proto = _build_model([node], ["x"], ["y"]; initializers = [init_W, init_b]) + xv = [0.8, -0.3] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((α .* (reshape(x, 1, 2) * W) .+ β .* reshape(bias, 1, 3)) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Gemm with transB=1: y = X * W' (PyTorch nn.Linear export pattern). +# W is stored as (out, in); Gemm transposes it. X is shape (1, in). +function test_gemm_transB() + vars = [MOI.VariableIndex(i) for i in 1:2] + var_mat = reshape(vars, 1, 2) + W = [0.4 0.5; + -0.1 1.0; + 0.9 -0.3] # (3, 2) — Linear(in=2, out=3) + bias = [0.0, 0.0, 0.0] + init_W = _make_tensor("W", W) + init_b = _make_tensor("b", bias) + node = _make_node( + "Gemm", + ["x", "W", "b"], + ["y"]; + attrs = [ + _attr_float("alpha", 1.0), + _attr_float("beta", 1.0), + _attr_int("transA", 0), + _attr_int("transB", 1), + ], + ) + proto = _build_model([node], ["x"], ["y"]; initializers = [init_W, init_b]) + xv = [1.1, -0.4] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((reshape(x, 1, 2) * W' .+ reshape(bias, 1, 3)) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Relu broadcast: y = max.(x, 0); gradient passes through positive entries. +function test_relu() + vars = [MOI.VariableIndex(i) for i in 1:4] + node = _make_node("Relu", ["x"], ["y"]) + proto = _build_model([node], ["x"], ["y"]) + xv = [1.0, -0.5, 0.3, -2.0] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = sum(max.(x, 0.0) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +function test_tanh() + vars = [MOI.VariableIndex(i) for i in 1:3] + node = _make_node("Tanh", ["x"], ["y"]) + proto = _build_model([node], ["x"], ["y"]) + xv = [0.2, -0.8, 1.5] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = sum(tanh.(x) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +function test_sigmoid() + vars = [MOI.VariableIndex(i) for i in 1:3] + node = _make_node("Sigmoid", ["x"], ["y"]) + proto = _build_model([node], ["x"], ["y"]) + xv = [0.1, -1.2, 0.7] + val, g = _eval_with_gradient(proto, vars, xv) + σ(t) = 1 / (1 + exp(-t)) + fjulia(x) = sum(σ.(x) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Constant op: y = Identity(Constant_value)*0 + x. Just check the value path. +function test_constant_then_add() + vars = [MOI.VariableIndex(i) for i in 1:3] + c = [0.5, -0.5, 1.0] + const_t = _make_tensor("c_value", c) + n1 = _make_node("Constant", String[], ["c"]; attrs = [_attr_tensor("value", const_t)], name = "k") + n2 = _make_node("Add", ["x", "c"], ["y"]) + proto = _build_model([n1, n2], ["x"], ["y"]) + xv = [0.3, 0.8, -0.4] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = sum((x .+ c) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# End-to-end: a 1-hidden-layer MLP with relu, matching what a PyTorch +# `nn.Sequential(nn.Linear(D, H), nn.ReLU(), nn.Linear(H, D_out))` would export. +function test_mlp_relu() + D_in, D_hidden, D_out = 3, 4, 2 + vars = [MOI.VariableIndex(i) for i in 1:D_in] + W1 = 0.3 * randn(MersenneTwisterRNG(), D_hidden, D_in) # (H, D_in) + b1 = 0.1 * randn(MersenneTwisterRNG(2), D_hidden) + W2 = 0.5 * randn(MersenneTwisterRNG(3), D_out, D_hidden) # (D_out, H) + b2 = 0.2 * randn(MersenneTwisterRNG(4), D_out) + init = [ + _make_tensor("W1", W1), + _make_tensor("b1", b1), + _make_tensor("W2", W2), + _make_tensor("b2", b2), + ] + gemm_attrs(transB) = [ + _attr_float("alpha", 1.0), + _attr_float("beta", 1.0), + _attr_int("transA", 0), + _attr_int("transB", transB), + ] + nodes = [ + _make_node("Gemm", ["x", "W1", "b1"], ["h_pre"]; attrs = gemm_attrs(1), name = "fc1"), + _make_node("Relu", ["h_pre"], ["h"]; name = "act"), + _make_node("Gemm", ["h", "W2", "b2"], ["y"]; attrs = gemm_attrs(1), name = "fc2"), + ] + proto = _build_model(nodes, ["x"], ["y"]; initializers = init) + xv = [0.5, -0.7, 1.1] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = begin + h = max.(x' * W1' .+ b1', 0.0) # row vec + y = h * W2' .+ b2' + sum(y .^ 2) + end + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Tiny RNG helper so the test is reproducible without needing the user's +# global RNG state. +import Random +MersenneTwisterRNG(seed::Int = 1) = Random.MersenneTwister(seed) + +# Error paths. +function test_unsupported_op_errors() + vars = [MOI.VariableIndex(i) for i in 1:2] + node = _make_node("LeakyRelu", ["x"], ["y"]) + proto = _build_model([node], ["x"], ["y"]) + @test_throws ErrorException ArrayDiff.from_onnx( + proto; + inputs = Dict("x" => vars), + ) +end + +function test_missing_input_errors() + proto = _build_model( + [_make_node("Identity", ["x"], ["y"])], + ["x"], + ["y"], + ) + @test_throws ErrorException ArrayDiff.from_onnx(proto; inputs = Dict()) +end + +end # module + +TestONNXExt.runtests() diff --git a/test/Project.toml b/test/Project.toml index fc73fa8..0e267ff 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ArrayDiff = "c45fa1ca-6901-44ac-ae5b-5513a4852d50" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GenOpt = "f2c049d8-7489-4223-990c-4f1c121a4cde" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" JSOSolvers = "10dff2fc-5484-5881-a0e0-c90441020f8a" @@ -13,6 +14,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e" NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" +ONNX = "d0dd6a25-fac6-55c0-abf7-829e0c774d20" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/runtests.jl b/test/runtests.jl index 47c82a6..3dd07b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ include("ReverseAD.jl") include("ArrayDiff.jl") include("JuMP.jl") include("MathOptAI.jl") +include("ONNXExt.jl") if VERSION >= v"1.11" # [sources] not supported on Julia v1.10 # Needs https://github.com/jump-dev/NLopt.jl/pull/273 From fd50601be3e0d549ae777fb57a005ffa88fba8e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 10:02:22 +0200 Subject: [PATCH 2/9] Fixes --- ext/ArrayDiffONNXExt.jl | 36 +++++++++++++++++++++++++++--------- test/ONNXExt.jl | 18 ++++++++++++------ 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl index a99d3ca..222cbcf 100644 --- a/ext/ArrayDiffONNXExt.jl +++ b/ext/ArrayDiffONNXExt.jl @@ -86,8 +86,17 @@ _wrap_input(v::Vector{MOI.VariableIndex}) = function _wrap_input(M::Matrix{MOI.VariableIndex}) m, n = size(M) - rows = Any[ANF{1}(:row, Any[M[i, j] for j in 1:n], (n,), false) for i in 1:m] - return (ANF{2}(:vcat, rows, (m, n), false), (m, n)) + row(i) = ANF{2}(:row, Any[M[i, j] for j in 1:n], (1, n), false) + if m == 1 + return (row(1), (1, n)) + end + # ArrayDiff's `:vcat` evaluator handles exactly two children (see + # reverse_mode.jl); fold left so each `:vcat` node stays binary. + acc = row(1) + for i in 2:m + acc = ANF{2}(:vcat, Any[acc, row(i)], (i, n), false) + end + return (acc, (m, n)) end _wrap_input(x::ANF{N}) where {N} = (x, x.size) @@ -199,16 +208,13 @@ end function _convert_matmul(node, env) a, sa = env[node.input[1]] b, sb = env[node.input[2]] - # ArrayDiff's `:*` requires the left operand to be 2D (it walks left-to-right - # accumulating dims). For NumPy-style Vec × Mat = Vec, swap to Matᵀ × Vec — - # which requires `Mat` to be a constant so we can transpose at convert time. if length(sa) == 1 && length(sb) == 2 + # NumPy-style Vec × Mat = Vec. Depends on ArrayDiff `:*` supporting + # vector × matrix shape inference (see ArrayDiff PR adding matrix- + # vector / vec-matrix product support). sa[1] == sb[1] || error("MatMul shape mismatch: $sa × $sb") - b isa AbstractMatrix{<:Real} || - error("MatMul Vec × Mat requires the matrix to be a constant initializer (got $(typeof(b)))") - bT = collect(permutedims(b)) s = (sb[2],) - return (_call(:*, Any[bT, a], s), s) + return (_call(:*, Any[a, b], s), s) elseif length(sa) == 2 && length(sb) == 1 sa[2] == sb[1] || error("MatMul shape mismatch: $sa × $sb") s = (sa[1],) @@ -257,6 +263,18 @@ function _convert_gemm(node, env) return (AB, AB_shape) end + # ONNX broadcasts a (N,) bias against (M, N) by treating the bias as a row. + # ArrayDiff (following Julia) would treat (N,) as a column instead, so + # promote to (1, N) explicitly before adding. + if length(sC) == 1 + if C isa AbstractVector{<:Real} + C = reshape(collect(C), 1, sC[1]) + else + error("Gemm with non-constant 1D bias is not supported yet") + end + sC = (1, sC[1]) + end + Cterm = β == 1.0 ? C : _bcall(:*, Any[β, C], sC) out_shape = _broadcast_shape(AB_shape, sC) out = _bcall(:+, Any[AB, Cterm], out_shape) diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl index adc7df9..3b26382 100644 --- a/test/ONNXExt.jl +++ b/test/ONNXExt.jl @@ -212,6 +212,8 @@ function test_elementwise_sub_mul_div() end # MatMul vector × matrix: y = x * W, W is (3, 2). Output is 1D length 2. +# Pending the ArrayDiff PR that adds vec × mat shape inference to `:*` — +# until then ArrayDiff trips `sizes.ndims[k] > 1` on the second operand. function test_matmul_vector_matrix() vars = [MOI.VariableIndex(i) for i in 1:3] W = [1.0 0.5; @@ -221,10 +223,11 @@ function test_matmul_vector_matrix() node = _make_node("MatMul", ["x", "W"], ["y"]) proto = _build_model([node], ["x"], ["y"]; initializers = [init]) xv = [0.4, -1.0, 0.9] - val, g = _eval_with_gradient(proto, vars, xv) fjulia(x) = sum((x' * W) .^ 2) - @test val ≈ fjulia(xv) - @test g ≈ ForwardDiff.gradient(fjulia, xv) + @test_broken begin + val, g = _eval_with_gradient(proto, vars, xv) + val ≈ fjulia(xv) && g ≈ ForwardDiff.gradient(fjulia, xv) + end end # Gemm without transB: y = α * (X * W) + β * b. X is shape (1, K). @@ -338,9 +341,11 @@ end # End-to-end: a 1-hidden-layer MLP with relu, matching what a PyTorch # `nn.Sequential(nn.Linear(D, H), nn.ReLU(), nn.Linear(H, D_out))` would export. +# Input is shape (1, D_in), matching the (batch, features) convention. function test_mlp_relu() D_in, D_hidden, D_out = 3, 4, 2 vars = [MOI.VariableIndex(i) for i in 1:D_in] + var_mat = reshape(vars, 1, D_in) W1 = 0.3 * randn(MersenneTwisterRNG(), D_hidden, D_in) # (H, D_in) b1 = 0.1 * randn(MersenneTwisterRNG(2), D_hidden) W2 = 0.5 * randn(MersenneTwisterRNG(3), D_out, D_hidden) # (D_out, H) @@ -364,10 +369,11 @@ function test_mlp_relu() ] proto = _build_model(nodes, ["x"], ["y"]; initializers = init) xv = [0.5, -0.7, 1.1] - val, g = _eval_with_gradient(proto, vars, xv) + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) fjulia(x) = begin - h = max.(x' * W1' .+ b1', 0.0) # row vec - y = h * W2' .+ b2' + xrow = reshape(x, 1, D_in) + h = max.(xrow * W1' .+ reshape(b1, 1, D_hidden), 0.0) + y = h * W2' .+ reshape(b2, 1, D_out) sum(y .^ 2) end @test val ≈ fjulia(xv) From 330947318cdc916e5d00f17086dce2e5c712449d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 10:26:01 +0200 Subject: [PATCH 3/9] Fixes --- ext/ArrayDiffONNXExt.jl | 19 ++++++++++++++----- test/ONNXExt.jl | 9 +++------ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl index 222cbcf..cfecaca 100644 --- a/ext/ArrayDiffONNXExt.jl +++ b/ext/ArrayDiffONNXExt.jl @@ -178,8 +178,13 @@ function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) return _convert_gemm(node, env) elseif op == "Relu" + # ArrayDiff's broadcasted-multivariate shape inference only handles + # {+, -, *, /, ^}, not `:max`. Express ReLU using the equivalent + # `(x + abs(x)) / 2` which uses only broadcast-supported ops. x, sx = env[node.input[1]] - return (_bcall(:max, Any[x, 0.0], sx), sx) + absx = _bcall(:abs, Any[x], sx) + s = _bcall(:+, Any[x, absx], sx) + return (_bcall(:/, Any[s, 2.0], sx), sx) elseif op == "Tanh" x, sx = env[node.input[1]] @@ -209,12 +214,16 @@ function _convert_matmul(node, env) a, sa = env[node.input[1]] b, sb = env[node.input[2]] if length(sa) == 1 && length(sb) == 2 - # NumPy-style Vec × Mat = Vec. Depends on ArrayDiff `:*` supporting - # vector × matrix shape inference (see ArrayDiff PR adding matrix- - # vector / vec-matrix product support). + # NumPy-style Vec × Mat = Vec. ArrayDiff's `:*` shape inference walks + # left-to-right and requires the first non-scalar child to be 2D, so + # rewrite as Matᵀ × Vec by transposing the constant matrix at + # convert time. Requires `b` to be a constant tensor (initializer). sa[1] == sb[1] || error("MatMul shape mismatch: $sa × $sb") + b isa AbstractMatrix{<:Real} || + error("MatMul Vec × Mat requires the matrix to be a constant initializer (got $(typeof(b)))") + bT = collect(permutedims(b)) s = (sb[2],) - return (_call(:*, Any[a, b], s), s) + return (_call(:*, Any[bT, a], s), s) elseif length(sa) == 2 && length(sb) == 1 sa[2] == sb[1] || error("MatMul shape mismatch: $sa × $sb") s = (sa[1],) diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl index 3b26382..9b9808b 100644 --- a/test/ONNXExt.jl +++ b/test/ONNXExt.jl @@ -212,8 +212,6 @@ function test_elementwise_sub_mul_div() end # MatMul vector × matrix: y = x * W, W is (3, 2). Output is 1D length 2. -# Pending the ArrayDiff PR that adds vec × mat shape inference to `:*` — -# until then ArrayDiff trips `sizes.ndims[k] > 1` on the second operand. function test_matmul_vector_matrix() vars = [MOI.VariableIndex(i) for i in 1:3] W = [1.0 0.5; @@ -223,11 +221,10 @@ function test_matmul_vector_matrix() node = _make_node("MatMul", ["x", "W"], ["y"]) proto = _build_model([node], ["x"], ["y"]; initializers = [init]) xv = [0.4, -1.0, 0.9] + val, g = _eval_with_gradient(proto, vars, xv) fjulia(x) = sum((x' * W) .^ 2) - @test_broken begin - val, g = _eval_with_gradient(proto, vars, xv) - val ≈ fjulia(xv) && g ≈ ForwardDiff.gradient(fjulia, xv) - end + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) end # Gemm without transB: y = α * (X * W) + β * b. X is shape (1, K). From e85ac1ae39c337ceaf0fc6fc5cb1813f422269ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 10:26:19 +0200 Subject: [PATCH 4/9] Fix format --- test/ONNXExt.jl | 67 +++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 27 deletions(-) diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl index 9b9808b..39f3e62 100644 --- a/test/ONNXExt.jl +++ b/test/ONNXExt.jl @@ -52,19 +52,11 @@ function _attr_float(name::String, f::Real) end function _attr_int(name::String, i::Integer) - return ONNX.AttributeProto( - name = name, - i = Int64(i), - var"#type" = AT.INT, - ) + return ONNX.AttributeProto(name = name, i = Int64(i), var"#type" = AT.INT) end function _attr_tensor(name::String, t::ONNX.TensorProto) - return ONNX.AttributeProto( - name = name, - t = t, - var"#type" = AT.TENSOR, - ) + return ONNX.AttributeProto(name = name, t = t, var"#type" = AT.TENSOR) end function _make_node( @@ -214,9 +206,11 @@ end # MatMul vector × matrix: y = x * W, W is (3, 2). Output is 1D length 2. function test_matmul_vector_matrix() vars = [MOI.VariableIndex(i) for i in 1:3] - W = [1.0 0.5; - -0.2 1.1; - 0.3 -0.7] + W = [ + 1.0 0.5; + -0.2 1.1; + 0.3 -0.7 + ] init = _make_tensor("W", W) node = _make_node("MatMul", ["x", "W"], ["y"]) proto = _build_model([node], ["x"], ["y"]; initializers = [init]) @@ -231,8 +225,10 @@ end function test_gemm_no_transpose() vars = [MOI.VariableIndex(i) for i in 1:2] var_mat = reshape(vars, 1, 2) # (1, 2) - W = [0.4 -0.6 0.2; - 1.1 0.3 -0.9] # (2, 3) + W = [ + 0.4 -0.6 0.2; + 1.1 0.3 -0.9 + ] # (2, 3) bias = [0.05, -0.1, 0.2] α, β = 0.5, 2.0 init_W = _make_tensor("W", W) @@ -251,7 +247,8 @@ function test_gemm_no_transpose() proto = _build_model([node], ["x"], ["y"]; initializers = [init_W, init_b]) xv = [0.8, -0.3] val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) - fjulia(x) = sum((α .* (reshape(x, 1, 2) * W) .+ β .* reshape(bias, 1, 3)) .^ 2) + fjulia(x) = + sum((α .* (reshape(x, 1, 2) * W) .+ β .* reshape(bias, 1, 3)) .^ 2) @test val ≈ fjulia(xv) @test g ≈ ForwardDiff.gradient(fjulia, xv) end @@ -261,9 +258,11 @@ end function test_gemm_transB() vars = [MOI.VariableIndex(i) for i in 1:2] var_mat = reshape(vars, 1, 2) - W = [0.4 0.5; - -0.1 1.0; - 0.9 -0.3] # (3, 2) — Linear(in=2, out=3) + W = [ + 0.4 0.5; + -0.1 1.0; + 0.9 -0.3 + ] # (3, 2) — Linear(in=2, out=3) bias = [0.0, 0.0, 0.0] init_W = _make_tensor("W", W) init_b = _make_tensor("b", bias) @@ -326,7 +325,13 @@ function test_constant_then_add() vars = [MOI.VariableIndex(i) for i in 1:3] c = [0.5, -0.5, 1.0] const_t = _make_tensor("c_value", c) - n1 = _make_node("Constant", String[], ["c"]; attrs = [_attr_tensor("value", const_t)], name = "k") + n1 = _make_node( + "Constant", + String[], + ["c"]; + attrs = [_attr_tensor("value", const_t)], + name = "k", + ) n2 = _make_node("Add", ["x", "c"], ["y"]) proto = _build_model([n1, n2], ["x"], ["y"]) xv = [0.3, 0.8, -0.4] @@ -360,9 +365,21 @@ function test_mlp_relu() _attr_int("transB", transB), ] nodes = [ - _make_node("Gemm", ["x", "W1", "b1"], ["h_pre"]; attrs = gemm_attrs(1), name = "fc1"), + _make_node( + "Gemm", + ["x", "W1", "b1"], + ["h_pre"]; + attrs = gemm_attrs(1), + name = "fc1", + ), _make_node("Relu", ["h_pre"], ["h"]; name = "act"), - _make_node("Gemm", ["h", "W2", "b2"], ["y"]; attrs = gemm_attrs(1), name = "fc2"), + _make_node( + "Gemm", + ["h", "W2", "b2"], + ["y"]; + attrs = gemm_attrs(1), + name = "fc2", + ), ] proto = _build_model(nodes, ["x"], ["y"]; initializers = init) xv = [0.5, -0.7, 1.1] @@ -394,11 +411,7 @@ function test_unsupported_op_errors() end function test_missing_input_errors() - proto = _build_model( - [_make_node("Identity", ["x"], ["y"])], - ["x"], - ["y"], - ) + proto = _build_model([_make_node("Identity", ["x"], ["y"])], ["x"], ["y"]) @test_throws ErrorException ArrayDiff.from_onnx(proto; inputs = Dict()) end From c7396906ba0a1ffc01f729933cf80e5e6c36a46a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 12:45:04 +0200 Subject: [PATCH 5/9] Improve coverage --- ext/ArrayDiffONNXExt.jl | 7 - test/ONNXExt.jl | 296 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+), 7 deletions(-) diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl index cfecaca..548440d 100644 --- a/ext/ArrayDiffONNXExt.jl +++ b/ext/ArrayDiffONNXExt.jl @@ -32,11 +32,6 @@ function _attr_float(node, name; default::Float64) return a === nothing ? default : Float64(a.f) end -function _attr_ints(node, name) - a = _find_attr(node, name) - return a === nothing ? Int[] : Int[Int(x) for x in a.ints] -end - function _attr_tensor(node, name) a = _find_attr(node, name) return a === nothing ? nothing : a.t @@ -113,8 +108,6 @@ function _broadcast_shape(a::_Shape, b::_Shape) if a == () return b end if b == () return a end n = max(length(a), length(b)) - pa = ntuple(i -> i <= length(a) ? a[end - length(a) + i + (n - n)] : 1, n) # placeholder - # simpler explicit loop out = Vector{Int}(undef, n) for i in 1:n da = i <= length(a) ? a[end - i + 1] : 1 diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl index 39f3e62..f69a723 100644 --- a/test/ONNXExt.jl +++ b/test/ONNXExt.jl @@ -415,6 +415,302 @@ function test_missing_input_errors() @test_throws ErrorException ArrayDiff.from_onnx(proto; inputs = Dict()) end +# ── Direct tests for private helpers ───────────────────────────────────────── + +const _ext = Base.get_extension(ArrayDiff, :ArrayDiffONNXExt) + +function test_broadcast_shape_helper() + @test _ext._broadcast_shape((), (3,)) == (3,) + @test _ext._broadcast_shape((3,), ()) == (3,) + @test _ext._broadcast_shape((2, 3), (2, 3)) == (2, 3) + # da == 1 broadcast + @test _ext._broadcast_shape((1, 3), (2, 3)) == (2, 3) + # db == 1 broadcast (1D bias-style) + @test _ext._broadcast_shape((2, 3), (3,)) == (2, 3) + # incompatible + @test_throws ErrorException _ext._broadcast_shape((2,), (3,)) +end + +function test_wrap_input_scalar_vector_matrix_real() + @test _ext._wrap_input(3.5) == (3.5, ()) + @test _ext._wrap_input(2) == (2.0, ()) + v = [1.0, 2.0, 3.0] + out, sz = _ext._wrap_input(v) + @test out == v && sz == (3,) && out isa Vector{Float64} + M = [1.0 2.0; 3.0 4.0] + outM, szM = _ext._wrap_input(M) + @test outM == M && szM == (2, 2) && outM isa Matrix{Float64} +end + +function test_wrap_input_anf() + anf = ArrayDiff.ArrayNonlinearFunction{1}(:vect, Any[1.0, 2.0], (2,), false) + out, sz = _ext._wrap_input(anf) + @test out === anf && sz == (2,) +end + +function test_wrap_input_unsupported() + @test_throws ErrorException _ext._wrap_input((1, 2, 3)) +end + +function test_wrap_input_matrix_vars_multi_row() + M = collect(reshape([MOI.VariableIndex(i) for i in 1:6], 2, 3)) + out, sz = _ext._wrap_input(M) + @test sz == (2, 3) + @test out isa ArrayDiff.ArrayNonlinearFunction{2} + @test out.head == :vcat +end + +# ── TensorProto encoding paths ─────────────────────────────────────────────── + +function test_tensor_to_array_float_data() + t = ONNX.TensorProto( + dims = Int64[2, 3], + data_type = Int32(DT.FLOAT), + name = "t", + float_data = Float32[1, 2, 3, 4, 5, 6], + ) + arr, sz = _ext._tensor_to_array(t) + @test sz == (2, 3) + @test arr == [1.0 2.0 3.0; 4.0 5.0 6.0] +end + +function test_tensor_to_array_raw_data_float() + raw = Vector{UInt8}(reinterpret(UInt8, Float32[1.0, 2.0, 3.0])) + t = ONNX.TensorProto( + dims = Int64[3], + data_type = Int32(DT.FLOAT), + name = "t", + raw_data = raw, + ) + arr, sz = _ext._tensor_to_array(t) + @test sz == (3,) + @test arr == [1.0, 2.0, 3.0] +end + +function test_tensor_to_array_raw_data_double() + raw = Vector{UInt8}(reinterpret(UInt8, Float64[1.5, -2.5])) + t = ONNX.TensorProto( + dims = Int64[2], + data_type = Int32(DT.DOUBLE), + name = "t", + raw_data = raw, + ) + arr, sz = _ext._tensor_to_array(t) + @test sz == (2,) + @test arr == [1.5, -2.5] +end + +function test_tensor_to_array_raw_data_unsupported() + t = ONNX.TensorProto( + dims = Int64[2], + data_type = Int32(DT.INT32), + name = "t", + raw_data = UInt8[1, 2, 3, 4, 5, 6, 7, 8], + ) + @test_throws ErrorException _ext._tensor_to_array(t) +end + +function test_tensor_to_array_empty_encoding() + t = ONNX.TensorProto( + dims = Int64[2], + data_type = Int32(DT.INT32), + name = "t", + ) + @test_throws ErrorException _ext._tensor_to_array(t) +end + +function test_tensor_to_array_scalar() + t = _make_scalar_tensor("t", 3.5) + arr, sz = _ext._tensor_to_array(t) + @test arr == 3.5 && sz == () +end + +function test_tensor_to_array_3d_unsupported() + t = ONNX.TensorProto( + dims = Int64[1, 2, 3], + data_type = Int32(DT.DOUBLE), + name = "t", + double_data = Float64[1, 2, 3, 4, 5, 6], + ) + @test_throws ErrorException _ext._tensor_to_array(t) +end + +# ── Per-op coverage ────────────────────────────────────────────────────────── + +function test_neg() + vars = [MOI.VariableIndex(i) for i in 1:3] + node = _make_node("Neg", ["x"], ["y"]) + proto = _build_model([node], ["x"], ["y"]) + xv = [0.4, -1.2, 1.7] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = sum((-x) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Constant op with a scalar value. +function test_constant_scalar() + vars = [MOI.VariableIndex(i) for i in 1:2] + c_t = _make_scalar_tensor("c_val", 1.5) + n1 = _make_node( + "Constant", + String[], + ["c"]; + attrs = [_attr_tensor("value", c_t)], + name = "k", + ) + n2 = _make_node("Add", ["x", "c"], ["y"]) + proto = _build_model([n1, n2], ["x"], ["y"]) + xv = [0.5, 1.0] + val, g = _eval_with_gradient(proto, vars, xv) + fjulia(x) = sum((x .+ 1.5) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +function test_constant_missing_value_errors() + n = _make_node("Constant", String[], ["c"]) + proto = _build_model( + [n, _make_node("Identity", ["c"], ["y"])], + String[], + ["y"], + ) + @test_throws ErrorException ArrayDiff.from_onnx(proto) +end + +# MatMul Mat × Vec: y = X * b, X = (2, 3) vars, b = (3,) const. +function test_matmul_matrix_vector() + vars = [MOI.VariableIndex(i) for i in 1:6] + var_mat = collect(reshape(vars, 2, 3)) + b = [0.4, -1.0, 0.9] + init = _make_tensor("b", b) + node = _make_node("MatMul", ["x", "b"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init]) + xv = [0.3, -0.7, 1.1, 2.0, 0.5, -1.5] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((reshape(x, 2, 3) * b) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# MatMul Mat × Mat: y = X * W, X = (2, 3) vars, W = (3, 2) const. +function test_matmul_matrix_matrix() + vars = [MOI.VariableIndex(i) for i in 1:6] + var_mat = collect(reshape(vars, 2, 3)) + W = [ + 0.4 -0.1 + 0.5 1.2 + -0.3 0.7 + ] + init = _make_tensor("W", W) + node = _make_node("MatMul", ["x", "W"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init]) + xv = [0.2, 1.0, -0.5, 0.8, 1.4, -1.1] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((reshape(x, 2, 3) * W) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Vec × Vec is not a supported MatMul shape combination. +function test_matmul_unsupported_shapes() + vars = [MOI.VariableIndex(i) for i in 1:3] + b = [1.0, 2.0, 3.0] + init = _make_tensor("b", b) + node = _make_node("MatMul", ["x", "b"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init]) + @test_throws ErrorException ArrayDiff.from_onnx( + proto; + inputs = Dict("x" => vars), + ) +end + +# Gemm without C: 2-input form, no bias. +function test_gemm_no_bias() + vars = [MOI.VariableIndex(i) for i in 1:2] + var_mat = reshape(vars, 1, 2) + W = [ + 0.4 -0.1 + 0.5 1.2 + ] + init_W = _make_tensor("W", W) + node = _make_node( + "Gemm", + ["x", "W"], + ["y"]; + attrs = [ + _attr_float("alpha", 1.0), + _attr_float("beta", 1.0), + _attr_int("transA", 0), + _attr_int("transB", 0), + ], + ) + proto = _build_model([node], ["x"], ["y"]; initializers = [init_W]) + xv = [0.3, -0.7] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((reshape(x, 1, 2) * W) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Gemm with all attributes omitted: covers the default-attr path in `_find_attr`. +function test_gemm_default_attrs() + vars = [MOI.VariableIndex(i) for i in 1:2] + var_mat = reshape(vars, 1, 2) + W = [ + 0.4 -0.1 + 0.5 1.2 + ] + bias = [0.1, -0.2] + init_W = _make_tensor("W", W) + init_b = _make_tensor("b", bias) + node = _make_node("Gemm", ["x", "W", "b"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init_W, init_b]) + xv = [0.3, -0.7] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((reshape(x, 1, 2) * W .+ reshape(bias, 1, 2)) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Gemm with a non-constant 1D bias is rejected. +function test_gemm_non_const_1d_bias_errors() + vars_x = [MOI.VariableIndex(i) for i in 1:2] + vars_c = [MOI.VariableIndex(i) for i in 3:5] + var_mat = reshape(vars_x, 1, 2) + W = [ + 0.4 -0.1 0.5 + 1.2 -0.3 0.7 + ] + init_W = _make_tensor("W", W) + node = _make_node("Gemm", ["x", "W", "c"], ["y"]) + proto = _build_model([node], ["x", "c"], ["y"]; initializers = [init_W]) + @test_throws ErrorException ArrayDiff.from_onnx( + proto; + inputs = Dict("x" => var_mat, "c" => vars_c), + ) +end + +# Graph declares "x" as both an input and an initializer: the initializer wins. +function test_input_name_overlaps_initializer() + init_x = _make_tensor("x", [1.0, 2.0, 3.0]) + node = _make_node("Identity", ["x"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init_x]) + out = ArrayDiff.from_onnx(proto) + @test out == [1.0, 2.0, 3.0] +end + +# Multi-output graph: result is keyed by output name. +function test_multi_output() + vars = [MOI.VariableIndex(i) for i in 1:3] + n1 = _make_node("Identity", ["x"], ["a"]; name = "id") + n2 = _make_node("Neg", ["x"], ["b"]; name = "neg") + proto = _build_model([n1, n2], ["x"], ["a", "b"]) + out = ArrayDiff.from_onnx(proto; inputs = Dict("x" => vars)) + @test out isa Dict + @test sort(collect(keys(out))) == ["a", "b"] +end + end # module TestONNXExt.runtests() From bb9e2b35559165a04e70cf29a8551162fe8ec51e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 12:48:52 +0200 Subject: [PATCH 6/9] Fix format --- test/ONNXExt.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl index f69a723..d235f44 100644 --- a/test/ONNXExt.jl +++ b/test/ONNXExt.jl @@ -570,11 +570,8 @@ end function test_constant_missing_value_errors() n = _make_node("Constant", String[], ["c"]) - proto = _build_model( - [n, _make_node("Identity", ["c"], ["y"])], - String[], - ["y"], - ) + proto = + _build_model([n, _make_node("Identity", ["c"], ["y"])], String[], ["y"]) @test_throws ErrorException ArrayDiff.from_onnx(proto) end From 40cc3d146a28fb3fd04f26be0d766a31d8f2af4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 12:53:51 +0200 Subject: [PATCH 7/9] Fix format --- ext/ArrayDiffONNXExt.jl | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl index 548440d..518a736 100644 --- a/ext/ArrayDiffONNXExt.jl +++ b/ext/ArrayDiffONNXExt.jl @@ -70,14 +70,17 @@ function _tensor_to_array(t::ONNX.TensorProto) mat = collect(permutedims(reshape(flat, (n, m)), (2, 1))) return (mat, (m, n)) else - error("Tensors with ndim > 2 not supported (got dims=$dims for '$(t.name)')") + error( + "Tensors with ndim > 2 not supported (got dims=$dims for '$(t.name)')", + ) end end # ── User-supplied input wrapping ──────────────────────────────────────────── -_wrap_input(v::Vector{MOI.VariableIndex}) = - (ANF{1}(:vect, Any[v...], (length(v),), false), (length(v),)) +function _wrap_input(v::Vector{MOI.VariableIndex}) + return (ANF{1}(:vect, Any[v...], (length(v),), false), (length(v),)) +end function _wrap_input(M::Matrix{MOI.VariableIndex}) m, n = size(M) @@ -105,19 +108,23 @@ _wrap_input(x) = error("Unsupported input value type: $(typeof(x))") function _broadcast_shape(a::_Shape, b::_Shape) # NumPy / ONNX semantics: align trailing dims; each pair must match or one is 1. - if a == () return b end - if b == () return a end + if a == () + return b + end + if b == () + return a + end n = max(length(a), length(b)) out = Vector{Int}(undef, n) for i in 1:n - da = i <= length(a) ? a[end - i + 1] : 1 - db = i <= length(b) ? b[end - i + 1] : 1 + da = i <= length(a) ? a[end-i+1] : 1 + db = i <= length(b) ? b[end-i+1] : 1 if da == db - out[n - i + 1] = da + out[n-i+1] = da elseif da == 1 - out[n - i + 1] = db + out[n-i+1] = db elseif db == 1 - out[n - i + 1] = da + out[n-i+1] = da else error("Incompatible broadcast shapes $a vs $b") end @@ -148,7 +155,8 @@ function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) elseif op == "Constant" t = _attr_tensor(node, "value") - t === nothing && error("Constant node '$(node.name)' has no 'value' attribute") + t === nothing && + error("Constant node '$(node.name)' has no 'value' attribute") return _tensor_to_array(t) elseif op == "Add" @@ -212,8 +220,9 @@ function _convert_matmul(node, env) # rewrite as Matᵀ × Vec by transposing the constant matrix at # convert time. Requires `b` to be a constant tensor (initializer). sa[1] == sb[1] || error("MatMul shape mismatch: $sa × $sb") - b isa AbstractMatrix{<:Real} || - error("MatMul Vec × Mat requires the matrix to be a constant initializer (got $(typeof(b)))") + b isa AbstractMatrix{<:Real} || error( + "MatMul Vec × Mat requires the matrix to be a constant initializer (got $(typeof(b)))", + ) bT = collect(permutedims(b)) s = (sb[2],) return (_call(:*, Any[bT, a], s), s) @@ -242,8 +251,9 @@ function _convert_gemm(node, env) transA && error("Gemm with transA=1 is not supported") if transB - B isa AbstractMatrix{<:Real} || - error("Gemm with transB=1 requires B to be a constant tensor (got $(typeof(B)))") + B isa AbstractMatrix{<:Real} || error( + "Gemm with transB=1 requires B to be a constant tensor (got $(typeof(B)))", + ) B = collect(permutedims(B)) sB = (sB[2], sB[1]) end From 33ad8313e25fdc3e53e06042c1451f75313a82db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 17:18:52 +0200 Subject: [PATCH 8/9] Fix for Float32 --- ext/ArrayDiffONNXExt.jl | 78 +++++++++++++++------------ test/ONNXExt.jl | 113 +++++++++++++++++++++++++++++++++++----- 2 files changed, 144 insertions(+), 47 deletions(-) diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl index 518a736..5f18ab7 100644 --- a/ext/ArrayDiffONNXExt.jl +++ b/ext/ArrayDiffONNXExt.jl @@ -8,7 +8,8 @@ const ANF = ArrayDiff.ArrayNonlinearFunction const _Shape = Tuple{Vararg{Int}} # An entry in the conversion env: the value plus its array shape. # Value may be `ANF`, `MOI.ScalarNonlinearFunction`, `MOI.VariableIndex`, -# `Float64`, `Vector{Float64}`, or `Matrix{Float64}`. Shape is `()` for scalars. +# scalar `T`, `Vector{T}`, or `Matrix{T}` (where `T` is the scalar type passed +# to `from_onnx`). Shape is `()` for scalars. const _Entry = Tuple{Any,_Shape} # ── Attribute helpers ─────────────────────────────────────────────────────── @@ -27,9 +28,9 @@ function _attr_int(node, name; default::Int) return a === nothing ? default : Int(a.i) end -function _attr_float(node, name; default::Float64) +function _attr_float(::Type{T}, node, name; default) where {T<:Real} a = _find_attr(node, name) - return a === nothing ? default : Float64(a.f) + return a === nothing ? T(default) : T(a.f) end function _attr_tensor(node, name) @@ -41,18 +42,18 @@ end # ONNX stores tensor values in C (row-major) order. Julia is column-major, # so a 2D tensor of dims=(m, n) must be reshaped to (n, m) then permuted. -function _tensor_to_array(t::ONNX.TensorProto) +function _tensor_to_array(::Type{T}, t::ONNX.TensorProto) where {T<:Real} DT = getfield(ONNX, Symbol("TensorProto.DataType")) dims = Int[Int(d) for d in t.dims] flat = if t.data_type == Int32(DT.DOUBLE) && !isempty(t.double_data) - Float64[x for x in t.double_data] + T[T(x) for x in t.double_data] elseif t.data_type == Int32(DT.FLOAT) && !isempty(t.float_data) - Float64[Float64(x) for x in t.float_data] + T[T(x) for x in t.float_data] elseif !isempty(t.raw_data) if t.data_type == Int32(DT.FLOAT) - Float64.(reinterpret(Float32, t.raw_data)) + T.(reinterpret(Float32, t.raw_data)) elseif t.data_type == Int32(DT.DOUBLE) - Float64.(reinterpret(Float64, t.raw_data)) + T.(reinterpret(Float64, t.raw_data)) else error("Unsupported raw_data type: $(t.data_type)") end @@ -78,11 +79,11 @@ end # ── User-supplied input wrapping ──────────────────────────────────────────── -function _wrap_input(v::Vector{MOI.VariableIndex}) +function _wrap_input(::Type{<:Real}, v::Vector{MOI.VariableIndex}) return (ANF{1}(:vect, Any[v...], (length(v),), false), (length(v),)) end -function _wrap_input(M::Matrix{MOI.VariableIndex}) +function _wrap_input(::Type{<:Real}, M::Matrix{MOI.VariableIndex}) m, n = size(M) row(i) = ANF{2}(:row, Any[M[i, j] for j in 1:n], (1, n), false) if m == 1 @@ -97,12 +98,15 @@ function _wrap_input(M::Matrix{MOI.VariableIndex}) return (acc, (m, n)) end -_wrap_input(x::ANF{N}) where {N} = (x, x.size) -_wrap_input(x::Real) = (Float64(x), ()) -_wrap_input(x::Vector{<:Real}) = (Vector{Float64}(x), (length(x),)) -_wrap_input(x::Matrix{<:Real}) = (Matrix{Float64}(x), size(x)) +_wrap_input(::Type{<:Real}, x::ANF{N}) where {N} = (x, x.size) +_wrap_input(::Type{T}, x::Real) where {T<:Real} = (T(x), ()) +_wrap_input(::Type{T}, x::Vector{<:Real}) where {T<:Real} = + (Vector{T}(x), (length(x),)) +_wrap_input(::Type{T}, x::Matrix{<:Real}) where {T<:Real} = + (Matrix{T}(x), size(x)) -_wrap_input(x) = error("Unsupported input value type: $(typeof(x))") +_wrap_input(::Type{<:Real}, x) = + error("Unsupported input value type: $(typeof(x))") # ── Shape arithmetic ──────────────────────────────────────────────────────── @@ -148,7 +152,11 @@ end # ── Per-op conversion ─────────────────────────────────────────────────────── -function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) +function _convert_node( + ::Type{T}, + node::ONNX.NodeProto, + env::Dict{String,_Entry}, +) where {T<:Real} op = node.op_type if op == "Identity" return env[node.input[1]] @@ -157,7 +165,7 @@ function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) t = _attr_tensor(node, "value") t === nothing && error("Constant node '$(node.name)' has no 'value' attribute") - return _tensor_to_array(t) + return _tensor_to_array(T, t) elseif op == "Add" return _binop_broadcast(:+, node, env) @@ -170,13 +178,13 @@ function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) elseif op == "Neg" x, sx = env[node.input[1]] - return (_bcall(:-, Any[0.0, x], sx), sx) + return (_bcall(:-, Any[zero(T), x], sx), sx) elseif op == "MatMul" return _convert_matmul(node, env) elseif op == "Gemm" - return _convert_gemm(node, env) + return _convert_gemm(T, node, env) elseif op == "Relu" # ArrayDiff's broadcasted-multivariate shape inference only handles @@ -185,7 +193,7 @@ function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) x, sx = env[node.input[1]] absx = _bcall(:abs, Any[x], sx) s = _bcall(:+, Any[x, absx], sx) - return (_bcall(:/, Any[s, 2.0], sx), sx) + return (_bcall(:/, Any[s, T(2)], sx), sx) elseif op == "Tanh" x, sx = env[node.input[1]] @@ -194,10 +202,10 @@ function _convert_node(node::ONNX.NodeProto, env::Dict{String,_Entry}) elseif op == "Sigmoid" # 1 / (1 + exp(-x)), all broadcast over x's shape. x, sx = env[node.input[1]] - negx = _bcall(:-, Any[0.0, x], sx) + negx = _bcall(:-, Any[zero(T), x], sx) ex = _bcall(:exp, Any[negx], sx) - one_plus = _bcall(:+, Any[1.0, ex], sx) - return (_bcall(:/, Any[1.0, one_plus], sx), sx) + one_plus = _bcall(:+, Any[one(T), ex], sx) + return (_bcall(:/, Any[one(T), one_plus], sx), sx) else error("ONNX op '$(op)' is not supported by ArrayDiffONNXExt") @@ -239,13 +247,13 @@ function _convert_matmul(node, env) end end -function _convert_gemm(node, env) +function _convert_gemm(::Type{T}, node, env) where {T<:Real} A, sA = env[node.input[1]] B, sB = env[node.input[2]] has_C = length(node.input) >= 3 && !isempty(node.input[3]) C, sC = has_C ? env[node.input[3]] : (nothing, ()) - α = _attr_float(node, "alpha"; default = 1.0) - β = _attr_float(node, "beta"; default = 1.0) + α = _attr_float(T, node, "alpha"; default = 1) + β = _attr_float(T, node, "beta"; default = 1) transA = _attr_int(node, "transA"; default = 0) != 0 transB = _attr_int(node, "transB"; default = 0) != 0 @@ -267,11 +275,11 @@ function _convert_gemm(node, env) AB_shape = (sA[1], sB[2]) AB = _call(:*, Any[A, B], AB_shape) - if α != 1.0 + if !isone(α) AB = _bcall(:*, Any[α, AB], AB_shape) end - if !has_C || β == 0.0 + if !has_C || iszero(β) return (AB, AB_shape) end @@ -287,7 +295,7 @@ function _convert_gemm(node, env) sC = (1, sC[1]) end - Cterm = β == 1.0 ? C : _bcall(:*, Any[β, C], sC) + Cterm = isone(β) ? C : _bcall(:*, Any[β, C], sC) out_shape = _broadcast_shape(AB_shape, sC) out = _bcall(:+, Any[AB, Cterm], out_shape) return (out, out_shape) @@ -296,14 +304,15 @@ end # ── Entry point ───────────────────────────────────────────────────────────── function ArrayDiff.from_onnx( + ::Type{T}, proto::ONNX.ModelProto; inputs::AbstractDict = Dict{String,Any}(), -) +) where {T<:Real} graph = proto.graph env = Dict{String,_Entry}() for tp in graph.initializer - env[tp.name] = _tensor_to_array(tp) + env[tp.name] = _tensor_to_array(T, tp) end for inp in graph.input @@ -312,11 +321,11 @@ function ArrayDiff.from_onnx( end haskey(inputs, String(inp.name)) || error("ONNX graph input '$(inp.name)' has no supplied value") - env[inp.name] = _wrap_input(inputs[String(inp.name)]) + env[inp.name] = _wrap_input(T, inputs[String(inp.name)]) end for node in graph.node - result = _convert_node(node, env) + result = _convert_node(T, node, env) # All currently-supported ops produce exactly one output. length(node.output) == 1 || error("Multi-output op '$(node.op_type)' not supported") @@ -331,4 +340,7 @@ function ArrayDiff.from_onnx( end end +ArrayDiff.from_onnx(proto::ONNX.ModelProto; kwargs...) = + ArrayDiff.from_onnx(Float64, proto; kwargs...) + end # module diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl index d235f44..ada591c 100644 --- a/test/ONNXExt.jl +++ b/test/ONNXExt.jl @@ -432,29 +432,36 @@ function test_broadcast_shape_helper() end function test_wrap_input_scalar_vector_matrix_real() - @test _ext._wrap_input(3.5) == (3.5, ()) - @test _ext._wrap_input(2) == (2.0, ()) + @test _ext._wrap_input(Float64, 3.5) == (3.5, ()) + @test _ext._wrap_input(Float64, 2) == (2.0, ()) v = [1.0, 2.0, 3.0] - out, sz = _ext._wrap_input(v) + out, sz = _ext._wrap_input(Float64, v) @test out == v && sz == (3,) && out isa Vector{Float64} M = [1.0 2.0; 3.0 4.0] - outM, szM = _ext._wrap_input(M) + outM, szM = _ext._wrap_input(Float64, M) @test outM == M && szM == (2, 2) && outM isa Matrix{Float64} + # Float32 carries through to the wrapped value. + out32, _ = _ext._wrap_input(Float32, 3.5) + @test out32 isa Float32 + outV32, _ = _ext._wrap_input(Float32, v) + @test outV32 isa Vector{Float32} + outM32, _ = _ext._wrap_input(Float32, M) + @test outM32 isa Matrix{Float32} end function test_wrap_input_anf() anf = ArrayDiff.ArrayNonlinearFunction{1}(:vect, Any[1.0, 2.0], (2,), false) - out, sz = _ext._wrap_input(anf) + out, sz = _ext._wrap_input(Float64, anf) @test out === anf && sz == (2,) end function test_wrap_input_unsupported() - @test_throws ErrorException _ext._wrap_input((1, 2, 3)) + @test_throws ErrorException _ext._wrap_input(Float64, (1, 2, 3)) end function test_wrap_input_matrix_vars_multi_row() M = collect(reshape([MOI.VariableIndex(i) for i in 1:6], 2, 3)) - out, sz = _ext._wrap_input(M) + out, sz = _ext._wrap_input(Float64, M) @test sz == (2, 3) @test out isa ArrayDiff.ArrayNonlinearFunction{2} @test out.head == :vcat @@ -469,7 +476,7 @@ function test_tensor_to_array_float_data() name = "t", float_data = Float32[1, 2, 3, 4, 5, 6], ) - arr, sz = _ext._tensor_to_array(t) + arr, sz = _ext._tensor_to_array(Float64, t) @test sz == (2, 3) @test arr == [1.0 2.0 3.0; 4.0 5.0 6.0] end @@ -482,7 +489,7 @@ function test_tensor_to_array_raw_data_float() name = "t", raw_data = raw, ) - arr, sz = _ext._tensor_to_array(t) + arr, sz = _ext._tensor_to_array(Float64, t) @test sz == (3,) @test arr == [1.0, 2.0, 3.0] end @@ -495,7 +502,7 @@ function test_tensor_to_array_raw_data_double() name = "t", raw_data = raw, ) - arr, sz = _ext._tensor_to_array(t) + arr, sz = _ext._tensor_to_array(Float64, t) @test sz == (2,) @test arr == [1.5, -2.5] end @@ -507,7 +514,7 @@ function test_tensor_to_array_raw_data_unsupported() name = "t", raw_data = UInt8[1, 2, 3, 4, 5, 6, 7, 8], ) - @test_throws ErrorException _ext._tensor_to_array(t) + @test_throws ErrorException _ext._tensor_to_array(Float64, t) end function test_tensor_to_array_empty_encoding() @@ -516,12 +523,12 @@ function test_tensor_to_array_empty_encoding() data_type = Int32(DT.INT32), name = "t", ) - @test_throws ErrorException _ext._tensor_to_array(t) + @test_throws ErrorException _ext._tensor_to_array(Float64, t) end function test_tensor_to_array_scalar() t = _make_scalar_tensor("t", 3.5) - arr, sz = _ext._tensor_to_array(t) + arr, sz = _ext._tensor_to_array(Float64, t) @test arr == 3.5 && sz == () end @@ -532,7 +539,7 @@ function test_tensor_to_array_3d_unsupported() name = "t", double_data = Float64[1, 2, 3, 4, 5, 6], ) - @test_throws ErrorException _ext._tensor_to_array(t) + @test_throws ErrorException _ext._tensor_to_array(Float64, t) end # ── Per-op coverage ────────────────────────────────────────────────────────── @@ -697,6 +704,84 @@ function test_input_name_overlaps_initializer() @test out == [1.0, 2.0, 3.0] end +# Float32 end-to-end: the from_onnx output uses T-typed constants and the +# evaluator returns Float32 values and Float32 gradients. +function _eval_with_gradient_f32( + proto::ONNX.ModelProto, + vars::Vector{MOI.VariableIndex}, + xv::Vector{Float32}; + input = vars, +) + out = ArrayDiff.from_onnx(Float32, proto; inputs = Dict("x" => input)) + snf = MOI.ScalarNonlinearFunction(:dot, Any[out, out]) + model = ArrayDiff.Model{Float32}() + ArrayDiff.set_objective(model, snf) + evaluator = + ArrayDiff.Evaluator(model, ArrayDiff.Mode{Vector{Float32}}(), vars) + MOI.initialize(evaluator, [:Grad]) + val = MOI.eval_objective(evaluator, xv) + g = zeros(Float32, length(xv)) + MOI.eval_objective_gradient(evaluator, g, xv) + return out, val, g +end + +# Tensor initializers materialize as `Vector{Float32}` / `Matrix{Float32}`. +function test_float32_initializer_eltype() + init = _make_tensor("b", [0.1, -0.3, 0.7, 1.0]) + node = _make_node("Add", ["x", "b"], ["y"]) + proto = _build_model([node], ["x"], ["y"]; initializers = [init]) + vars = [MOI.VariableIndex(i) for i in 1:4] + xv = Float32[1.0, 2.0, -1.5, 0.4] + out, val, g = _eval_with_gradient_f32(proto, vars, xv) + @test out isa ArrayDiff.ArrayNonlinearFunction + # The bias is the second argument of the broadcasted `:+`. + bias_arg = out.args[2] + @test bias_arg isa Vector{Float32} + @test val isa Float32 + @test g isa Vector{Float32} + fjulia(x) = sum((x .+ Float32[0.1, -0.3, 0.7, 1.0]) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Sigmoid emits `zero(T)` and `one(T)` constants; Gemm's α/β are also typed. +function test_float32_sigmoid_and_gemm() + vars = [MOI.VariableIndex(i) for i in 1:2] + var_mat = reshape(vars, 1, 2) + W = [0.4 -0.6 0.2; 1.1 0.3 -0.9] + bias = [0.05, -0.1, 0.2] + init_W = _make_tensor("W", W) + init_b = _make_tensor("b", bias) + gemm = _make_node( + "Gemm", + ["x", "W", "b"], + ["h"]; + attrs = [ + _attr_float("alpha", 0.5), + _attr_float("beta", 2.0), + _attr_int("transA", 0), + _attr_int("transB", 0), + ], + ) + sig = _make_node("Sigmoid", ["h"], ["y"]) + proto = + _build_model([gemm, sig], ["x"], ["y"]; initializers = [init_W, init_b]) + xv = Float32[0.8, -0.3] + _, val, g = _eval_with_gradient_f32(proto, vars, xv; input = var_mat) + @test val isa Float32 + @test g isa Vector{Float32} + fjulia(x) = sum( + ( + 1 ./ ( + 1 .+ exp.( + .-(0.5f0 .* (reshape(x, 1, 2) * Float32.(W)) .+ 2.0f0 .* reshape(Float32.(bias), 1, 3)), + ) + ) + ) .^ 2, + ) + @test val ≈ fjulia(xv) rtol = 1.0f-5 +end + # Multi-output graph: result is keyed by output name. function test_multi_output() vars = [MOI.VariableIndex(i) for i in 1:3] From 9f9ad8f39fb80cc8c578a245c5dee0890d02a589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 21:24:20 +0200 Subject: [PATCH 9/9] Fix format --- ext/ArrayDiffONNXExt.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl index 5f18ab7..7c0b4c8 100644 --- a/ext/ArrayDiffONNXExt.jl +++ b/ext/ArrayDiffONNXExt.jl @@ -100,13 +100,16 @@ end _wrap_input(::Type{<:Real}, x::ANF{N}) where {N} = (x, x.size) _wrap_input(::Type{T}, x::Real) where {T<:Real} = (T(x), ()) -_wrap_input(::Type{T}, x::Vector{<:Real}) where {T<:Real} = - (Vector{T}(x), (length(x),)) -_wrap_input(::Type{T}, x::Matrix{<:Real}) where {T<:Real} = - (Matrix{T}(x), size(x)) +function _wrap_input(::Type{T}, x::Vector{<:Real}) where {T<:Real} + return (Vector{T}(x), (length(x),)) +end +function _wrap_input(::Type{T}, x::Matrix{<:Real}) where {T<:Real} + return (Matrix{T}(x), size(x)) +end -_wrap_input(::Type{<:Real}, x) = - error("Unsupported input value type: $(typeof(x))") +function _wrap_input(::Type{<:Real}, x) + return error("Unsupported input value type: $(typeof(x))") +end # ── Shape arithmetic ──────────────────────────────────────────────────────── @@ -340,7 +343,8 @@ function ArrayDiff.from_onnx( end end -ArrayDiff.from_onnx(proto::ONNX.ModelProto; kwargs...) = - ArrayDiff.from_onnx(Float64, proto; kwargs...) +function ArrayDiff.from_onnx(proto::ONNX.ModelProto; kwargs...) + return ArrayDiff.from_onnx(Float64, proto; kwargs...) +end end # module