From f794dae02629f040339b1eb54fd520e0a5e5ffd6 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Fri, 3 Jul 2026 04:17:51 -0400 Subject: [PATCH] Release v1.0.0: constant refutation replaces is_leaf_sig Promotes FunctionProperties to a stable 1.0 public API surface -- `hasbranching` and `is_leaf` -- and makes constant-propagation-aware analysis the single mechanism for suppressing value-independent branches. Analysis change (follow-up to #62): the type recursion is the source of truth; when it reports a branch inside a call carrying `Core.Const` arguments, the callee is re-inferred with those constants preserved (no optimizer, so no library/structural branches are inlined into view) and, only if that folds the branch, the finding is refuted. This is a strict refinement -- it can only downgrade a reported branch to branch-free -- so it adds no false positives on arbitrary code (broadcast, ComponentArrays), unlike a "run const-prop everywhere" approach. It replaces the `is_leaf_sig` hook, so no per-container override is needed (e.g. split-mode MTK `getindex(::MTKParameters, ::Int)`). The refutation uses `Base.Compiler`/`Core.Compiler` internals whose API differs across versions, so it is functionally gated: a probe folds a constant-decided branch fixture at first use and only activates if the fold works; otherwise, and on any inference failure, the analysis is exactly the plain type recursion. Verified: lts 1.10.11 (gated off), 1.12.6, 1.13.0-rc1 all green; MTK RHS table correct through the `ODEFunction` with no container-specific overrides; docs build clean. Also fixes the docs config: `deploydocs` pointed at MultiScaleArrays.jl, and the docs environment pinned `FunctionProperties = "0.1.2"`. BREAKING (1.0.0): removes the exported `is_leaf_sig` (public in 0.1.7). Co-Authored-By: Chris Rackauckas --- Project.toml | 2 +- docs/Project.toml | 2 +- docs/make.jl | 2 +- src/FunctionProperties.jl | 283 ++++++++++++++++++++++++++++++-------- test/core_tests.jl | 41 ++++-- 5 files changed, 259 insertions(+), 71 deletions(-) diff --git a/Project.toml b/Project.toml index 7c385ec..32a42f2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FunctionProperties" uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc" -version = "0.1.7" +version = "1.0.0" authors = ["SciML"] [deps] diff --git a/docs/Project.toml b/docs/Project.toml index 40a63e3..edaad76 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,4 @@ FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc" [compat] Documenter = "1" -FunctionProperties = "0.1.2" +FunctionProperties = "1" diff --git a/docs/make.jl b/docs/make.jl index 96f55d1..e74f9c6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -28,6 +28,6 @@ makedocs( ) deploydocs( - repo = "github.com/SciML/MultiScaleArrays.jl.git"; + repo = "github.com/SciML/FunctionProperties.jl.git"; push_preview = true ) diff --git a/src/FunctionProperties.jl b/src/FunctionProperties.jl index c20f4fc..38db025 100644 --- a/src/FunctionProperties.jl +++ b/src/FunctionProperties.jl @@ -5,6 +5,18 @@ using Core: GotoIfNot # Backstop against pathological recursion depth; real call trees that matter here are shallow. const RECURSION_LIMIT = 256 +# `hasbranching` recurses through statically resolved calls. Ordinary analysis widens every argument +# to its type, which loses constants: a branch decided by a *constant* argument (e.g. selecting a +# buffer by a literal index inside a parameter container) then looks value-dependent even though +# every real call site folds it. When the running Julia's compiler cooperates, such a call is +# re-inferred with its `Core.Const` arguments preserved (no optimizer, so no library/structural +# branches are inlined into view) and the constant-decided branch folds to a `Core.Const` condition +# that `_is_const_gotoifnot` skips. This depends on `Base.Compiler`/`Core.Compiler` internals whose +# API changes across Julia versions, so it is *functionally* gated (see `_const_prop_capable`): it +# activates only where a probe confirms folding actually works, and otherwise the analysis falls +# back to the plain type recursion. +const _CC = isdefined(Base, :Compiler) ? Base.Compiler : Core.Compiler + """ is_leaf(f, args...) -> Bool @@ -20,27 +32,6 @@ FunctionProperties.is_leaf(::typeof(my_fn)) = true """ is_leaf(f, args...) = false -""" - is_leaf_sig(sig::Type{<:Tuple}) -> Bool - -Signature-level counterpart to [`is_leaf`](@ref), consulted while recursing through statically -resolved calls. `sig` is the call's `Tuple{typeof(f), argtypes...}`. Return `true` to treat the -call as branch-free and stop recursing into it. - -Use this (instead of `is_leaf`) when the exemption depends on the *argument types*, not just the -function. The motivating case is value-independent plumbing whose branch is on an index/type -rather than on traced values — e.g. selecting a buffer by integer index inside a parameter -container, where each real call site passes a literal index that constant-folds the branch away, -but the recursion only sees the widened argument type. - -## Example - -```julia -FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(getindex), <:MyParamContainer, Vararg}}) = true -``` -""" -is_leaf_sig(@nospecialize(sig)) = false - """ hasbranching(f, x...) @@ -64,6 +55,10 @@ that branches living behind a non-inlined call boundary are still detected. Call are structural/compile-time rather than value-dependent user logic, and recursing into them (e.g. matrix multiply, broadcasting, `getindex` bounds checks) produces false positives. +Branches whose condition inference proves constant are ignored (they are not value-dependent), +and — where the compiler cooperates — a call with constant arguments is re-inferred with those +constants preserved so branches they decide fold away rather than being reported. + ## Customizing and Removing Functions from the Checks Some functions may produce false positives because their internal branches are compile-time @@ -96,12 +91,17 @@ function _hasbranching(@nospecialize(sig), seen, depth) # Generated functions that were not expanded come back as `Method`, not `CodeInfo`; # there is no body to scan, so treat them as leaves. ci isa Core.CodeInfo || continue - for stmt in ci.code - if isa(stmt, GotoIfNot) - _is_const_gotoifnot(stmt, ci) || return true - elseif _recurse_call(stmt, ci, seen, depth) - return true - end + _scan_codeinfo(ci, seen, depth) && return true + end + return false +end + +function _scan_codeinfo(ci, seen, depth) + for stmt in ci.code + if isa(stmt, GotoIfNot) + _is_const_gotoifnot(stmt, ci) || return true + elseif _recurse_call(stmt, ci, seen, depth) + return true end end return false @@ -109,12 +109,12 @@ end # A `GotoIfNot` whose condition type inference has *proven* constant is a compile-time branch, # not a value-dependent one: e.g. an `x isa T` test on a concretely-typed field (the SciML -# `ODEFunction` wrapper) or the device/type-introspection dispatch inside ML library layers -# (SciML/FunctionProperties.jl#46). Such a branch can never be taken differently under a tracing -# AD, so it is not the branching `hasbranching` is meant to surface. A condition that is a literal -# `true`/`false` written directly into the IR is deliberately *not* skipped: that is a genuine -# syntactic branch in user code (e.g. `true ? a : b`). Only conditions inference resolved to a -# `Core.Const` value are dropped; anything we cannot positively prove constant is kept. +# `ODEFunction` wrapper), the device/type-introspection dispatch inside ML library layers +# (SciML/FunctionProperties.jl#46), or a branch a constant argument folded via the constant-argument +# recursion below. Such a branch can never be taken differently under a tracing AD. A condition that +# is a literal `true`/`false` written directly into the IR is deliberately *not* skipped: that is a +# genuine syntactic branch in user code (e.g. `true ? a : b`). Only conditions inference resolved to +# a `Core.Const` value are dropped; anything we cannot positively prove constant is kept. function _is_const_gotoifnot(stmt::GotoIfNot, ci) cond = stmt.cond t = if cond isa Core.SSAValue @@ -130,8 +130,9 @@ function _is_const_gotoifnot(stmt::GotoIfNot, ci) return t isa Core.Const end -# Inspect a single IR statement: if it is a statically resolvable call into a non-library -# method, recurse into that method's IR. Returns `true` if a branch is found downstream. +# Inspect a single IR statement: if it is a statically resolvable call into a non-library method, +# recurse into that method (with any constant arguments preserved). Returns `true` if a branch is +# found downstream. function _recurse_call(@nospecialize(stmt), ci, seen, depth) call = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt @@ -139,11 +140,13 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth) mi = call.args[1] callsig = mi isa Core.MethodInstance ? mi.specTypes : ( - isdefined(mi, :def) && getfield(mi, :def) isa Core.MethodInstance ? + isdefined(mi, :def) && getfield(mi, :def) isa Core.MethodInstance ? getfield(mi, :def).specTypes : nothing - ) + ) callsig === nothing && return false - return _recurse_sig(callsig, nothing, seen, depth) + _, fval = _resolve_callee(call.args[2], ci) + arglat = Any[_arg_lattice(a, ci) for a in @view call.args[3:end]] + return _recurse_sig(callsig, fval, arglat, seen, depth) end Meta.isexpr(call, :call) || return false @@ -152,22 +155,21 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth) end ftype, fval = _resolve_callee(call.args[1], ci) ftype === nothing && return false - argtypes = Any[_value_type(a, ci) for a in @view call.args[2:end]] - return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) + arglat = Any[_arg_lattice(a, ci) for a in @view call.args[2:end]] + return _recurse_sig(Tuple{ftype, (_lat_type(x) for x in arglat)...}, fval, arglat, seen, depth) end _is_apply(@nospecialize(f)) = f isa GlobalRef && f.mod === Core && (f.name === :_apply_iterate || f.name === :_apply) -# A splatted call `g(a, bs...)` lowers to `Core._apply_iterate(iter, g, groups...)` (or, on -# older lowerings, `Core._apply(g, groups...)`). The real callee `g` is therefore an *argument* -# of a `Core` builtin, so the plain `:call` path would resolve the callee to `_apply_iterate`, -# treat it as library, and dead-end — missing every branch behind the forwarder. SciML/MTK RHS -# objects are exactly such forwarders (`ODEFunction` -> `GeneratedFunctionWrapper` -> -# `RuntimeGeneratedFunction` -> `generated_callfunc`, each `f(args...)`), so the generated body's -# branches only become reachable by following the apply through to `g`. The splatted groups are -# the actual positional arguments; recover their element types from the (concrete) tuple types so -# the right method specialization is selected downstream. +# A splatted call `g(a, bs...)` lowers to `Core._apply_iterate(iter, g, groups...)` (or, on older +# lowerings, `Core._apply(g, groups...)`). The real callee `g` is therefore an *argument* of a +# `Core` builtin, so the plain `:call` path would resolve the callee to `_apply_iterate`, treat it +# as library, and dead-end — missing every branch behind the forwarder. SciML/MTK RHS objects are +# exactly such forwarders (`ODEFunction` -> `GeneratedFunctionWrapper` -> `RuntimeGeneratedFunction` +# -> `generated_callfunc`, each `f(args...)`), so the generated body's branches only become +# reachable by following the apply. The splatted groups are the actual positional arguments; +# recover their element types from the (concrete) tuple types. function _recurse_apply(call, ci, seen, depth) args = call.args fpos = args[1].name === :_apply_iterate ? 3 : 2 @@ -180,28 +182,195 @@ function _recurse_apply(call, ci, seen, depth) if at isa DataType && at <: Tuple && Base.isconcretetype(at) append!(argtypes, at.parameters) else - # Splatted container whose element types we cannot recover statically (e.g. a - # non-`isbits` `Vararg` tuple or an array): bail rather than guess a wrong signature. + # Splatted container whose element types we cannot recover statically: bail rather than + # guess a wrong signature. return false end end - return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) + # Splatted arguments are recovered from tuple element *types*; constants are not available here. + return _recurse_sig(Tuple{ftype, argtypes...}, fval, Any[argtypes...], seen, depth) end -function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth) +function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), arglat, seen, depth) # Honor user `is_leaf` overrides when the concrete function value is recoverable. fval !== nothing && is_leaf(fval) && return false - # Signature-level overrides: exemptions that depend on the argument types. - is_leaf_sig(callsig) && return false m = try Base.which(callsig) catch return false end _is_library_method(m) && return false - return _hasbranching(callsig, seen, depth + 1) + # The type recursion is the source of truth. If it finds no branch, we are done. + _hasbranching(callsig, seen, depth + 1) || return false + # It found a branch. If constant arguments decide that branch, re-inferring with the constants + # preserved folds it away -- so only then do we let the constant path *refute* the finding. This + # can only downgrade a reported branch to branch-free, never the reverse, and it is skipped + # entirely (leaving the branch reported) when there are no constant arguments, when the compiler + # internals do not cooperate, or when the constant inference errors. + if _const_prop_capable() && any(x -> x isa Core.Const, arglat) + funclat = fval !== nothing ? Core.Const(fval) : _first_param(callsig) + _const_refutes(callsig, Any[funclat, arglat...], depth) && return false + end + return true +end + +# Re-infer `sig` with the constant lattice elements preserved (no optimizer, so no library or +# structural branches are inlined into view) and report whether the result is branch-free. Returns +# `false` -- i.e. does not refute -- whenever the constant inference is unavailable, fails, or leaves +# a branch, so an inability to fold never suppresses a genuine branch. +function _const_refutes(@nospecialize(sig), argtypes, depth) + src = _const_infer_src(sig, argtypes) + src isa Core.CodeInfo || return false + return try + !_scan_codeinfo(src, Set{Any}(), depth) + catch + false + end +end + +_first_param(@nospecialize(sig)) = + (sig isa DataType && !isempty(sig.parameters)) ? sig.parameters[1] : Any +_lat_type(@nospecialize(x)) = x isa Core.Const ? Core.Typeof(x.val) : x + +# Argument lattice element: a `Core.Const` when the argument is a compile-time constant, otherwise +# the widened type. Preserving the `Core.Const` is what lets a constant index survive the recursion +# boundary so `_const_refutes` can fold the branch it decides. +function _arg_lattice(@nospecialize(a), ci) + if a isa Core.SSAValue + t = ci.ssavaluetypes[a.id] + return t isa Core.Const ? t : _widen(t) + elseif a isa Core.Argument + st = ci.slottypes + st === nothing && return Any + t = st[a.n] + return t isa Core.Const ? t : _widen(t) + elseif a isa Core.SlotNumber + st = ci.slottypes + st === nothing && return Any + t = st[a.id] + return t isa Core.Const ? t : _widen(t) + elseif a isa GlobalRef + return (isdefined(a.mod, a.name) && isconst(a.mod, a.name)) ? + Core.Const(getglobal(a.mod, a.name)) : Any + elseif a isa QuoteNode + return Core.Const(a.value) + elseif a isa Expr || a isa Core.GotoNode || a isa GotoIfNot || + a isa Core.NewvarNode || a isa Core.ReturnNode + return Any + else + # Raw literal constant embedded in the IR (e.g. an `Int` index). + return Core.Const(a) + end end +# ---- constant-argument inference ----------------------------------------------------------- + +# Run inference on `sig` with the given argument lattice (some `Core.Const`) preserved, and return +# the inferred (unoptimized) `CodeInfo`, or `nothing` if the compiler internals do not cooperate. +# The `InferenceState` construction and the inferred-source location differ across Julia versions: +# 1.12 accepts `InferenceState(result, cache_mode, interp)` and exposes the body on `result.src`, +# while 1.13 wants the uninferred source passed explicitly and exposes the body on `frame.src`. We +# try the explicit-source form first (works on both) with the non-caching `:volatile` mode, then +# fall back, and read whichever of `frame.src`/`result.src` is a `CodeInfo`. +function _const_infer_src(@nospecialize(sig), argtypes) + m = try + Base.which(sig) + catch + return nothing + end + mi = try + Base.specialize_method(m, sig, Core.svec()) + catch + return nothing + end + overridden = BitVector(x isa Core.Const for x in argtypes) + src0 = try + _CC.retrieve_code_info(mi, Base.get_world_counter()) + catch + nothing + end + # A fresh `InferenceResult`/`InferenceState` per attempt: an `InferenceResult` cannot be + # re-inferred once used. + for build in ( + interp -> src0 isa Core.CodeInfo ? + _CC.InferenceState(_new_result(mi, argtypes, overridden), src0, :volatile, interp) : + nothing, + interp -> _CC.InferenceState(_new_result(mi, argtypes, overridden), :volatile, interp), + ) + src = try + interp = _CC.NativeInterpreter() + frame = build(interp) + frame === nothing && continue + _CC.typeinf(interp, frame) + _inferred_src(frame) + catch + nothing + end + src isa Core.CodeInfo && return src + end + return nothing +end + +_new_result(mi, argtypes, overridden) = _CC.InferenceResult(mi, Any[argtypes...], overridden) + +function _inferred_src(frame) + if isdefined(frame, :src) && getfield(frame, :src) isa Core.CodeInfo + return getfield(frame, :src) + end + r = getfield(frame, :result) + return (r isa _CC.InferenceResult && r.src isa Core.CodeInfo) ? r.src : nothing +end + +_count_nonconst_gotoifnot(ci::Core.CodeInfo) = + count(s -> isa(s, GotoIfNot) && !_is_const_gotoifnot(s, ci), ci.code) + +# `nothing` until the functional probe has run; then `true`/`false`. +const _CONST_PROP_CAPABLE = Ref{Union{Nothing, Bool}}(nothing) + +# Fixture with a branch decided purely by a constant integer index -- the shape the constant-argument +# recursion must fold. Used only by the capability probe. +struct _ProbeContainer + a::Int + b::Int +end +@generated function _probe_indexed(x::_ProbeContainer, idx::Int) + quote + if idx == 1 + return x.a + else + return x.b + end + end +end + +# Verify, on the running Julia, that constant inference actually folds a constant-decided branch: +# the constant-index call must come back branch-free while the widened-index call must not. If the +# compiler internals we depend on have shifted shape, this returns `false` and the constant-argument +# recursion stays inert (behaviour identical to the plain type recursion). Probed once, then cached. +function _probe_const_prop() + sig = Tuple{typeof(_probe_indexed), _ProbeContainer, Int} + folded = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Core.Const(1)]) + widened = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Int]) + folded isa Core.CodeInfo || return false + widened isa Core.CodeInfo || return false + return _count_nonconst_gotoifnot(folded) == 0 && _count_nonconst_gotoifnot(widened) > 0 +end + +function _const_prop_capable() + v = _CONST_PROP_CAPABLE[] + if v === nothing + v = try + _probe_const_prop() + catch + false + end + _CONST_PROP_CAPABLE[] = v + end + return v +end + +# ---- callee/argument resolution ------------------------------------------------------------ + # Library code (`Base`, `Core`, stdlibs) is treated as a leaf: its branches are structural or # compile-time, not the value-dependent user logic `hasbranching` is meant to surface. function _is_library_method(m::Method) @@ -254,6 +423,6 @@ _widen(@nospecialize t) = t isa Core.PartialStruct ? t.typ : isa(t, Type) ? t : Any -export hasbranching, is_leaf, is_leaf_sig +export hasbranching, is_leaf end diff --git a/test/core_tests.jl b/test/core_tests.jl index 376e831..6cf6c6e 100644 --- a/test/core_tests.jl +++ b/test/core_tests.jl @@ -144,18 +144,37 @@ splat_forward_free(args...) = splat_target_free(args...) @test !FunctionProperties.hasbranching(splat_forward_free, -1.0) # --------------------------------------------------------------------------------------------- -# `is_leaf_sig`: signature-level exemptions for value-independent plumbing. +# Constant-decided branches (value-independent) must not be reported. # # A branch on an integer index that selects a buffer (the MTK `getindex(::MTKParameters, ::Int)` # pattern) is value-independent: each real call site passes a literal index that constant-folds the -# branch, but the recursion only sees the widened `Int` and so reports it. Such a call can be marked -# branch-free by signature. -struct TwoBuffers - a::Float64 - b::Float64 +# branch, but ordinary recursion widens the `Int` and so reports it. The constant-argument recursion +# re-infers the callee with the constant preserved so the branch folds away — where the running +# Julia's compiler cooperates (`_const_prop_capable()`). It stays conservative: a genuinely +# value-dependent branch, and a dynamic (non-constant) index, are always reported. +struct TwoBufferParams + a::Vector{Float64} + b::Vector{Float64} +end +@generated function pick_buffer(p::TwoBufferParams, idx::Int) + quote + if idx == 1 + return p.a + elseif idx == 2 + return p.b + else + throw(BoundsError(p, idx)) + end + end +end +cp_relu(x) = x > 0 ? x : zero(x) +rhs_const_index(p) = @inbounds pick_buffer(p, 1)[1] +rhs_dynamic_index(p, i) = @inbounds pick_buffer(p, i)[1] +rhs_real_branch(u, p) = cp_relu(u) + @inbounds pick_buffer(p, 1)[1] +tbp = TwoBufferParams([1.0], [2.0]) + +@test FunctionProperties.hasbranching(rhs_real_branch, 1.0, tbp) # genuine branch: always reported +@test FunctionProperties.hasbranching(rhs_dynamic_index, tbp, 1) # dynamic index: always reported +if FunctionProperties._const_prop_capable() + @test !FunctionProperties.hasbranching(rhs_const_index, tbp) # constant index folds away end -@noinline select_buffer(c::TwoBuffers, i::Int) = i == 1 ? c.a : c.b -rhs_with_plumbing(u, p, t) = select_buffer(p, 1) * u -@test FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0) -FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(select_buffer), TwoBuffers, Vararg}}) = true -@test !FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0)