Skip to content

Commit a9968d3

Browse files
bors[bot]sethaxen
andauthored
Merge #888
888: Remove rules for matrix exponential r=DhairyaLGandhi a=sethaxen JuliaDiff/ChainRules.jl#351 added rules for the dense matrix exponential to ChainRules. This PR removes the corresponding adjoint from Zygote. Co-authored-by: Seth Axen <[email protected]>
2 parents 71370b6 + 0d9a1c8 commit a9968d3

File tree

2 files changed

+2
-25
lines changed

2 files changed

+2
-25
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.1"
3+
version = "0.6.2"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2222

2323
[compat]
2424
AbstractFFTs = "0.5, 1.0"
25-
ChainRules = "0.7.47"
25+
ChainRules = "0.7.49"
2626
DiffRules = "1.0"
2727
FillArrays = "0.8, 0.9, 0.10, 0.11"
2828
ForwardDiff = "0.10"

src/lib/array.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -555,29 +555,6 @@ Base.@propagate_inbounds function _pairdiffquotmat(f, n, x, fx, dfx, d²fx = not
555555
return Δfij.(Base.OneTo(n), Base.OneTo(n)')
556556
end
557557

558-
# Adjoint based on the Theano implementation, which uses the differential as described
559-
# in Brančík, "Matlab programs for matrix exponential function derivative evaluation"
560-
@adjoint exp(A::AbstractMatrix) = exp(A), function(F̄)
561-
n = size(A, 1)
562-
E = eigen(A)
563-
w = E.values
564-
ew = exp.(w)
565-
X = _pairdiffquotmat(exp, n, w, ew, ew, ew)
566-
V = E.vectors
567-
VF = factorize(V)
568-
Āc = (V * ((VF \' * V) .* X) / VF)'
569-
Ā = isreal(A) && isreal(F̄) ? real(Āc) : Āc
570-
return (Ā,)
571-
end
572-
573-
# The adjoint for exp(::AbstractArray) intercepts ChainRules' rrule for exp(::Hermitian),
574-
# so we call it manually. This can be removed when the generic rule for exp is moved to
575-
# ChainRules
576-
@adjoint function exp(A::LinearAlgebra.RealHermSymComplexHerm)
577-
Y, back = chain_rrule(exp, A)
578-
return Y, Δ -> (back(Δ)[2],)
579-
end
580-
581558
# Hermitian/Symmetric matrix functions that can be written as power series
582559
_realifydiag!(A::AbstractArray{<:Real}) = A
583560
function _realifydiag!(A)

0 commit comments

Comments
 (0)