Skip to content

Commit

Permalink
Merge #888
Browse files Browse the repository at this point in the history
888: Remove rules for matrix exponential r=DhairyaLGandhi a=sethaxen

JuliaDiff/ChainRules.jl#351 added rules for the dense matrix exponential to ChainRules. This PR removes the corresponding adjoint from Zygote.

Co-authored-by: Seth Axen <[email protected]>
  • Loading branch information
bors[bot] and sethaxen authored Jan 21, 2021
2 parents 71370b6 + 0d9a1c8 commit a9968d3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 25 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.1"
version = "0.6.2"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.7.47"
ChainRules = "0.7.49"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
Expand Down
23 changes: 0 additions & 23 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,29 +555,6 @@ Base.@propagate_inbounds function _pairdiffquotmat(f, n, x, fx, dfx, d²fx = not
return Δfij.(Base.OneTo(n), Base.OneTo(n)')
end

# Adjoint based on the Theano implementation, which uses the differential as described
# in Brančík, "Matlab programs for matrix exponential function derivative evaluation"
@adjoint exp(A::AbstractMatrix) = exp(A), function(F̄)
n = size(A, 1)
E = eigen(A)
w = E.values
ew = exp.(w)
X = _pairdiffquotmat(exp, n, w, ew, ew, ew)
V = E.vectors
VF = factorize(V)
Āc = (V * ((VF \' * V) .* X) / VF)'
Ā = isreal(A) && isreal(F̄) ? real(Āc) : Āc
return (Ā,)
end

# The adjoint for exp(::AbstractArray) intercepts ChainRules' rrule for exp(::Hermitian),
# so we call it manually. This can be removed when the generic rule for exp is moved to
# ChainRules
@adjoint function exp(A::LinearAlgebra.RealHermSymComplexHerm)
Y, back = chain_rrule(exp, A)
return Y, Δ -> (back(Δ)[2],)
end

# Hermitian/Symmetric matrix functions that can be written as power series
_realifydiag!(A::AbstractArray{<:Real}) = A
function _realifydiag!(A)
Expand Down

2 comments on commit a9968d3

@sethaxen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DhairyaLGandhi can you register a new release as well?

@DhairyaLGandhi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I have another PR queued up and I was going to release those together

Please sign in to comment.