From 7104e73d63394385fbe4f212bd607ec945f014e6 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 14 Jun 2026 23:58:30 +0200 Subject: [PATCH] Fix matrix-function primal type for Hermitian{<:Real} on Julia 1.12 Julia 1.12 changed real-valued matrix functions of `Hermitian{<:Real}` (e.g. `exp`, `log`, `sqrt`, `cos`) to return `Hermitian` instead of `Symmetric`. `_matfun` unconditionally wrapped the real-eltype result as `Symmetric`, so the rrule/frule primal no longer matched `f(A)` (FluxML/Zygote.jl#1592). Wrap the result in the same type as the input when the output is real, keeping `Symmetric` for the complex-valued case (e.g. `acosh` of eigenvalues < 1), which Julia still returns as `Symmetric` for both `Symmetric` and `Hermitian` real inputs. The behavior is unchanged on Julia < 1.12. The inference assertions in the matrix-function tests are updated accordingly. Co-Authored-By: Claude Opus 4.8 (1M context) --- Project.toml | 2 +- src/rulesets/LinearAlgebra/symmetric.jl | 14 +++++++++++++- test/rulesets/LinearAlgebra/symmetric.jl | 10 ++++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ff52dd8bb..ddb2fe3fd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.73.0" +version = "1.73.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index cdadca6c3..2e713c2ed 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -381,7 +381,19 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) df_dλ = last.(unthunk.(fλ_df_dλ)) fA = (U * Diagonal(fλ)) * U' Y = if eltype(A) <: Real - Symmetric(fA) + if eltype(fλ) <: Complex + # a complex-valued function of a real Herm/Sym is always `Symmetric` + Symmetric(fA) + else + # Julia 1.12 made real-valued matrix functions of `Hermitian{<:Real}` + # return `Hermitian` (they previously returned `Symmetric`); match the + # type of the primal `f(A)`. + @static if VERSION >= v"1.12" + _symhermtype(A)(fA) + else + Symmetric(fA) + end + end elseif eltype(fλ) <: Complex fA else diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 593b82148..1940ca15e 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -328,8 +328,11 @@ if is_inferable(f, A) Y_ad, ∂Y_ad = @maybe_inferred frule((ZeroTangent(), ΔA), f, A) else + # real-valued functions of `TA{<:Real}` return `TA` on Julia + # 1.12+ (they previously always returned `Symmetric`) + SR = @static VERSION >= v"1.12" ? TA{T} : Symmetric{T} TY = T∂Y = if T <: Real - Union{Symmetric{Complex{T}},Symmetric{T}} + Union{Symmetric{Complex{T}},SR} else Union{Matrix{T},Hermitian{T}} end @@ -381,8 +384,11 @@ if is_inferable(f, A) Y_ad, back = @maybe_inferred rrule(f, A) else + # real-valued functions of `TA{<:Real}` return `TA` on Julia + # 1.12+ (they previously always returned `Symmetric`) + SR = @static VERSION >= v"1.12" ? TA{T} : Symmetric{T} TY = if T <: Real - Union{Symmetric{Complex{T}},Symmetric{T}} + Union{Symmetric{Complex{T}},SR} else Union{Matrix{T},Hermitian{T}} end