Preserve input element type in ExpMethodGeneric (fix Float32 → Float64 promotion)#238
Draft
ChrisRackauckas-Claude wants to merge 1 commit into
Conversation
…64 promotion)
`exponential!(x, ExpMethodGeneric())` on a Float32 immutable matrix (e.g.
`SMatrix{3,3,Float32}`) silently returned a Float64 result, and likewise
ForwardDiff jacobians came back as Float64. The cause was the specialized
(13,13) Padé numerator methods hardcoding `Float64` coefficients
(`UniformScaling{Float64}(...)` for matrices, `Float64` literals for scalars),
which forced the whole Horner evaluation to promote to Float64.
Take the coefficient type from `float(eltype(x))` at runtime instead. The
coefficients are the exact Padé rationals, so for Float64 inputs the resulting
Float64 coefficients are bit-identical to the previous literals (verified),
while Float32 (and ComplexF32, Double64, ForwardDiff Dual) inputs now retain
their type. This is deliberately not `@generated`: constructing the element
type inside a generated body is world-age blocked for ForwardDiff `Dual`s.
Reported at https://discourse.julialang.org/t/how-to-use-forwarddiff-jl-with-inv-and-matrix-exponential/137880/15
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Reported on Discourse: How to use ForwardDiff.jl with
invand matrix exponential (post #15).exponential!(x, ExpMethodGeneric())silently promotesFloat32inputs toFloat64:Cause
The specialized
(13,13)Padé numerator methods hardcodedFloat64coefficients —UniformScaling{Float64}(...)for the immutable-matrix path andFloat64literals for the scalar path. The hardcodedFloat64coefficients force the whole@evalpolyHorner evaluation to promote toFloat64, regardless of the input element type. (The immutable/SMatrixpath is the one that hits these; mutable matrices go throughexp_generic_mutable, which already promotes toFloat64by design and is untouched here.)Fix
Take the coefficient type from
float(eltype(x))(matrix) /float(typeof(x))(scalar) at runtime, using the exact Padé rationals. ForFloat64inputs the resulting coefficients are bit-identical to the previous hardcoded literals (verified), soFloat64results are unchanged.Float32,ComplexF32,Double64, and ForwardDiffDualinputs now retain their element type.This is deliberately not
@generated: deriving the coefficient type inside a generated-function body requires constructing the element type (e.g. a ForwardDiffDual), which is world-age blocked and breaks differentiation. That is also why the existing generic@generated exp_pade_p(x, ::Val{k}, ::Val{m})method (used for non-13 orders) could not simply be reused for the defaultk = 13matrix/scalar case.Tests
Added a testset asserting element-type preservation for
ExpMethodGenericacrossFloat32/Float64static matrices (result eltype and ForwardDiff jacobian eltype),ComplexF32, and the scalar path. Full suite passes locally on Julia 1.10 (364 pass, 1 pre-existing broken).🤖 Generated with Claude Code