diff --git a/Project.toml b/Project.toml index 0aa9ab7..7c385ec 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FunctionProperties" uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc" -version = "0.1.6" +version = "0.1.7" authors = ["SciML"] [deps] diff --git a/src/FunctionProperties.jl b/src/FunctionProperties.jl index bd17be9..c20f4fc 100644 --- a/src/FunctionProperties.jl +++ b/src/FunctionProperties.jl @@ -20,6 +20,27 @@ FunctionProperties.is_leaf(::typeof(my_fn)) = true """ is_leaf(f, args...) = false +""" + is_leaf_sig(sig::Type{<:Tuple}) -> Bool + +Signature-level counterpart to [`is_leaf`](@ref), consulted while recursing through statically +resolved calls. `sig` is the call's `Tuple{typeof(f), argtypes...}`. Return `true` to treat the +call as branch-free and stop recursing into it. + +Use this (instead of `is_leaf`) when the exemption depends on the *argument types*, not just the +function. The motivating case is value-independent plumbing whose branch is on an index/type +rather than on traced values — e.g. selecting a buffer by integer index inside a parameter +container, where each real call site passes a literal index that constant-folds the branch away, +but the recursion only sees the widened argument type. + +## Example + +```julia +FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(getindex), <:MyParamContainer, Vararg}}) = true +``` +""" +is_leaf_sig(@nospecialize(sig)) = false + """ hasbranching(f, x...) @@ -75,14 +96,40 @@ function _hasbranching(@nospecialize(sig), seen, depth) # Generated functions that were not expanded come back as `Method`, not `CodeInfo`; # there is no body to scan, so treat them as leaves. ci isa Core.CodeInfo || continue - any(stmt -> isa(stmt, GotoIfNot), ci.code) && return true for stmt in ci.code - _recurse_call(stmt, ci, seen, depth) && return true + if isa(stmt, GotoIfNot) + _is_const_gotoifnot(stmt, ci) || return true + elseif _recurse_call(stmt, ci, seen, depth) + return true + end end end return false end +# A `GotoIfNot` whose condition type inference has *proven* constant is a compile-time branch, +# not a value-dependent one: e.g. an `x isa T` test on a concretely-typed field (the SciML +# `ODEFunction` wrapper) or the device/type-introspection dispatch inside ML library layers +# (SciML/FunctionProperties.jl#46). Such a branch can never be taken differently under a tracing +# AD, so it is not the branching `hasbranching` is meant to surface. A condition that is a literal +# `true`/`false` written directly into the IR is deliberately *not* skipped: that is a genuine +# syntactic branch in user code (e.g. `true ? a : b`). Only conditions inference resolved to a +# `Core.Const` value are dropped; anything we cannot positively prove constant is kept. +function _is_const_gotoifnot(stmt::GotoIfNot, ci) + cond = stmt.cond + t = if cond isa Core.SSAValue + types = ci.ssavaluetypes + types isa AbstractVector && checkbounds(Bool, types, cond.id) ? types[cond.id] : nothing + elseif cond isa Core.Argument + ci.slottypes === nothing ? nothing : get(ci.slottypes, cond.n, nothing) + elseif cond isa Core.SlotNumber + ci.slottypes === nothing ? nothing : get(ci.slottypes, cond.id, nothing) + else + nothing + end + return t isa Core.Const +end + # Inspect a single IR statement: if it is a statically resolvable call into a non-library # method, recurse into that method's IR. Returns `true` if a branch is found downstream. function _recurse_call(@nospecialize(stmt), ci, seen, depth) @@ -100,15 +147,52 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth) end Meta.isexpr(call, :call) || return false + if _is_apply(call.args[1]) + return _recurse_apply(call, ci, seen, depth) + end ftype, fval = _resolve_callee(call.args[1], ci) ftype === nothing && return false argtypes = Any[_value_type(a, ci) for a in @view call.args[2:end]] return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) end +_is_apply(@nospecialize(f)) = + f isa GlobalRef && f.mod === Core && (f.name === :_apply_iterate || f.name === :_apply) + +# A splatted call `g(a, bs...)` lowers to `Core._apply_iterate(iter, g, groups...)` (or, on +# older lowerings, `Core._apply(g, groups...)`). The real callee `g` is therefore an *argument* +# of a `Core` builtin, so the plain `:call` path would resolve the callee to `_apply_iterate`, +# treat it as library, and dead-end — missing every branch behind the forwarder. SciML/MTK RHS +# objects are exactly such forwarders (`ODEFunction` -> `GeneratedFunctionWrapper` -> +# `RuntimeGeneratedFunction` -> `generated_callfunc`, each `f(args...)`), so the generated body's +# branches only become reachable by following the apply through to `g`. The splatted groups are +# the actual positional arguments; recover their element types from the (concrete) tuple types so +# the right method specialization is selected downstream. +function _recurse_apply(call, ci, seen, depth) + args = call.args + fpos = args[1].name === :_apply_iterate ? 3 : 2 + length(args) >= fpos || return false + ftype, fval = _resolve_callee(args[fpos], ci) + ftype === nothing && return false + argtypes = Any[] + for a in @view args[(fpos + 1):end] + at = _value_type(a, ci) + if at isa DataType && at <: Tuple && Base.isconcretetype(at) + append!(argtypes, at.parameters) + else + # Splatted container whose element types we cannot recover statically (e.g. a + # non-`isbits` `Vararg` tuple or an array): bail rather than guess a wrong signature. + return false + end + end + return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) +end + function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth) # Honor user `is_leaf` overrides when the concrete function value is recoverable. fval !== nothing && is_leaf(fval) && return false + # Signature-level overrides: exemptions that depend on the argument types. + is_leaf_sig(callsig) && return false m = try Base.which(callsig) catch @@ -170,6 +254,6 @@ _widen(@nospecialize t) = t isa Core.PartialStruct ? t.typ : isa(t, Type) ? t : Any -export hasbranching, is_leaf +export hasbranching, is_leaf, is_leaf_sig end diff --git a/test/core_tests.jl b/test/core_tests.jl index e166bd7..376e831 100644 --- a/test/core_tests.jl +++ b/test/core_tests.jl @@ -102,3 +102,60 @@ end x0 = [-4.0f0, 0.0f0] ts = Float32.(collect(0.0:0.01:tspan[2])) @test !FunctionProperties.hasbranching(dxdt_, copy(x0), x0, p, tspan[1]) + +# --------------------------------------------------------------------------------------------- +# Value-independent (compile-time-constant) branches must not be reported. +# +# A `GotoIfNot` whose condition inference proves `Core.Const` cannot be taken differently under a +# tracing AD, so it is wrapper/dispatch plumbing rather than the value-dependent branching +# `hasbranching` is meant to surface. This is the shape of the SciML `ODEFunction` functor +# (`if f.f isa AbstractSciMLOperator`) and of ML-library device/type-introspection dispatch +# (SciML/FunctionProperties.jl#46). A *literal* `true`/`false` condition is still a genuine branch +# and is kept (covered by the `f_branch` test above). +abstract type FakeOperator end +struct CondWrap{F} + f::F +end +function (w::CondWrap)(x) + if w.f isa FakeOperator # concretely-typed field => `isa` folds to a constant + return zero(x) + else + return w.f(x) + end +end +branchfree_inner(x) = x * x + one(x) +branchy_inner(x) = x < 0 ? -x : x +@test !FunctionProperties.hasbranching(CondWrap(branchfree_inner), 1.0) # const `isa` skipped +@test FunctionProperties.hasbranching(CondWrap(branchy_inner), 1.0) # real inner branch kept + +# --------------------------------------------------------------------------------------------- +# Branches behind a *splatted* call boundary must be detected. +# +# `g(args...)` lowers to `Core._apply_iterate(iter, g, args)`, hiding the real callee `g` as an +# argument of a `Core` builtin. The scan must follow the apply through to `g`, otherwise every +# branch behind a splat forwarder is missed. This is the SciML/MTK RHS shape (`ODEFunction` -> +# `GeneratedFunctionWrapper` -> `RuntimeGeneratedFunction` -> `generated_callfunc`, each a +# `f(args...)` forwarder). +@noinline splat_target_branchy(x) = x < 0 ? -x : x +@noinline splat_target_free(x) = x * x +splat_forward_branchy(args...) = splat_target_branchy(args...) +splat_forward_free(args...) = splat_target_free(args...) +@test FunctionProperties.hasbranching(splat_forward_branchy, -1.0) +@test !FunctionProperties.hasbranching(splat_forward_free, -1.0) + +# --------------------------------------------------------------------------------------------- +# `is_leaf_sig`: signature-level exemptions for value-independent plumbing. +# +# A branch on an integer index that selects a buffer (the MTK `getindex(::MTKParameters, ::Int)` +# pattern) is value-independent: each real call site passes a literal index that constant-folds the +# branch, but the recursion only sees the widened `Int` and so reports it. Such a call can be marked +# branch-free by signature. +struct TwoBuffers + a::Float64 + b::Float64 +end +@noinline select_buffer(c::TwoBuffers, i::Int) = i == 1 ? c.a : c.b +rhs_with_plumbing(u, p, t) = select_buffer(p, 1) * u +@test FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0) +FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(select_buffer), TwoBuffers, Vararg}}) = true +@test !FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0)