diff --git a/Project.toml b/Project.toml index ff52dd8bb..ddb2fe3fd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.73.0" +version = "1.73.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index cdadca6c3..2e713c2ed 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -381,7 +381,19 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) df_dλ = last.(unthunk.(fλ_df_dλ)) fA = (U * Diagonal(fλ)) * U' Y = if eltype(A) <: Real - Symmetric(fA) + if eltype(fλ) <: Complex + # a complex-valued function of a real Herm/Sym is always `Symmetric` + Symmetric(fA) + else + # Julia 1.12 made real-valued matrix functions of `Hermitian{<:Real}` + # return `Hermitian` (they previously returned `Symmetric`); match the + # type of the primal `f(A)`. + @static if VERSION >= v"1.12" + _symhermtype(A)(fA) + else + Symmetric(fA) + end + end elseif eltype(fλ) <: Complex fA else diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 593b82148..1940ca15e 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -328,8 +328,11 @@ if is_inferable(f, A) Y_ad, ∂Y_ad = @maybe_inferred frule((ZeroTangent(), ΔA), f, A) else + # real-valued functions of `TA{<:Real}` return `TA` on Julia + # 1.12+ (they previously always returned `Symmetric`) + SR = @static VERSION >= v"1.12" ? TA{T} : Symmetric{T} TY = T∂Y = if T <: Real - Union{Symmetric{Complex{T}},Symmetric{T}} + Union{Symmetric{Complex{T}},SR} else Union{Matrix{T},Hermitian{T}} end @@ -381,8 +384,11 @@ if is_inferable(f, A) Y_ad, back = @maybe_inferred rrule(f, A) else + # real-valued functions of `TA{<:Real}` return `TA` on Julia + # 1.12+ (they previously always returned `Symmetric`) + SR = @static VERSION >= v"1.12" ? TA{T} : Symmetric{T} TY = if T <: Real - Union{Symmetric{Complex{T}},Symmetric{T}} + Union{Symmetric{Complex{T}},SR} else Union{Matrix{T},Hermitian{T}} end