Recurse through typed IR so branches behind call boundaries are detected#61
Merged
ChrisRackauckas merged 2 commits intoJun 27, 2026
Conversation
PR SciML#55 replaced the Cassette-based pass with a single `code_typed(...; optimize = false)` scan that only inspects the immediate IR of `f`. That made `hasbranching` miss value-dependent branches that live in a helper the entry function calls but that does not get inlined — a false negative. For SciMLSensitivity that is the dangerous direction: a missed branch lets a ReverseDiff tape be compiled on a branchy function, silently producing wrong gradients. Restore the recursive semantics of the old Cassette pass, but on type-inferred IR instead of a Cassette context: - Scan the entry function's IR for `GotoIfNot`, then recurse through statically resolved calls into the callees' IR. - Treat `Base`, `Core`, and stdlib methods as leaves (detected via `Base.moduleroot` / `Sys.STDLIB`). Their internal branches are structural / compile-time, not value-dependent user logic; recursing into matrix multiply, broadcasting, or `getindex` bounds checks would reintroduce the false positives the old code suppressed with a hand-curated leaf list. This module rule subsumes that list. - `is_leaf` overrides now also stop recursion into the marked callee. - Unexpanded generated functions return `Method` rather than `CodeInfo` and are treated as leaves. Tests cover a branch behind a `@noinline` helper, a branch-free nested helper, and an `is_leaf`-opted-out helper, alongside the existing broadcast / matmul / neural-ODE false-positive cases. Passes on Julia 1.10 and 1.12. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
#55 fixed the Julia 1.12 incompatibility by dropping Cassette and replacing the IR pass with a single
code_typed(f, argtypes; optimize = false)scan. That scan only inspects the immediate IR off:The old Cassette pass instead recursed through the entire call tree (
Cassette.recurse), flagging a branch found in any nested method. So the new version regresses on value-dependent branches that live in a helper the entry function calls but which is not inlined atoptimize = false:hasbranching(rhs, 1.0, 2.0, 0.0)returnsfalseon the merged code, but the function does branch on its input.For SciMLSensitivity this is the unsafe direction.
concrete_solve.jluses!hasbranching(f, ...)to decide whether to compile a ReverseDiff tape. A false negative means a tape gets compiled on a branchy function, which bakes in one branch and silently produces wrong gradients for inputs that take the other branch. (A false positive only costs a skipped optimization.)Fix
Restore the recursive semantics of the old Cassette pass, but on type-inferred IR rather than a Cassette context:
GotoIfNot, then recurse through statically resolved calls into the callees' IR (reconstructing each callee signature fromssavaluetypes/slottypes, sinceoptimize = falseemits:call, not:invoke).Base,Core, and stdlib methods as leaves (viaBase.moduleroot/Sys.STDLIB). Their internal branches are structural / compile-time, not value-dependent user logic; recursing into matrix multiply, broadcasting, orgetindexbounds checks would reintroduce exactly the false positives the old code suppressed with its hand-curated leaf list (+,*,getindex,broadcasted,materialize, the DiffRules functions, …). The module rule subsumes that whole list.is_leafoverrides now also stop recursion into the marked callee.code_typed_by_typeasMethodrather thanCodeInfo(e.g.ComponentArrays._getindexon 1.12); these are treated as leaves.seenset plus a depth backstop bound the recursion.Tests
Added to
test/core_tests.jl:@noinlinehelper →true(the regression),false,is_leaf-opted-out branchy helper →false.The existing broadcast / matmul / affine-activation / neural-ODE cases (which must stay
false) are unchanged and still pass — that suite is what pins the no-false-positive requirement.Verified locally with
Pkg.test()on Julia 1.10 and 1.12 (Core group: 14/14; QA group: Aqua/ExplicitImports pass, the two pre-existing@test_brokenplaceholders from #54 remain).Notes
is_leaf) is unchanged and gains recursion-stopping power.versionbumped to0.1.5.Base/stdlib). If a future case wants to recurse into a non-stdlib package that produces a false positive,is_leafis the escape hatch.