From f13e0a45d10bb13f48d6208e9c9d5b4a52b96732 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Sun, 20 Mar 2022 18:21:22 -0400 Subject: [PATCH] Fix `rrule` for complex-valued determinants (#601) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix rrule for det of complex matrix * Add test of rules for complex determinant * `ΔΩ * inv(x)'` ⇒ `inv(x)' * dot(Ω, ΔΩ)` Co-authored-by: Seth Axen * Use a random matrix as a seed for the unitary * Bump patch number * Use `U = exp(B - B')` as complex `det` test matrix Co-authored-by: Seth Axen Co-authored-by: Seth Axen --- Project.toml | 2 +- src/rulesets/LinearAlgebra/dense.jl | 2 +- test/rulesets/LinearAlgebra/dense.jl | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 7b80db3e1..bfcd14aef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.0" +version = "1.28.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 13d77d302..ebd0c07ce 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -129,7 +129,7 @@ frule((_, Δx), ::typeof(det), x::Number) = (det(x), Δx) function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) Ω = det(x) function det_pullback(ΔΩ) - ∂x = x isa Number ? ΔΩ : Ω * ΔΩ * inv(x)' + ∂x = x isa Number ? ΔΩ : inv(x)' * dot(Ω, ΔΩ) return (NoTangent(), ∂x) end return Ω, det_pullback diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index d8eb50eb0..ec8288a84 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -119,6 +119,12 @@ test_rrule(f, B) end end + @testset "$f(complex determinant)" begin + B = randn(ComplexF64, 4, 4) + U = exp(B - B') + test_frule(f, U) + test_rrule(f, U) + end end @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64) B = randn(T, 4, 4)