From b774b799567dd27ef54b52f4750c804cb912ed5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 16:54:11 +0200 Subject: [PATCH 1/9] Simplify _infer_size --- src/sizes.jl | 384 +++++++++++++++++++++++++++------------------------ test/JuMP.jl | 8 +- 2 files changed, 209 insertions(+), 183 deletions(-) diff --git a/src/sizes.jl b/src/sizes.jl index f250a0d..401e45a 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -340,18 +340,25 @@ function _assert_scalar_children(sizes, children_arr, children_indices, op) end """ - infer_sizes(op, child_sizes::Tuple...) -> Tuple + infer_sizes(op, child_sizes...) -> shape Return the output shape of applying `op` to arguments of shapes `child_sizes`. -Each `child_sizes[i]` is `()` if argument `i` is a scalar, or a tuple of -positive integers if it is an array. The returned shape is `()` for a scalar -result. +Each `child_sizes[i]` is empty (`()` or `Int[]`) if argument `i` is a scalar, +or a tuple/vector of positive integers if it is an array. The returned shape +is empty for a scalar output. Methods are written generically over indexable +containers so the same definition works whether the caller passes tuples +(JuMP-side) or views into a `Sizes` buffer (tape-side); the returned shape +typically matches the input kind. The default implementation constructs dummy arguments with `zeros(sz)` (or `0.0` for scalars) and calls `op(args...)`. Specialise on `op`'s `typeof` to avoid the allocation, to support operators that error on zero inputs, or to compute the output shape symbolically. +For tape-internal use, built-in operator symbols dispatch via +`infer_sizes(::Val{op}, shapes...)`, and broadcasted variants dispatch via +`infer_sizes(::Broadcasted{op}, shapes...)`. + ## Example For multiplication, `infer_sizes` can be implemented as follows @@ -367,14 +374,188 @@ function infer_sizes(::typeof(*), lhs, rhs) end ``` """ -function infer_sizes(op, child_sizes::Tuple...) +function infer_sizes(op, child_sizes...) args = map(child_sizes) do sz - return isempty(sz) ? 0.0 : zeros(sz) + return isempty(sz) ? 0.0 : zeros(Tuple(sz)) end y = op(args...) return y isa AbstractArray ? size(y) : () end +# Symbol → Val redirect: built-in tape operators carry a `Symbol`, so callers +# from `_infer_sizes` invoke `infer_sizes(:op, shapes...)` which lands here. +infer_sizes(op::Symbol, shapes...) = infer_sizes(Val(op), shapes...) + +""" + Broadcasted{op} + +Marker used to dispatch `infer_sizes` for the broadcasted variant of a +tape operator. Mirrors `Val(op)` but selects shape-combination rules that +follow Julia's broadcasting axis-combination rules instead of the +non-broadcasted (matmul/identity/etc.) rules. +""" +struct Broadcasted{S} end +Broadcasted(s::Symbol) = Broadcasted{s}() + +# ── Built-in tape operators: non-broadcasted ──────────────────────────────── +# These methods are written generically over indexable shape containers +# (tuple or `AbstractVector`), so the same definition serves both the +# JuMP-side (tuples from `size()`) and the tape-side (views into `Sizes`). + +# Pure scalar reductions +infer_sizes(::Val{:sum}, shapes...) = () +infer_sizes(::Val{:norm}, shapes...) = () +infer_sizes(::Val{:dot}, shapes...) = () + +# vect: N scalar children → 1-D vector of length N +function infer_sizes(::Val{:vect}, shapes...) + @assert all(isempty, shapes) "`vect` expects scalar children" + return (length(shapes),) +end + +# row: N scalar children → 1×N row vector +function infer_sizes(::Val{:row}, shapes...) + @assert all(isempty, shapes) "`row` expects scalar children" + return (1, length(shapes)) +end + +# +, - : non-broadcasted; all children share the same shape, output = first +infer_sizes(::Val{:+}, shape, more...) = shape +infer_sizes(::Val{:-}, shape, more...) = shape + +# hcat: rows from first arg, total cols summed across children +function infer_sizes(::Val{:hcat}, shapes...) + total_cols = sum(s -> length(s) <= 1 ? 1 : s[2], shapes) + if isempty(shapes[1]) + return (1, total_cols) + end + @assert length(shapes[1]) <= 2 "hcat with ndims > 2 is not supported yet" + return (shapes[1][1], total_cols) +end + +# vcat: cols from first arg, total rows summed across children +function infer_sizes(::Val{:vcat}, shapes...) + total_rows = sum(s -> length(s) <= 1 ? 1 : s[1], shapes) + if isempty(shapes[1]) + return (total_rows, 1) + end + @assert length(shapes[1]) <= 2 "vcat with ndims > 2 is not supported yet" + return (total_rows, shapes[1][2]) +end + +# *: matmul-like inner-dim reduction; scalar children are ignored +function infer_sizes(::Val{:*}, shapes...) + out::Tuple{Vararg{Int}} = () + for s in shapes + if isempty(s) + continue + end + if isempty(out) + out = Tuple(s) + else + @assert length(out) > 1 + @assert s[1] == out[end] + out = (out[1:(end-1)]..., s[2:end]...) + end + end + return out +end + +# ^ and / (non-broadcasted): first arg's shape, second must be scalar +function infer_sizes(::Val{:^}, base, exp) + @assert isempty(exp) "`^` expects scalar exponent" + return base +end +function infer_sizes(::Val{:/}, num, den) + @assert isempty(den) "`/` expects scalar denominator" + return num +end + +# ── Built-in tape operators: broadcasted ──────────────────────────────────── + +# Broadcasted +, -, *, /: combine via Julia's broadcasting axis-combination +# rules. Output ndims = max child ndims; size in each dim is the max size > 1. +function infer_sizes(::Broadcasted{op}, shapes...) where {op} + if !(op in (:+, :-, :*, :/)) + error("Unsupported broadcasted op `$op`") + end + nd = maximum(length, shapes; init = 0) + out = ones(Int, nd) + for sz in shapes + for j in eachindex(sz) + if sz[j] > 1 + if out[j] == 1 + out[j] = sz[j] + else + @assert out[j] == sz[j] + end + end + end + end + return out +end + +# Broadcasted ^: scalar exponent, base shape preserved +function infer_sizes(::Broadcasted{:^}, base, exp) + @assert isempty(exp) "broadcasted `^` expects scalar exponent" + return base +end + +# `:sum_dims` is the one tape op whose shape depends on data outside the +# child-shape tuple (the constant dims vector), so it can't fit the generic +# `infer_sizes(op, shapes...)` signature. Handled inline by `_infer_sizes`. +function _sum_dims_shape( + sizes, + nodes, + children_arr, + children_indices, + block_shapes, + const_values, +) + @assert length(children_indices) == 2 "`sum_dims` expects (array, dims_vector)" + arr_id = children_arr[first(children_indices)] + dims_id = children_arr[first(children_indices)+1] + @assert nodes[dims_id].type == NODE_VALUE_BLOCK "`sum_dims` requires constant dims (NODE_VALUE_BLOCK)" + dims_len = prod(block_shapes[dims_id]) + start = nodes[dims_id].index + dims_vec = const_values[(start-1) .+ (1:dims_len)] + in_ndims = sizes.ndims[arr_id] + return map(1:in_ndims) do d + return d in dims_vec ? 1 : _size(sizes, arr_id, d) + end +end + +# Resolve a tape node's "operator handle" — a `Symbol` (built-in op), a +# `Function` (user-defined ChainRules op), or `nothing` (unrecognised / +# scalar-only operator). For broadcasted multivariate nodes we wrap the +# symbol in `Broadcasted` so the corresponding `infer_sizes` method fires. +# This is type-unstable but it is only called from `_infer_sizes` for which +# performance isn't critical since it's just called at setup time. +# If performance is an issue, we can generate code with if-else like we +# already do at other places in this package. +function _shape_op(node::Node, operators) + @assert node.type == NODE_CALL_MULTIVARIATE || + node.type == NODE_CALL_MULTIVARIATE_BROADCASTED + func = nothing + if node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS) + func = DEFAULT_MULTIVARIATE_OPERATORS[node.index] + end + if operators !== nothing && + node.index in eachindex(operators.multivariate_operators) + op_sym = operators.multivariate_operators[node.index] + if haskey(operators.chainrules_operators, op_sym) + func = operators.chainrules_operators[op_sym] + end + end + if isnothing(func) + return nothing + end + if node.type == NODE_CALL_MULTIVARIATE_BROADCASTED + func = Broadcasted(func) + end + return func +end + function _infer_sizes( nodes::Vector{Node}, adj::SparseArrays.SparseMatrixCSC{Bool,Int}, @@ -396,189 +577,36 @@ function _infer_sizes( if node.type == NODE_VARIABLE_BLOCK || node.type == NODE_VALUE_BLOCK || node.type == NODE_MOI_VARIABLE_BLOCK - shape = block_shapes[k] - _add_size!(sizes, k, shape) + _add_size!(sizes, k, block_shapes[k]) continue end children_indices = SparseArrays.nzrange(adj, k) N = length(children_indices) - if node.type == NODE_CALL_MULTIVARIATE - if !(node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS)) - if operators !== nothing && - node.index in eachindex(operators.multivariate_operators) - op_sym = operators.multivariate_operators[node.index] - if haskey(operators.chainrules_operators, op_sym) - f = operators.chainrules_operators[op_sym] - child_shapes = Tuple( - ntuple( - d -> _size(sizes, children_arr[c_idx], d), - sizes.ndims[children_arr[c_idx]], - ) for c_idx in children_indices - ) - out_sz = infer_sizes(f, child_shapes...) - if !isempty(out_sz) - _add_size!(sizes, k, out_sz) - end - # Scalar output → ndims = 0 (already initialised). - continue - end - end - # TODO user-defined operators - continue + if node.type == NODE_CALL_MULTIVARIATE || + node.type == NODE_CALL_MULTIVARIATE_BROADCASTED + op = _shape_op(node, operators) + if op === nothing + continue # TODO user-defined operators end - op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] - if op == :vect - _assert_scalar_children( - sizes, - children_arr, - children_indices, - op, - ) - _add_size!(sizes, k, (N,)) - elseif op == :row - _assert_scalar_children( + if op === :sum_dims + out_sz = _sum_dims_shape( sizes, + nodes, children_arr, children_indices, - op, - ) - _add_size!(sizes, k, (1, N)) - elseif op == :dot - # TODO assert all arguments have same size - elseif op == :norm - # TODO actually norm should be moved to univariate - elseif op == :sum - # sum reduces array to scalar, ndims stays 0 - elseif op == :+ || op == :- - # TODO assert all arguments have same size - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :hcat - total_cols = 0 - for c_idx in children_indices - total_cols += - sizes.ndims[children_arr[c_idx]] <= 1 ? 1 : - _size(sizes, children_arr[c_idx], 2) - end - if sizes.ndims[children_arr[first(children_indices)]] == 0 - shape = (1, total_cols) - else - @assert sizes.ndims[children_arr[first( - children_indices, - )]] <= 2 "Hcat with ndims > 2 is not supported yet" - shape = ( - _size(sizes, children_arr[first(children_indices)], 1), - total_cols, - ) - end - _add_size!(sizes, k, tuple(shape...)) - elseif op == :vcat - total_rows = 0 - for c_idx in children_indices - total_rows += - sizes.ndims[children_arr[c_idx]] <= 1 ? 1 : - _size(sizes, children_arr[c_idx], 1) - end - if sizes.ndims[children_arr[first(children_indices)]] == 0 - shape = (total_rows, 1) - else - @assert sizes.ndims[children_arr[first( - children_indices, - )]] <= 2 "Hcat with ndims > 2 is not supported yet" - shape = ( - total_rows, - _size(sizes, children_arr[first(children_indices)], 2), - ) - end - _add_size!(sizes, k, tuple(shape...)) - elseif op == :* - sizes.ndims[k] = 0 - for child in children_indices - id = children_arr[child] - ndims = sizes.ndims[id] - if !iszero(ndims) - sz = _size(sizes, id) - if iszero(sizes.ndims[k]) - sizes.size_offset[k] = length(sizes.size) - append!(sizes.size, sz) - sizes.ndims[k] = ndims - else - @assert sizes.ndims[k] > 1 - @assert sz[1] == sizes.size[end] - pop!(sizes.size) - append!(sizes.size, @view(sz[2:end])) - sizes.ndims[k] += ndims - 2 - end - end - end - elseif op == :^ || op == :/ - @assert N == 2 - _assert_scalar_children( - sizes, - children_arr, - children_indices[2:end], - op, + block_shapes, + const_values, ) - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :sum_dims - # Two args: (array, Vector{Float64}(dims)). Output keeps the - # input ndims with the reduced dims collapsed to size 1. - @assert N == 2 "`sum_dims` expects (array, dims_vector)" - arr_id = children_arr[first(children_indices)] - dims_id = children_arr[first(children_indices)+1] - @assert nodes[dims_id].type == NODE_VALUE_BLOCK "`sum_dims` requires constant dims (NODE_VALUE_BLOCK)" - # Read the dims values out of `const_values`. The block was - # appended at `nodes[dims_id].index` with length recorded in - # `block_shapes`. - dims_len = prod(block_shapes[dims_id]) - start = nodes[dims_id].index - dims_vec = const_values[(start-1) .+ (1:dims_len)] - in_ndims = sizes.ndims[arr_id] - out_shape = map(1:in_ndims) do d - return d in dims_vec ? 1 : _size(sizes, arr_id, d) - end - _add_size!(sizes, k, out_shape) else - _assert_scalar_children( - sizes, - children_arr, - children_indices, - op, + # Pass shape views directly — no Tuple/Vector conversion. The + # `infer_sizes` methods are generic over indexable containers. + child_shapes = ntuple( + i -> _size(sizes, children_arr[children_indices[i]]), + N, ) + out_sz = infer_sizes(op, child_shapes...) end - elseif node.type == NODE_CALL_MULTIVARIATE_BROADCASTED - if !(node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS)) - # TODO user-defined operators - continue - end - op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] - if op == :+ || op == :- || op == :* || op == :/ - sizes.ndims[k] = maximum(children_indices, init = 0) do i - return sizes.ndims[children_arr[i]] - end - sizes.size_offset[k] = length(sizes.size) - for _ in 1:sizes.ndims[k] - push!(sizes.size, 1) - end - sz_parent = _size(sizes, k) - for i in children_indices - id = children_arr[i] - sz = _size(sizes, id) - for j in eachindex(sz) - if sz[j] > 1 - if sz_parent[j] == 1 - sz_parent[j] = sz[j] - else - @assert sz_parent[j] == sz[j] - end - end - end - end - elseif op == :^ - # Broadcasted ^ with scalar exponent preserves base shape - @assert length(children_indices) == 2 "Expected two arguments for broadcasted operator `$op`, got $(length(children_indices))" - @assert iszero(sizes.ndims[children_arr[children_indices[2]]]) "Expected scalar exponent for broadcasted operator `$op`" - _copy_size!(sizes, k, children_arr[first(children_indices)]) - end + _add_size!(sizes, k, out_sz) elseif node.type == NODE_CALL_UNIVARIATE if !( node.index in @@ -615,10 +643,8 @@ function _infer_sizes( continue end error("TODO user-defined operators") - continue end @assert N == 1 - op = MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS[node.index] _copy_size!(sizes, k, children_arr[first(children_indices)]) end end diff --git a/test/JuMP.jl b/test/JuMP.jl index 35ba6f6..4f13770 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -47,9 +47,9 @@ function my_crossentropy2(p, q) return -sum(q .* log.(p)) end -function ArrayDiff.infer_sizes(::typeof(my_crossentropy2), ::Tuple, ::Tuple) - return () -end +# Override works for any indexable shape kind (tuple from JuMP, view from +# tape) — leaving the args untyped exercises that. +ArrayDiff.infer_sizes(::typeof(my_crossentropy2), s1, s2) = () function ChainRulesCore.rrule( ::typeof(my_crossentropy1), @@ -520,7 +520,7 @@ function test_broadcast_nonsquare_matrix() # (2, 3). The old bug would report (2, 2) for the broadcast node. @test sizes.ndims == [0, 2, 2, 2] @test sizes.size == [2, 3, 2, 3, 2, 3] - @test sizes.size_offset == [0, 4, 2, 0] + @test sizes.size_offset == [6, 4, 2, 0] @test sizes.storage_offset == [0, 1, 7, 13, 19] @test val ≈ LinearAlgebra.norm(ref_mat) ref_g = if op == :+ From 0b898938a6a7584d512609b7bf4c7f6367250ab0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 17:05:34 +0200 Subject: [PATCH 2/9] Simplify --- src/sizes.jl | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/sizes.jl b/src/sizes.jl index 401e45a..8313e3f 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -343,12 +343,15 @@ end infer_sizes(op, child_sizes...) -> shape Return the output shape of applying `op` to arguments of shapes `child_sizes`. -Each `child_sizes[i]` is empty (`()` or `Int[]`) if argument `i` is a scalar, -or a tuple/vector of positive integers if it is an array. The returned shape -is empty for a scalar output. Methods are written generically over indexable -containers so the same definition works whether the caller passes tuples -(JuMP-side) or views into a `Sizes` buffer (tape-side); the returned shape -typically matches the input kind. +Each `child_sizes[i]` is empty if argument `i` is a scalar (`()` or `Int[]`), +or any indexable container of positive integers if it is an array. The +returned shape is empty for a scalar output. + +Inputs may be tuples (JuMP-side, from `size()`) or `AbstractVector{Int}` +(tape-side — typically views into a `Sizes` buffer). The returned shape can +likewise be a tuple or an `AbstractVector{Int}` (including a view): pick +whichever makes the method type-stable. `_infer_sizes` and +`_build_user_op_expr` accept either. The default implementation constructs dummy arguments with `zeros(sz)` (or `0.0` for scalars) and calls `op(args...)`. Specialise on `op`'s @@ -376,7 +379,7 @@ end """ function infer_sizes(op, child_sizes...) args = map(child_sizes) do sz - return isempty(sz) ? 0.0 : zeros(Tuple(sz)) + return isempty(sz) ? 0.0 : zeros(sz...) end y = op(args...) return y isa AbstractArray ? size(y) : () @@ -443,19 +446,24 @@ function infer_sizes(::Val{:vcat}, shapes...) return (total_rows, shapes[1][2]) end -# *: matmul-like inner-dim reduction; scalar children are ignored +# *: matmul-like inner-dim reduction; scalar children are ignored. Returns a +# `Vector{Int}` so the accumulator is type-stable across the loop (a tuple +# accumulator would change type each iteration as the length varies). function infer_sizes(::Val{:*}, shapes...) - out::Tuple{Vararg{Int}} = () + out = Int[] for s in shapes if isempty(s) continue end if isempty(out) - out = Tuple(s) + append!(out, s) else @assert length(out) > 1 @assert s[1] == out[end] - out = (out[1:(end-1)]..., s[2:end]...) + pop!(out) + for j in 2:length(s) + push!(out, s[j]) + end end end return out From 61316ed395b4102b3b4afe773dd6b69b826d845f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 26 May 2026 18:20:05 +0200 Subject: [PATCH 3/9] Use functions, not Val --- src/sizes.jl | 117 +++++++++++++++++++++++++++------------------------ 1 file changed, 61 insertions(+), 56 deletions(-) diff --git a/src/sizes.jl b/src/sizes.jl index 8313e3f..0534591 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -385,49 +385,28 @@ function infer_sizes(op, child_sizes...) return y isa AbstractArray ? size(y) : () end -# Symbol → Val redirect: built-in tape operators carry a `Symbol`, so callers -# from `_infer_sizes` invoke `infer_sizes(:op, shapes...)` which lands here. -infer_sizes(op::Symbol, shapes...) = infer_sizes(Val(op), shapes...) - -""" - Broadcasted{op} - -Marker used to dispatch `infer_sizes` for the broadcasted variant of a -tape operator. Mirrors `Val(op)` but selects shape-combination rules that -follow Julia's broadcasting axis-combination rules instead of the -non-broadcasted (matmul/identity/etc.) rules. -""" -struct Broadcasted{S} end -Broadcasted(s::Symbol) = Broadcasted{s}() - # ── Built-in tape operators: non-broadcasted ──────────────────────────────── # These methods are written generically over indexable shape containers # (tuple or `AbstractVector`), so the same definition serves both the # JuMP-side (tuples from `size()`) and the tape-side (views into `Sizes`). # Pure scalar reductions -infer_sizes(::Val{:sum}, shapes...) = () -infer_sizes(::Val{:norm}, shapes...) = () -infer_sizes(::Val{:dot}, shapes...) = () +infer_sizes(::typeof(sum), shapes...) = () +infer_sizes(::typeof(LinearAlgebra.norm), shapes...) = () +infer_sizes(::typeof(LinearAlgebra.dot), shapes...) = () # vect: N scalar children → 1-D vector of length N -function infer_sizes(::Val{:vect}, shapes...) +function infer_sizes(::typeof(Base.vect), shapes...) @assert all(isempty, shapes) "`vect` expects scalar children" return (length(shapes),) end -# row: N scalar children → 1×N row vector -function infer_sizes(::Val{:row}, shapes...) - @assert all(isempty, shapes) "`row` expects scalar children" - return (1, length(shapes)) -end - # +, - : non-broadcasted; all children share the same shape, output = first -infer_sizes(::Val{:+}, shape, more...) = shape -infer_sizes(::Val{:-}, shape, more...) = shape +infer_sizes(::typeof(+), shape, more...) = shape +infer_sizes(::typeof(-), shape, more...) = shape # hcat: rows from first arg, total cols summed across children -function infer_sizes(::Val{:hcat}, shapes...) +function infer_sizes(::typeof(hcat), shapes...) total_cols = sum(s -> length(s) <= 1 ? 1 : s[2], shapes) if isempty(shapes[1]) return (1, total_cols) @@ -437,7 +416,7 @@ function infer_sizes(::Val{:hcat}, shapes...) end # vcat: cols from first arg, total rows summed across children -function infer_sizes(::Val{:vcat}, shapes...) +function infer_sizes(::typeof(vcat), shapes...) total_rows = sum(s -> length(s) <= 1 ? 1 : s[1], shapes) if isempty(shapes[1]) return (total_rows, 1) @@ -449,7 +428,7 @@ end # *: matmul-like inner-dim reduction; scalar children are ignored. Returns a # `Vector{Int}` so the accumulator is type-stable across the loop (a tuple # accumulator would change type each iteration as the length varies). -function infer_sizes(::Val{:*}, shapes...) +function infer_sizes(::typeof(*), shapes...) out = Int[] for s in shapes if isempty(s) @@ -470,23 +449,27 @@ function infer_sizes(::Val{:*}, shapes...) end # ^ and / (non-broadcasted): first arg's shape, second must be scalar -function infer_sizes(::Val{:^}, base, exp) +function infer_sizes(::typeof(^), base, exp) @assert isempty(exp) "`^` expects scalar exponent" return base end -function infer_sizes(::Val{:/}, num, den) +function infer_sizes(::typeof(/), num, den) @assert isempty(den) "`/` expects scalar denominator" return num end # ── Built-in tape operators: broadcasted ──────────────────────────────────── +# A broadcasted call lowers to `Base.broadcasted(op, args...)`; we use the +# same shape — dispatching on `(typeof(broadcasted), typeof(op))` instead of +# defining a parallel marker type. # Broadcasted +, -, *, /: combine via Julia's broadcasting axis-combination # rules. Output ndims = max child ndims; size in each dim is the max size > 1. -function infer_sizes(::Broadcasted{op}, shapes...) where {op} - if !(op in (:+, :-, :*, :/)) - error("Unsupported broadcasted op `$op`") - end +function infer_sizes( + ::typeof(Base.broadcasted), + ::Union{typeof(+),typeof(-),typeof(*),typeof(/)}, + shapes..., +) nd = maximum(length, shapes; init = 0) out = ones(Int, nd) for sz in shapes @@ -504,11 +487,37 @@ function infer_sizes(::Broadcasted{op}, shapes...) where {op} end # Broadcasted ^: scalar exponent, base shape preserved -function infer_sizes(::Broadcasted{:^}, base, exp) +function infer_sizes(::typeof(Base.broadcasted), ::typeof(^), base, exp) @assert isempty(exp) "broadcasted `^` expects scalar exponent" return base end +# `:row` is an MOI tape primitive (one row of a matrix literal) without a +# corresponding Julia function. Define an empty function so `infer_sizes` can +# still dispatch on `typeof(_row_op)` — `_row_op` is never actually called. +function _row_op end + +function infer_sizes(::typeof(_row_op), shapes...) + @assert all(isempty, shapes) "`row` expects scalar children" + return (1, length(shapes)) +end + +# Map a built-in operator symbol to its Julia function so `infer_sizes` can +# dispatch on `typeof(fn)`. Returns `nothing` for `:sum_dims`, whose shape +# depends on the constant dims vector and is handled inline by `_infer_sizes`. +function _default_op_function(sym::Symbol) + if sym === :sum_dims + return nothing + end + if sym === :row + return _row_op + end + if sym === :dot || sym === :norm + return getfield(LinearAlgebra, sym) + end + return getfield(Base, sym) +end + # `:sum_dims` is the one tape op whose shape depends on data outside the # child-shape tuple (the constant dims vector), so it can't fit the generic # `infer_sizes(op, shapes...)` signature. Handled inline by `_infer_sizes`. @@ -533,35 +542,27 @@ function _sum_dims_shape( end end -# Resolve a tape node's "operator handle" — a `Symbol` (built-in op), a -# `Function` (user-defined ChainRules op), or `nothing` (unrecognised / -# scalar-only operator). For broadcasted multivariate nodes we wrap the -# symbol in `Broadcasted` so the corresponding `infer_sizes` method fires. -# This is type-unstable but it is only called from `_infer_sizes` for which -# performance isn't critical since it's just called at setup time. -# If performance is an issue, we can generate code with if-else like we -# already do at other places in this package. +# Resolve a multivariate tape node's operator to the Julia `Function` whose +# `infer_sizes` method computes its output shape. Returns the bare op symbol +# (`:sum_dims`) when the shape rule needs out-of-band data and is handled +# inline by `_infer_sizes`, or `nothing` for unrecognised operators. +# Type-unstable on purpose — called only at setup time. function _shape_op(node::Node, operators) @assert node.type == NODE_CALL_MULTIVARIATE || node.type == NODE_CALL_MULTIVARIATE_BROADCASTED - func = nothing if node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS) - func = DEFAULT_MULTIVARIATE_OPERATORS[node.index] + sym = DEFAULT_MULTIVARIATE_OPERATORS[node.index] + func = _default_op_function(sym) + return func === nothing ? sym : func end if operators !== nothing && node.index in eachindex(operators.multivariate_operators) op_sym = operators.multivariate_operators[node.index] if haskey(operators.chainrules_operators, op_sym) - func = operators.chainrules_operators[op_sym] + return operators.chainrules_operators[op_sym] end end - if isnothing(func) - return nothing - end - if node.type == NODE_CALL_MULTIVARIATE_BROADCASTED - func = Broadcasted(func) - end - return func + return nothing end function _infer_sizes( @@ -612,7 +613,11 @@ function _infer_sizes( i -> _size(sizes, children_arr[children_indices[i]]), N, ) - out_sz = infer_sizes(op, child_shapes...) + out_sz = if node.type == NODE_CALL_MULTIVARIATE_BROADCASTED + infer_sizes(Base.broadcasted, op, child_shapes...) + else + infer_sizes(op, child_shapes...) + end end _add_size!(sizes, k, out_sz) elseif node.type == NODE_CALL_UNIVARIATE From c45a2979a1d543f0453d7497d59da956d2dc7376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 21:29:54 +0200 Subject: [PATCH 4/9] Fix --- src/sizes.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/sizes.jl b/src/sizes.jl index 0534591..9788679 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -405,6 +405,11 @@ end infer_sizes(::typeof(+), shape, more...) = shape infer_sizes(::typeof(-), shape, more...) = shape +function infer_sizes(::typeof(ifelse), cond, lhs, rhs) + @assert lhs == rhs + return lhs +end + # hcat: rows from first arg, total cols summed across children function infer_sizes(::typeof(hcat), shapes...) total_cols = sum(s -> length(s) <= 1 ? 1 : s[2], shapes) From 89a316eecd8b8a1ef9350863266f5ee6a6813657 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 21:32:04 +0200 Subject: [PATCH 5/9] Update test --- test/ArrayDiff.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index 067eae8..eac5625 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -29,7 +29,7 @@ function test_objective_dot_univariate() MOI.initialize(evaluator, [:Grad, :Hess]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 0, 1, 0] - @test sizes.size_offset == [0, 1, 0, 0, 0] + @test sizes.size_offset == [2, 1, 0, 0, 0] @test sizes.size == [1, 1] @test sizes.storage_offset == [0, 1, 2, 3, 4, 5] xv = [1.2] @@ -75,7 +75,7 @@ function test_objective_dot_bivariate() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0] - @test sizes.size_offset == [0, 6, 5, 0, 0, 4, 0, 0, 3, 2, 1, 0, 0, 0, 0, 0] + @test sizes.size_offset == [7, 6, 5, 0, 0, 4, 0, 0, 3, 2, 1, 0, 0, 0, 0, 0] @test sizes.size == [2, 2, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11, 13, 15, 17, 18, 19, 21, 22, 23] @@ -158,7 +158,7 @@ function test_objective_dot_bivariate_on_rows() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 0, 0, 2, 0, 0, 2, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 0, 0, 8, 0, 0, 6, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 0, 0, 8, 0, 0, 6, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11, 13, 15, 17, 18, 19, 21, 22, 23] @@ -608,7 +608,7 @@ function test_objective_broadcasted_product() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0] - @test sizes.size_offset == [0, 2, 1, 0, 0, 0, 0, 0] + @test sizes.size_offset == [3, 2, 1, 0, 0, 0, 0, 0] @test sizes.size == [2, 2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11] x1 = 1.0 @@ -640,7 +640,7 @@ function test_objective_broadcasted_matrix_product() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 9, 11, 12, 13, 15, 16, 17, 21, 23, 24, 25, 27, 28, 29] @@ -673,7 +673,7 @@ function test_objective_broadcasted_tanh() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 1, 0, 0] - @test sizes.size_offset == [0, 1, 0, 0, 0] + @test sizes.size_offset == [2, 1, 0, 0, 0] @test sizes.size == [2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7] x1 = 1.0 From 6c17fdd5d122bc5dd7bac0f4f5017514ef291dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 21:39:45 +0200 Subject: [PATCH 6/9] Fixes --- test/ArrayDiff.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index eac5625..8df8625 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -50,7 +50,7 @@ function test_objective_dot_univariate_and_scalar_mult() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 0, 0, 1, 0, 1, 0] - @test sizes.size_offset == [0, 0, 0, 1, 0, 0, 0] + @test sizes.size_offset == [2, 0, 2, 1, 0, 0, 0] @test sizes.size == [1, 1] @test sizes.storage_offset == [0, 1, 2, 3, 4, 5, 6, 7] xv = [1.2] @@ -100,7 +100,7 @@ function test_objective_hcat_scalars() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5, 7, 8, 9] x1 = 1.0 @@ -131,7 +131,7 @@ function test_objective_hcat_vectors() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 1, 0, 1, 0, 2, 1, 0, 1, 0] - @test sizes.size_offset == [0, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0] + @test sizes.size_offset == [8, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0] @test sizes.size == [1, 1, 1, 2, 1, 1, 1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13] x1 = 1.0 @@ -202,7 +202,7 @@ function test_objective_norm_bivariate() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [1, 0, 0, 0] @test sizes.size == [2] @test sizes.storage_offset == [0, 1, 3, 4, 5] xv = [3.0, 4.0] @@ -231,7 +231,7 @@ function test_objective_norm_of_row_vector() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [2, 0, 0, 0] @test sizes.size == [1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5] x1 = 1.0 @@ -316,7 +316,7 @@ function test_objective_norm_of_row() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [2, 0, 0, 0] @test sizes.size == [1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5] x1 = 1.0 @@ -342,7 +342,7 @@ function test_objective_norm_of_matrix() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 4, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [6, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 7, 8, 9, 11, 12, 13] x1 = 1.0 @@ -376,7 +376,7 @@ function test_objective_norm_of_matrix_with_sum() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 9, 11, 12, 13, 15, 16, 17, 21, 23, 24, 25, 27, 28, 29] @@ -406,7 +406,7 @@ function test_objective_norm_of_product_of_matrices() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 9, 11, 12, 13, 15, 16, 17, 21, 23, 24, 25, 27, 28, 29] @@ -469,7 +469,7 @@ function test_objective_norm_of_product_of_matrices_with_sum() 0, ] @test sizes.size_offset == [ - 0, + 22, 20, 18, 16, @@ -553,7 +553,7 @@ function test_objective_norm_of_mtx_vector_product() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 8, 6, 4, 0, 0, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [10, 8, 6, 4, 0, 0, 2, 0, 0, 0, 0, 0] @test sizes.size == [2, 1, 1, 2, 1, 2, 2, 2, 2, 1] @test sizes.storage_offset == [0, 1, 3, 7, 9, 10, 11, 13, 14, 15, 17, 18, 19] From 22cf85c4797cd3778fa97c498b719b56380c55cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 21:52:57 +0200 Subject: [PATCH 7/9] up tests --- test/ArrayDiff.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index 8df8625..b4bda2e 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -180,7 +180,7 @@ function test_objective_norm_univariate() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 0] - @test sizes.size_offset == [0, 0, 0] + @test sizes.size_offset == [1, 0, 0] @test sizes.size == [1] @test sizes.storage_offset == [0, 1, 2, 3] xv = [1.2] @@ -257,7 +257,7 @@ function test_objective_norm_of_vcat_vector() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [2, 0, 0, 0] @test sizes.size == [2, 1] @test sizes.storage_offset == [0, 1, 3, 4, 5] x1 = 1.0 @@ -285,7 +285,7 @@ function test_objective_norm_of_vcat_matrix() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 4, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [6, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 7, 8, 9, 11, 12, 13] x1 = 1.0 From 0f7935c16d5a5dd56ac509d05d9d4d03901e8824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 21:58:11 +0200 Subject: [PATCH 8/9] Fix format --- src/Coloring/Coloring.jl | 6 +++--- src/reverse_mode.jl | 4 ++-- src/sizes.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Coloring/Coloring.jl b/src/Coloring/Coloring.jl index 1ef8dc5..c97a7f4 100644 --- a/src/Coloring/Coloring.jl +++ b/src/Coloring/Coloring.jl @@ -30,14 +30,14 @@ IndexedSet(n::Integer) = IndexedSet(zeros(Int, n), trues(n), 0) function Base.push!(v::IndexedSet, i::Integer) if v.empty[i] # new index - v.nzidx[v.nnz+=1] = i + v.nzidx[v.nnz += 1] = i v.empty[i] = false end return end function Base.empty!(v::IndexedSet) - for i in 1:v.nnz + for i in 1:(v.nnz) v.empty[v.nzidx[i]] = true end v.nnz = 0 @@ -58,7 +58,7 @@ function Base.resize!(v::IndexedSet, n::Integer) return end -Base.collect(v::IndexedSet) = v.nzidx[1:v.nnz] +Base.collect(v::IndexedSet) = v.nzidx[1:(v.nnz)] function Base.union!(v::IndexedSet, s) for x in s diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index db93906..2fffee0 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -133,7 +133,7 @@ function _forward_eval( len = length(tape_range) copyto!( view(f.forward_storage, tape_range), - view(x, node.index:(node.index+len-1)), + view(x, (node.index):(node.index+len-1)), ) elseif node.type == NODE_VALUE_BLOCK # Pre-loaded into `forward_storage` at construction. @@ -1169,7 +1169,7 @@ function _extract_reverse_pass_inner( if node.type == NODE_VARIABLE_BLOCK tape_range = _storage_range(f.sizes, k) len = length(tape_range) - x_range = node.index:(node.index+len-1) + x_range = (node.index):(node.index+len-1) view(output, x_range) .+= scale .* view(f.reverse_storage, tape_range) elseif node.type == NODE_VARIABLE diff --git a/src/sizes.jl b/src/sizes.jl index 9788679..0a853ef 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -711,7 +711,7 @@ struct _SubexpressionStorage{T<:Real,S<:AbstractVector{T}} j = sizes.storage_offset[k] + 1 len = _length(sizes, k) cpu_buffer[j:(j+len-1)] .= - view(const_values, node.index:(node.index+len-1)) + view(const_values, (node.index):(node.index+len-1)) end end forward_storage = convert(S, cpu_buffer) From d0bda4e6390a8874ee492e0ae2093eeb5bc4df1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 27 May 2026 23:24:22 +0200 Subject: [PATCH 9/9] Add tests --- src/reverse_mode.jl | 10 +++++++++ test/JuMP.jl | 52 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 2fffee0..767c0aa 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -745,6 +745,16 @@ function __reverse_broadcasted_div(f, ilhs, irhs, dout, dlhs, drhs) ) end +# Reverse for `:sum_dims`. `y = sum(x; dims=d)` collapses one axis, so +# ∂y[i…]/∂x[j…] is 1 iff `j` matches `i` on every non-reduced axis (any +# `j` along the reduced axis maps to the same `y` slot). Broadcasting the +# (m,1) or (1,n) parent adjoint across the full child shape produces +# exactly that pattern. +function _reverse_sum_dims!(rev_arr, rev_parent) + rev_arr .= rev_parent + return +end + """ _reverse_eval(f::_SubexpressionStorage) diff --git a/test/JuMP.jl b/test/JuMP.jl index 4f13770..e7f4490 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -503,6 +503,58 @@ function test_size_vec_vect() return end +# `:sum_dims` is emitted by `Base.sum(::AbstractJuMPArray; dims=…)` and is the +# reduction building block both `LayerNorm` and `softmax` rely on, but no +# existing test reaches it (`grep "sum.*dims" test/` only matches plain Julia +# arrays). These tests exercise the forward path, the `_sum_dims_shape` +# branch in `_infer_sizes`, and the reverse-mode contribution. The analytic +# gradient of `f = ‖sum(W; dims=d)‖` is `s[i_d] / ‖s‖` broadcast over the +# reduced axis. +function test_sum_dims_along_rows() + rows, cols = 2, 3 + model = Model() + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + expr = sum(W; dims = 2) + @test expr isa ArrayDiff.MatrixExpr + @test expr.head == :sum_dims + @test size(expr) == (rows, 1) + x = Float64.(collect(1:(rows*cols))) + W_val = reshape(x, rows, cols) + s = sum(W_val; dims = 2) + sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x) + @test val ≈ LinearAlgebra.norm(s) + # `s` is the (rows, 1) column; broadcast across the reduced (cols) axis. + @test g ≈ vec(repeat(s, 1, cols)) ./ LinearAlgebra.norm(s) + # Tape: norm (k=1, scalar) → sum_dims (k=2, (rows, 1)). + @test sizes.ndims[1] == 0 + @test sizes.ndims[2] == 2 + sd_off = sizes.size_offset[2] + @test sizes.size[sd_off+1] == rows + @test sizes.size[sd_off+2] == 1 + return +end + +function test_sum_dims_along_cols() + rows, cols = 2, 3 + model = Model() + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + expr = sum(W; dims = 1) + @test expr.head == :sum_dims + @test size(expr) == (1, cols) + x = Float64.(collect(1:(rows*cols))) + W_val = reshape(x, rows, cols) + s = sum(W_val; dims = 1) + sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x) + @test val ≈ LinearAlgebra.norm(s) + # `s` is the (1, cols) row; broadcast across the reduced (rows) axis. + @test g ≈ vec(repeat(s, rows, 1)) ./ LinearAlgebra.norm(s) + @test sizes.ndims[2] == 2 + sd_off = sizes.size_offset[2] + @test sizes.size[sd_off+1] == 1 + @test sizes.size[sd_off+2] == cols + return +end + function test_broadcast_nonsquare_matrix() model = Model() @variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables)