Skip to content

Recurse through typed IR so branches behind call boundaries are detected#61

Merged
ChrisRackauckas merged 2 commits into
SciML:mainfrom
ChrisRackauckas-Claude:fix/recursive-typed-ir-branch-detection
Jun 27, 2026
Merged

Recurse through typed IR so branches behind call boundaries are detected#61
ChrisRackauckas merged 2 commits into
SciML:mainfrom
ChrisRackauckas-Claude:fix/recursive-typed-ir-branch-detection

Conversation

@ChrisRackauckas-Claude

Copy link
Copy Markdown
Contributor

Please ignore until reviewed by @ChrisRackauckas.

Problem

#55 fixed the Julia 1.12 incompatibility by dropping Cassette and replacing the IR pass with a single code_typed(f, argtypes; optimize = false) scan. That scan only inspects the immediate IR of f:

ci = first(code_typed(f, argtypes; optimize = false))[1]
return any(isa(s, Core.GotoIfNot) for s in ci.code)

The old Cassette pass instead recursed through the entire call tree (Cassette.recurse), flagging a branch found in any nested method. So the new version regresses on value-dependent branches that live in a helper the entry function calls but which is not inlined at optimize = false:

@noinline branchy_helper(x) = x < 0 ? -x : x
rhs(u, p, t) = branchy_helper(u) + p   # branch is invisible in rhs's own IR

hasbranching(rhs, 1.0, 2.0, 0.0) returns false on the merged code, but the function does branch on its input.

For SciMLSensitivity this is the unsafe direction. concrete_solve.jl uses !hasbranching(f, ...) to decide whether to compile a ReverseDiff tape. A false negative means a tape gets compiled on a branchy function, which bakes in one branch and silently produces wrong gradients for inputs that take the other branch. (A false positive only costs a skipped optimization.)

Fix

Restore the recursive semantics of the old Cassette pass, but on type-inferred IR rather than a Cassette context:

  • Scan the entry function's IR for GotoIfNot, then recurse through statically resolved calls into the callees' IR (reconstructing each callee signature from ssavaluetypes/slottypes, since optimize = false emits :call, not :invoke).
  • Treat Base, Core, and stdlib methods as leaves (via Base.moduleroot / Sys.STDLIB). Their internal branches are structural / compile-time, not value-dependent user logic; recursing into matrix multiply, broadcasting, or getindex bounds checks would reintroduce exactly the false positives the old code suppressed with its hand-curated leaf list (+, *, getindex, broadcasted, materialize, the DiffRules functions, …). The module rule subsumes that whole list.
  • is_leaf overrides now also stop recursion into the marked callee.
  • Unexpanded generated functions come back from code_typed_by_type as Method rather than CodeInfo (e.g. ComponentArrays._getindex on 1.12); these are treated as leaves.
  • A seen set plus a depth backstop bound the recursion.

Tests

Added to test/core_tests.jl:

  • a branch behind a @noinline helper → true (the regression),
  • a branch-free nested helper → false,
  • an is_leaf-opted-out branchy helper → false.

The existing broadcast / matmul / affine-activation / neural-ODE cases (which must stay false) are unchanged and still pass — that suite is what pins the no-false-positive requirement.

Verified locally with Pkg.test() on Julia 1.10 and 1.12 (Core group: 14/14; QA group: Aqua/ExplicitImports pass, the two pre-existing @test_broken placeholders from #54 remain).

Notes

  • This restores a behavioral contract that Fix Julia 1.12: replace Cassette with code_typed-based branch detection #55 narrowed; the extension API (is_leaf) is unchanged and gains recursion-stopping power.
  • version bumped to 0.1.5.
  • Tradeoff vs. the merged code: recursion costs more analysis time at setup and is conservative across module boundaries (stops at Base/stdlib). If a future case wants to recurse into a non-stdlib package that produces a false positive, is_leaf is the escape hatch.

PR SciML#55 replaced the Cassette-based pass with a single `code_typed(...;
optimize = false)` scan that only inspects the immediate IR of `f`. That made
`hasbranching` miss value-dependent branches that live in a helper the entry
function calls but that does not get inlined — a false negative. For
SciMLSensitivity that is the dangerous direction: a missed branch lets a
ReverseDiff tape be compiled on a branchy function, silently producing wrong
gradients.

Restore the recursive semantics of the old Cassette pass, but on type-inferred
IR instead of a Cassette context:

- Scan the entry function's IR for `GotoIfNot`, then recurse through statically
  resolved calls into the callees' IR.
- Treat `Base`, `Core`, and stdlib methods as leaves (detected via
  `Base.moduleroot` / `Sys.STDLIB`). Their internal branches are structural /
  compile-time, not value-dependent user logic; recursing into matrix multiply,
  broadcasting, or `getindex` bounds checks would reintroduce the false
  positives the old code suppressed with a hand-curated leaf list. This module
  rule subsumes that list.
- `is_leaf` overrides now also stop recursion into the marked callee.
- Unexpanded generated functions return `Method` rather than `CodeInfo` and are
  treated as leaves.

Tests cover a branch behind a `@noinline` helper, a branch-free nested helper,
and an `is_leaf`-opted-out helper, alongside the existing broadcast / matmul /
neural-ODE false-positive cases. Passes on Julia 1.10 and 1.12.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Comment thread src/FunctionProperties.jl Outdated
@ChrisRackauckas ChrisRackauckas marked this pull request as ready for review June 27, 2026 14:38
@ChrisRackauckas ChrisRackauckas merged commit 4e13b9b into SciML:main Jun 27, 2026
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants