diff --git a/src/Coloring/Coloring.jl b/src/Coloring/Coloring.jl index 1ef8dc5..c97a7f4 100644 --- a/src/Coloring/Coloring.jl +++ b/src/Coloring/Coloring.jl @@ -30,14 +30,14 @@ IndexedSet(n::Integer) = IndexedSet(zeros(Int, n), trues(n), 0) function Base.push!(v::IndexedSet, i::Integer) if v.empty[i] # new index - v.nzidx[v.nnz+=1] = i + v.nzidx[v.nnz += 1] = i v.empty[i] = false end return end function Base.empty!(v::IndexedSet) - for i in 1:v.nnz + for i in 1:(v.nnz) v.empty[v.nzidx[i]] = true end v.nnz = 0 @@ -58,7 +58,7 @@ function Base.resize!(v::IndexedSet, n::Integer) return end -Base.collect(v::IndexedSet) = v.nzidx[1:v.nnz] +Base.collect(v::IndexedSet) = v.nzidx[1:(v.nnz)] function Base.union!(v::IndexedSet, s) for x in s diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index db93906..767c0aa 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -133,7 +133,7 @@ function _forward_eval( len = length(tape_range) copyto!( view(f.forward_storage, tape_range), - view(x, node.index:(node.index+len-1)), + view(x, (node.index):(node.index+len-1)), ) elseif node.type == NODE_VALUE_BLOCK # Pre-loaded into `forward_storage` at construction. @@ -745,6 +745,16 @@ function __reverse_broadcasted_div(f, ilhs, irhs, dout, dlhs, drhs) ) end +# Reverse for `:sum_dims`. `y = sum(x; dims=d)` collapses one axis, so +# ∂y[i…]/∂x[j…] is 1 iff `j` matches `i` on every non-reduced axis (any +# `j` along the reduced axis maps to the same `y` slot). Broadcasting the +# (m,1) or (1,n) parent adjoint across the full child shape produces +# exactly that pattern. +function _reverse_sum_dims!(rev_arr, rev_parent) + rev_arr .= rev_parent + return +end + """ _reverse_eval(f::_SubexpressionStorage) @@ -1169,7 +1179,7 @@ function _extract_reverse_pass_inner( if node.type == NODE_VARIABLE_BLOCK tape_range = _storage_range(f.sizes, k) len = length(tape_range) - x_range = node.index:(node.index+len-1) + x_range = (node.index):(node.index+len-1) view(output, x_range) .+= scale .* view(f.reverse_storage, tape_range) elseif node.type == NODE_VARIABLE diff --git a/src/sizes.jl b/src/sizes.jl index f250a0d..0a853ef 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -340,18 +340,28 @@ function _assert_scalar_children(sizes, children_arr, children_indices, op) end """ - infer_sizes(op, child_sizes::Tuple...) -> Tuple + infer_sizes(op, child_sizes...) -> shape Return the output shape of applying `op` to arguments of shapes `child_sizes`. -Each `child_sizes[i]` is `()` if argument `i` is a scalar, or a tuple of -positive integers if it is an array. The returned shape is `()` for a scalar -result. +Each `child_sizes[i]` is empty if argument `i` is a scalar (`()` or `Int[]`), +or any indexable container of positive integers if it is an array. The +returned shape is empty for a scalar output. + +Inputs may be tuples (JuMP-side, from `size()`) or `AbstractVector{Int}` +(tape-side — typically views into a `Sizes` buffer). The returned shape can +likewise be a tuple or an `AbstractVector{Int}` (including a view): pick +whichever makes the method type-stable. `_infer_sizes` and +`_build_user_op_expr` accept either. The default implementation constructs dummy arguments with `zeros(sz)` (or `0.0` for scalars) and calls `op(args...)`. Specialise on `op`'s `typeof` to avoid the allocation, to support operators that error on zero inputs, or to compute the output shape symbolically. +For tape-internal use, built-in operator symbols dispatch via +`infer_sizes(::Val{op}, shapes...)`, and broadcasted variants dispatch via +`infer_sizes(::Broadcasted{op}, shapes...)`. + ## Example For multiplication, `infer_sizes` can be implemented as follows @@ -367,14 +377,199 @@ function infer_sizes(::typeof(*), lhs, rhs) end ``` """ -function infer_sizes(op, child_sizes::Tuple...) +function infer_sizes(op, child_sizes...) args = map(child_sizes) do sz - return isempty(sz) ? 0.0 : zeros(sz) + return isempty(sz) ? 0.0 : zeros(sz...) end y = op(args...) return y isa AbstractArray ? size(y) : () end +# ── Built-in tape operators: non-broadcasted ──────────────────────────────── +# These methods are written generically over indexable shape containers +# (tuple or `AbstractVector`), so the same definition serves both the +# JuMP-side (tuples from `size()`) and the tape-side (views into `Sizes`). + +# Pure scalar reductions +infer_sizes(::typeof(sum), shapes...) = () +infer_sizes(::typeof(LinearAlgebra.norm), shapes...) = () +infer_sizes(::typeof(LinearAlgebra.dot), shapes...) = () + +# vect: N scalar children → 1-D vector of length N +function infer_sizes(::typeof(Base.vect), shapes...) + @assert all(isempty, shapes) "`vect` expects scalar children" + return (length(shapes),) +end + +# +, - : non-broadcasted; all children share the same shape, output = first +infer_sizes(::typeof(+), shape, more...) = shape +infer_sizes(::typeof(-), shape, more...) = shape + +function infer_sizes(::typeof(ifelse), cond, lhs, rhs) + @assert lhs == rhs + return lhs +end + +# hcat: rows from first arg, total cols summed across children +function infer_sizes(::typeof(hcat), shapes...) + total_cols = sum(s -> length(s) <= 1 ? 1 : s[2], shapes) + if isempty(shapes[1]) + return (1, total_cols) + end + @assert length(shapes[1]) <= 2 "hcat with ndims > 2 is not supported yet" + return (shapes[1][1], total_cols) +end + +# vcat: cols from first arg, total rows summed across children +function infer_sizes(::typeof(vcat), shapes...) + total_rows = sum(s -> length(s) <= 1 ? 1 : s[1], shapes) + if isempty(shapes[1]) + return (total_rows, 1) + end + @assert length(shapes[1]) <= 2 "vcat with ndims > 2 is not supported yet" + return (total_rows, shapes[1][2]) +end + +# *: matmul-like inner-dim reduction; scalar children are ignored. Returns a +# `Vector{Int}` so the accumulator is type-stable across the loop (a tuple +# accumulator would change type each iteration as the length varies). +function infer_sizes(::typeof(*), shapes...) + out = Int[] + for s in shapes + if isempty(s) + continue + end + if isempty(out) + append!(out, s) + else + @assert length(out) > 1 + @assert s[1] == out[end] + pop!(out) + for j in 2:length(s) + push!(out, s[j]) + end + end + end + return out +end + +# ^ and / (non-broadcasted): first arg's shape, second must be scalar +function infer_sizes(::typeof(^), base, exp) + @assert isempty(exp) "`^` expects scalar exponent" + return base +end +function infer_sizes(::typeof(/), num, den) + @assert isempty(den) "`/` expects scalar denominator" + return num +end + +# ── Built-in tape operators: broadcasted ──────────────────────────────────── +# A broadcasted call lowers to `Base.broadcasted(op, args...)`; we use the +# same shape — dispatching on `(typeof(broadcasted), typeof(op))` instead of +# defining a parallel marker type. + +# Broadcasted +, -, *, /: combine via Julia's broadcasting axis-combination +# rules. Output ndims = max child ndims; size in each dim is the max size > 1. +function infer_sizes( + ::typeof(Base.broadcasted), + ::Union{typeof(+),typeof(-),typeof(*),typeof(/)}, + shapes..., +) + nd = maximum(length, shapes; init = 0) + out = ones(Int, nd) + for sz in shapes + for j in eachindex(sz) + if sz[j] > 1 + if out[j] == 1 + out[j] = sz[j] + else + @assert out[j] == sz[j] + end + end + end + end + return out +end + +# Broadcasted ^: scalar exponent, base shape preserved +function infer_sizes(::typeof(Base.broadcasted), ::typeof(^), base, exp) + @assert isempty(exp) "broadcasted `^` expects scalar exponent" + return base +end + +# `:row` is an MOI tape primitive (one row of a matrix literal) without a +# corresponding Julia function. Define an empty function so `infer_sizes` can +# still dispatch on `typeof(_row_op)` — `_row_op` is never actually called. +function _row_op end + +function infer_sizes(::typeof(_row_op), shapes...) + @assert all(isempty, shapes) "`row` expects scalar children" + return (1, length(shapes)) +end + +# Map a built-in operator symbol to its Julia function so `infer_sizes` can +# dispatch on `typeof(fn)`. Returns `nothing` for `:sum_dims`, whose shape +# depends on the constant dims vector and is handled inline by `_infer_sizes`. +function _default_op_function(sym::Symbol) + if sym === :sum_dims + return nothing + end + if sym === :row + return _row_op + end + if sym === :dot || sym === :norm + return getfield(LinearAlgebra, sym) + end + return getfield(Base, sym) +end + +# `:sum_dims` is the one tape op whose shape depends on data outside the +# child-shape tuple (the constant dims vector), so it can't fit the generic +# `infer_sizes(op, shapes...)` signature. Handled inline by `_infer_sizes`. +function _sum_dims_shape( + sizes, + nodes, + children_arr, + children_indices, + block_shapes, + const_values, +) + @assert length(children_indices) == 2 "`sum_dims` expects (array, dims_vector)" + arr_id = children_arr[first(children_indices)] + dims_id = children_arr[first(children_indices)+1] + @assert nodes[dims_id].type == NODE_VALUE_BLOCK "`sum_dims` requires constant dims (NODE_VALUE_BLOCK)" + dims_len = prod(block_shapes[dims_id]) + start = nodes[dims_id].index + dims_vec = const_values[(start-1) .+ (1:dims_len)] + in_ndims = sizes.ndims[arr_id] + return map(1:in_ndims) do d + return d in dims_vec ? 1 : _size(sizes, arr_id, d) + end +end + +# Resolve a multivariate tape node's operator to the Julia `Function` whose +# `infer_sizes` method computes its output shape. Returns the bare op symbol +# (`:sum_dims`) when the shape rule needs out-of-band data and is handled +# inline by `_infer_sizes`, or `nothing` for unrecognised operators. +# Type-unstable on purpose — called only at setup time. +function _shape_op(node::Node, operators) + @assert node.type == NODE_CALL_MULTIVARIATE || + node.type == NODE_CALL_MULTIVARIATE_BROADCASTED + if node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS) + sym = DEFAULT_MULTIVARIATE_OPERATORS[node.index] + func = _default_op_function(sym) + return func === nothing ? sym : func + end + if operators !== nothing && + node.index in eachindex(operators.multivariate_operators) + op_sym = operators.multivariate_operators[node.index] + if haskey(operators.chainrules_operators, op_sym) + return operators.chainrules_operators[op_sym] + end + end + return nothing +end + function _infer_sizes( nodes::Vector{Node}, adj::SparseArrays.SparseMatrixCSC{Bool,Int}, @@ -396,189 +591,40 @@ function _infer_sizes( if node.type == NODE_VARIABLE_BLOCK || node.type == NODE_VALUE_BLOCK || node.type == NODE_MOI_VARIABLE_BLOCK - shape = block_shapes[k] - _add_size!(sizes, k, shape) + _add_size!(sizes, k, block_shapes[k]) continue end children_indices = SparseArrays.nzrange(adj, k) N = length(children_indices) - if node.type == NODE_CALL_MULTIVARIATE - if !(node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS)) - if operators !== nothing && - node.index in eachindex(operators.multivariate_operators) - op_sym = operators.multivariate_operators[node.index] - if haskey(operators.chainrules_operators, op_sym) - f = operators.chainrules_operators[op_sym] - child_shapes = Tuple( - ntuple( - d -> _size(sizes, children_arr[c_idx], d), - sizes.ndims[children_arr[c_idx]], - ) for c_idx in children_indices - ) - out_sz = infer_sizes(f, child_shapes...) - if !isempty(out_sz) - _add_size!(sizes, k, out_sz) - end - # Scalar output → ndims = 0 (already initialised). - continue - end - end - # TODO user-defined operators - continue + if node.type == NODE_CALL_MULTIVARIATE || + node.type == NODE_CALL_MULTIVARIATE_BROADCASTED + op = _shape_op(node, operators) + if op === nothing + continue # TODO user-defined operators end - op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] - if op == :vect - _assert_scalar_children( - sizes, - children_arr, - children_indices, - op, - ) - _add_size!(sizes, k, (N,)) - elseif op == :row - _assert_scalar_children( + if op === :sum_dims + out_sz = _sum_dims_shape( sizes, + nodes, children_arr, children_indices, - op, + block_shapes, + const_values, ) - _add_size!(sizes, k, (1, N)) - elseif op == :dot - # TODO assert all arguments have same size - elseif op == :norm - # TODO actually norm should be moved to univariate - elseif op == :sum - # sum reduces array to scalar, ndims stays 0 - elseif op == :+ || op == :- - # TODO assert all arguments have same size - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :hcat - total_cols = 0 - for c_idx in children_indices - total_cols += - sizes.ndims[children_arr[c_idx]] <= 1 ? 1 : - _size(sizes, children_arr[c_idx], 2) - end - if sizes.ndims[children_arr[first(children_indices)]] == 0 - shape = (1, total_cols) - else - @assert sizes.ndims[children_arr[first( - children_indices, - )]] <= 2 "Hcat with ndims > 2 is not supported yet" - shape = ( - _size(sizes, children_arr[first(children_indices)], 1), - total_cols, - ) - end - _add_size!(sizes, k, tuple(shape...)) - elseif op == :vcat - total_rows = 0 - for c_idx in children_indices - total_rows += - sizes.ndims[children_arr[c_idx]] <= 1 ? 1 : - _size(sizes, children_arr[c_idx], 1) - end - if sizes.ndims[children_arr[first(children_indices)]] == 0 - shape = (total_rows, 1) - else - @assert sizes.ndims[children_arr[first( - children_indices, - )]] <= 2 "Hcat with ndims > 2 is not supported yet" - shape = ( - total_rows, - _size(sizes, children_arr[first(children_indices)], 2), - ) - end - _add_size!(sizes, k, tuple(shape...)) - elseif op == :* - sizes.ndims[k] = 0 - for child in children_indices - id = children_arr[child] - ndims = sizes.ndims[id] - if !iszero(ndims) - sz = _size(sizes, id) - if iszero(sizes.ndims[k]) - sizes.size_offset[k] = length(sizes.size) - append!(sizes.size, sz) - sizes.ndims[k] = ndims - else - @assert sizes.ndims[k] > 1 - @assert sz[1] == sizes.size[end] - pop!(sizes.size) - append!(sizes.size, @view(sz[2:end])) - sizes.ndims[k] += ndims - 2 - end - end - end - elseif op == :^ || op == :/ - @assert N == 2 - _assert_scalar_children( - sizes, - children_arr, - children_indices[2:end], - op, - ) - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :sum_dims - # Two args: (array, Vector{Float64}(dims)). Output keeps the - # input ndims with the reduced dims collapsed to size 1. - @assert N == 2 "`sum_dims` expects (array, dims_vector)" - arr_id = children_arr[first(children_indices)] - dims_id = children_arr[first(children_indices)+1] - @assert nodes[dims_id].type == NODE_VALUE_BLOCK "`sum_dims` requires constant dims (NODE_VALUE_BLOCK)" - # Read the dims values out of `const_values`. The block was - # appended at `nodes[dims_id].index` with length recorded in - # `block_shapes`. - dims_len = prod(block_shapes[dims_id]) - start = nodes[dims_id].index - dims_vec = const_values[(start-1) .+ (1:dims_len)] - in_ndims = sizes.ndims[arr_id] - out_shape = map(1:in_ndims) do d - return d in dims_vec ? 1 : _size(sizes, arr_id, d) - end - _add_size!(sizes, k, out_shape) else - _assert_scalar_children( - sizes, - children_arr, - children_indices, - op, + # Pass shape views directly — no Tuple/Vector conversion. The + # `infer_sizes` methods are generic over indexable containers. + child_shapes = ntuple( + i -> _size(sizes, children_arr[children_indices[i]]), + N, ) - end - elseif node.type == NODE_CALL_MULTIVARIATE_BROADCASTED - if !(node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS)) - # TODO user-defined operators - continue - end - op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] - if op == :+ || op == :- || op == :* || op == :/ - sizes.ndims[k] = maximum(children_indices, init = 0) do i - return sizes.ndims[children_arr[i]] - end - sizes.size_offset[k] = length(sizes.size) - for _ in 1:sizes.ndims[k] - push!(sizes.size, 1) - end - sz_parent = _size(sizes, k) - for i in children_indices - id = children_arr[i] - sz = _size(sizes, id) - for j in eachindex(sz) - if sz[j] > 1 - if sz_parent[j] == 1 - sz_parent[j] = sz[j] - else - @assert sz_parent[j] == sz[j] - end - end - end + out_sz = if node.type == NODE_CALL_MULTIVARIATE_BROADCASTED + infer_sizes(Base.broadcasted, op, child_shapes...) + else + infer_sizes(op, child_shapes...) end - elseif op == :^ - # Broadcasted ^ with scalar exponent preserves base shape - @assert length(children_indices) == 2 "Expected two arguments for broadcasted operator `$op`, got $(length(children_indices))" - @assert iszero(sizes.ndims[children_arr[children_indices[2]]]) "Expected scalar exponent for broadcasted operator `$op`" - _copy_size!(sizes, k, children_arr[first(children_indices)]) end + _add_size!(sizes, k, out_sz) elseif node.type == NODE_CALL_UNIVARIATE if !( node.index in @@ -615,10 +661,8 @@ function _infer_sizes( continue end error("TODO user-defined operators") - continue end @assert N == 1 - op = MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS[node.index] _copy_size!(sizes, k, children_arr[first(children_indices)]) end end @@ -667,7 +711,7 @@ struct _SubexpressionStorage{T<:Real,S<:AbstractVector{T}} j = sizes.storage_offset[k] + 1 len = _length(sizes, k) cpu_buffer[j:(j+len-1)] .= - view(const_values, node.index:(node.index+len-1)) + view(const_values, (node.index):(node.index+len-1)) end end forward_storage = convert(S, cpu_buffer) diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index 067eae8..b4bda2e 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -29,7 +29,7 @@ function test_objective_dot_univariate() MOI.initialize(evaluator, [:Grad, :Hess]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 0, 1, 0] - @test sizes.size_offset == [0, 1, 0, 0, 0] + @test sizes.size_offset == [2, 1, 0, 0, 0] @test sizes.size == [1, 1] @test sizes.storage_offset == [0, 1, 2, 3, 4, 5] xv = [1.2] @@ -50,7 +50,7 @@ function test_objective_dot_univariate_and_scalar_mult() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 0, 0, 1, 0, 1, 0] - @test sizes.size_offset == [0, 0, 0, 1, 0, 0, 0] + @test sizes.size_offset == [2, 0, 2, 1, 0, 0, 0] @test sizes.size == [1, 1] @test sizes.storage_offset == [0, 1, 2, 3, 4, 5, 6, 7] xv = [1.2] @@ -75,7 +75,7 @@ function test_objective_dot_bivariate() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0] - @test sizes.size_offset == [0, 6, 5, 0, 0, 4, 0, 0, 3, 2, 1, 0, 0, 0, 0, 0] + @test sizes.size_offset == [7, 6, 5, 0, 0, 4, 0, 0, 3, 2, 1, 0, 0, 0, 0, 0] @test sizes.size == [2, 2, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11, 13, 15, 17, 18, 19, 21, 22, 23] @@ -100,7 +100,7 @@ function test_objective_hcat_scalars() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5, 7, 8, 9] x1 = 1.0 @@ -131,7 +131,7 @@ function test_objective_hcat_vectors() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 1, 0, 1, 0, 2, 1, 0, 1, 0] - @test sizes.size_offset == [0, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0] + @test sizes.size_offset == [8, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0] @test sizes.size == [1, 1, 1, 2, 1, 1, 1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13] x1 = 1.0 @@ -158,7 +158,7 @@ function test_objective_dot_bivariate_on_rows() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 0, 0, 2, 0, 0, 2, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 0, 0, 8, 0, 0, 6, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 0, 0, 8, 0, 0, 6, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11, 13, 15, 17, 18, 19, 21, 22, 23] @@ -180,7 +180,7 @@ function test_objective_norm_univariate() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 0] - @test sizes.size_offset == [0, 0, 0] + @test sizes.size_offset == [1, 0, 0] @test sizes.size == [1] @test sizes.storage_offset == [0, 1, 2, 3] xv = [1.2] @@ -202,7 +202,7 @@ function test_objective_norm_bivariate() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [1, 0, 0, 0] @test sizes.size == [2] @test sizes.storage_offset == [0, 1, 3, 4, 5] xv = [3.0, 4.0] @@ -231,7 +231,7 @@ function test_objective_norm_of_row_vector() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [2, 0, 0, 0] @test sizes.size == [1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5] x1 = 1.0 @@ -257,7 +257,7 @@ function test_objective_norm_of_vcat_vector() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [2, 0, 0, 0] @test sizes.size == [2, 1] @test sizes.storage_offset == [0, 1, 3, 4, 5] x1 = 1.0 @@ -285,7 +285,7 @@ function test_objective_norm_of_vcat_matrix() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 4, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [6, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 7, 8, 9, 11, 12, 13] x1 = 1.0 @@ -316,7 +316,7 @@ function test_objective_norm_of_row() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 0, 0] - @test sizes.size_offset == [0, 0, 0, 0] + @test sizes.size_offset == [2, 0, 0, 0] @test sizes.size == [1, 2] @test sizes.storage_offset == [0, 1, 3, 4, 5] x1 = 1.0 @@ -342,7 +342,7 @@ function test_objective_norm_of_matrix() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 4, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [6, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 7, 8, 9, 11, 12, 13] x1 = 1.0 @@ -376,7 +376,7 @@ function test_objective_norm_of_matrix_with_sum() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 9, 11, 12, 13, 15, 16, 17, 21, 23, 24, 25, 27, 28, 29] @@ -406,7 +406,7 @@ function test_objective_norm_of_product_of_matrices() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 9, 11, 12, 13, 15, 16, 17, 21, 23, 24, 25, 27, 28, 29] @@ -469,7 +469,7 @@ function test_objective_norm_of_product_of_matrices_with_sum() 0, ] @test sizes.size_offset == [ - 0, + 22, 20, 18, 16, @@ -553,7 +553,7 @@ function test_objective_norm_of_mtx_vector_product() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0] - @test sizes.size_offset == [0, 8, 6, 4, 0, 0, 2, 0, 0, 0, 0, 0] + @test sizes.size_offset == [10, 8, 6, 4, 0, 0, 2, 0, 0, 0, 0, 0] @test sizes.size == [2, 1, 1, 2, 1, 2, 2, 2, 2, 1] @test sizes.storage_offset == [0, 1, 3, 7, 9, 10, 11, 13, 14, 15, 17, 18, 19] @@ -608,7 +608,7 @@ function test_objective_broadcasted_product() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0] - @test sizes.size_offset == [0, 2, 1, 0, 0, 0, 0, 0] + @test sizes.size_offset == [3, 2, 1, 0, 0, 0, 0, 0] @test sizes.size == [2, 2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11] x1 = 1.0 @@ -640,7 +640,7 @@ function test_objective_broadcasted_matrix_product() sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 2, 2, 2, 0, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0] @test sizes.size_offset == - [0, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] + [14, 12, 10, 8, 0, 0, 6, 0, 0, 4, 2, 0, 0, 0, 0, 0] @test sizes.size == [1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2] @test sizes.storage_offset == [0, 1, 5, 9, 11, 12, 13, 15, 16, 17, 21, 23, 24, 25, 27, 28, 29] @@ -673,7 +673,7 @@ function test_objective_broadcasted_tanh() MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes @test sizes.ndims == [0, 1, 1, 0, 0] - @test sizes.size_offset == [0, 1, 0, 0, 0] + @test sizes.size_offset == [2, 1, 0, 0, 0] @test sizes.size == [2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7] x1 = 1.0 diff --git a/test/JuMP.jl b/test/JuMP.jl index 35ba6f6..e7f4490 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -47,9 +47,9 @@ function my_crossentropy2(p, q) return -sum(q .* log.(p)) end -function ArrayDiff.infer_sizes(::typeof(my_crossentropy2), ::Tuple, ::Tuple) - return () -end +# Override works for any indexable shape kind (tuple from JuMP, view from +# tape) — leaving the args untyped exercises that. +ArrayDiff.infer_sizes(::typeof(my_crossentropy2), s1, s2) = () function ChainRulesCore.rrule( ::typeof(my_crossentropy1), @@ -503,6 +503,58 @@ function test_size_vec_vect() return end +# `:sum_dims` is emitted by `Base.sum(::AbstractJuMPArray; dims=…)` and is the +# reduction building block both `LayerNorm` and `softmax` rely on, but no +# existing test reaches it (`grep "sum.*dims" test/` only matches plain Julia +# arrays). These tests exercise the forward path, the `_sum_dims_shape` +# branch in `_infer_sizes`, and the reverse-mode contribution. The analytic +# gradient of `f = ‖sum(W; dims=d)‖` is `s[i_d] / ‖s‖` broadcast over the +# reduced axis. +function test_sum_dims_along_rows() + rows, cols = 2, 3 + model = Model() + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + expr = sum(W; dims = 2) + @test expr isa ArrayDiff.MatrixExpr + @test expr.head == :sum_dims + @test size(expr) == (rows, 1) + x = Float64.(collect(1:(rows*cols))) + W_val = reshape(x, rows, cols) + s = sum(W_val; dims = 2) + sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x) + @test val ≈ LinearAlgebra.norm(s) + # `s` is the (rows, 1) column; broadcast across the reduced (cols) axis. + @test g ≈ vec(repeat(s, 1, cols)) ./ LinearAlgebra.norm(s) + # Tape: norm (k=1, scalar) → sum_dims (k=2, (rows, 1)). + @test sizes.ndims[1] == 0 + @test sizes.ndims[2] == 2 + sd_off = sizes.size_offset[2] + @test sizes.size[sd_off+1] == rows + @test sizes.size[sd_off+2] == 1 + return +end + +function test_sum_dims_along_cols() + rows, cols = 2, 3 + model = Model() + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + expr = sum(W; dims = 1) + @test expr.head == :sum_dims + @test size(expr) == (1, cols) + x = Float64.(collect(1:(rows*cols))) + W_val = reshape(x, rows, cols) + s = sum(W_val; dims = 1) + sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x) + @test val ≈ LinearAlgebra.norm(s) + # `s` is the (1, cols) row; broadcast across the reduced (rows) axis. + @test g ≈ vec(repeat(s, rows, 1)) ./ LinearAlgebra.norm(s) + @test sizes.ndims[2] == 2 + sd_off = sizes.size_offset[2] + @test sizes.size[sd_off+1] == 1 + @test sizes.size[sd_off+2] == cols + return +end + function test_broadcast_nonsquare_matrix() model = Model() @variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables) @@ -520,7 +572,7 @@ function test_broadcast_nonsquare_matrix() # (2, 3). The old bug would report (2, 2) for the broadcast node. @test sizes.ndims == [0, 2, 2, 2] @test sizes.size == [2, 3, 2, 3, 2, 3] - @test sizes.size_offset == [0, 4, 2, 0] + @test sizes.size_offset == [6, 4, 2, 0] @test sizes.storage_offset == [0, 1, 7, 13, 19] @test val ≈ LinearAlgebra.norm(ref_mat) ref_g = if op == :+