Skip to content

Commit b92da50

Browse files
sethaxenoxinabox
andauthored
Add rules for dense matrix exponential (#351)
* Add matfun.jl file * Add matfun docstrings * Add exp matrix function * At least store one intermediate * Test exp! * Make pullback type-inferrable * Add clearer test label * Create as hermitian * Test rrule * Add comment about relationship between pushforward and pullback * Add header * Add reference to Frechet deriv paper * Run JuliaFormatter * Reduce comment spacing from code * Update src/rulesets/LinearAlgebra/matfun.jl * Correctly handle balancing * Test imbalanced matrix A * Increment version number * Apply suggestions from code review Co-authored-by: Lyndon White <[email protected]> * Change signature of _matfun_frechet * Give math for Frechet derivative * Change Frechet notation * Add _matfun_frechet_adjoint * Simplify hermitian code * Correct comment * Remove comments * Use abbreviated SHA * Link * Update comment * Move comment up * Move comment further up * Update docstrings * Push header to same level as rules Co-authored-by: Lyndon White <[email protected]>
1 parent 9004ee0 commit b92da50

File tree

6 files changed

+385
-14
lines changed

6 files changed

+385
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.48"
3+
version = "0.7.49"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ include("rulesets/LinearAlgebra/utils.jl")
4444
include("rulesets/LinearAlgebra/blas.jl")
4545
include("rulesets/LinearAlgebra/dense.jl")
4646
include("rulesets/LinearAlgebra/norm.jl")
47+
include("rulesets/LinearAlgebra/matfun.jl")
4748
include("rulesets/LinearAlgebra/structured.jl")
4849
include("rulesets/LinearAlgebra/symmetric.jl")
4950
include("rulesets/LinearAlgebra/factorization.jl")

src/rulesets/LinearAlgebra/matfun.jl

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# matrix functions of dense matrices
2+
# https://en.wikipedia.org/wiki/Matrix_function
3+
4+
# NOTE: for a matrix function f, the pushforward and pullback can be computed using the
5+
# Fréchet derivative and its adjoint, respectively.
6+
# https://en.wikipedia.org/wiki/Fréchet_derivative
7+
8+
# The pushforwards and pullbacks are related by matrix adjoints. If the pushforward of f(A)
9+
# at A is (f_*)_A(ΔA), then the pullback at A is (f^*)_A(ΔY) = ((f_*)_A(ΔY'))'.
10+
# If f has a power series representation with real coefficients, then this simplifies to
11+
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY)
12+
# So we reuse the code from the pushforward to implement the pullback.
13+
14+
#####
15+
##### interface function definitions
16+
#####
17+
18+
"""
19+
_matfun(f, A) -> (Y, intermediates)
20+
21+
Compute the matrix function `Y=f(A)` for matrix `A`.
22+
The function returns a tuple containing the result and a tuple of intermediates to be
23+
reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative.
24+
"""
25+
_matfun
26+
27+
"""
28+
_matfun!(f, A) -> (Y, intermediates)
29+
30+
Similar to [`_matfun`](@ref), but where `A` may be overwritten.
31+
"""
32+
_matfun!
33+
34+
"""
35+
_matfun_frechet(f, E, A, Y, intermediates)
36+
37+
Compute the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A`` in the direction
38+
of ``E``, where `intermediates` is the second argument returned by [`_matfun`](@ref).
39+
40+
The Fréchet derivative is the unique linear map ``L_f \\colon E → L_f(A, E)``, such that
41+
```math
42+
L_f(A, E) = f(A + E) - f(A) + o(\\lVert E \\rVert).
43+
```
44+
45+
[^Higham08]:
46+
> Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70.
47+
> doi: 10.1137/1.9780898717778.ch3
48+
"""
49+
_matfun_frechet
50+
51+
"""
52+
_matfun_frechet!(f, E, A, Y, intermediates)
53+
54+
Similar to [`_matfun_frechet`](@ref), but where `E` may be overwritten.
55+
"""
56+
_matfun_frechet!
57+
58+
"""
59+
_matfun_frechet_adjoint(f, E, A, Y, intermediates)
60+
61+
Compute the adjoint of the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A``
62+
in the direction of ``E``, where `intermediates` is the second argument returned by
63+
[`_matfun`](@ref).
64+
65+
Given the Fréchet ``L_f(A, E)`` computed by [`_matfun_frechet`](@ref), then its adjoint
66+
``L_f^⋆(A, E)`` is defined by the identity
67+
```math
68+
\\langle B, L_f(A, C) \\rangle = \\langle L_f^⋆(A, B), C \\rangle.
69+
```
70+
This identity is satisfied by ``L_f^⋆(A, E) = L_f(A, E')'``.
71+
72+
[^Higham08]:
73+
> Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70.
74+
> doi: 10.1137/1.9780898717778.ch3
75+
"""
76+
function _matfun_frechet_adjoint(f, E, A, Y, intermediates)
77+
E′ = E'
78+
# avoid passing an Adjoint to _matfun_frechet in case it can't handle it
79+
E′ = E′ isa Adjoint ? copy(E′) : E′
80+
LE = adjoint(_matfun_frechet(f, E′, A, Y, intermediates))
81+
# avoid returning an Adjoint
82+
return LE isa Adjoint ? copy(LE) : LE
83+
end
84+
85+
"""
86+
_matfun_frechet_adjoint!(f, E, A, Y, intermediates)
87+
88+
Similar to [`_matfun_frechet_adjoint`](@ref), but where `E` may be overwritten.
89+
"""
90+
function _matfun_frechet_adjoint!(f, E, A, Y, intermediates)
91+
E′ = E'
92+
# avoid passing an Adjoint to _matfun_frechet in case it can't handle it
93+
E′ = E′ isa Adjoint ? copy(E′) : E′
94+
LE = adjoint(_matfun_frechet!(f, E′, A, Y, intermediates))
95+
# avoid returning an Adjoint
96+
return LE isa Adjoint ? copy(LE) : LE
97+
end
98+
99+
#####
100+
##### `exp`/`exp!`
101+
#####
102+
103+
function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFloat})
104+
if ishermitian(A)
105+
hermA = Hermitian(A)
106+
hermX, intermediates = _matfun(exp, hermA)
107+
∂hermX = _matfun_frechet(exp, ΔA, hermA, hermX, intermediates)
108+
X = Matrix(hermX)
109+
∂X = Matrix(∂hermX)
110+
else
111+
X, intermediates = _matfun!(exp, A)
112+
∂X = _matfun_frechet!(exp, ΔA, A, X, intermediates)
113+
end
114+
return X, ∂X
115+
end
116+
117+
function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat})
118+
# TODO: try to make this more type-stable
119+
if ishermitian(A0)
120+
# call _matfun instead of the rrule to avoid hermitrizing ∂A in the pullback
121+
hermA = Hermitian(A0)
122+
hermX, hermX_intermediates = _matfun(exp, hermA)
123+
function exp_pullback_hermitian(ΔX)
124+
∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates)
125+
return NO_FIELDS, Matrix(∂hermA)
126+
end
127+
return Matrix(hermX), exp_pullback_hermitian
128+
else
129+
A = copy(A0)
130+
X, intermediates = _matfun!(exp, A)
131+
function exp_pullback(ΔX)
132+
∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates)
133+
return NO_FIELDS, ∂A
134+
end
135+
return X, exp_pullback
136+
end
137+
end
138+
139+
## Destructive matrix exponential using algorithm from Higham, 2008,
140+
## "Functions of Matrices: Theory and Computation", SIAM
141+
## Adapted from LinearAlgebra.exp! with return of intermediates
142+
## https://github.com/JuliaLang/julia/blob/f613b55/stdlib/LinearAlgebra/src/dense.jl#L583-L666
143+
function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat}
144+
n = LinearAlgebra.checksquare(A)
145+
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
146+
nA = opnorm(A, 1)
147+
Inn = Matrix{T}(I, n, n)
148+
## For sufficiently small nA, use lower order Padé-Approximations
149+
if (nA <= 2.1)
150+
if nA > 0.95
151+
C = T[
152+
17643225600.0,
153+
8821612800.0,
154+
2075673600.0,
155+
302702400.0,
156+
30270240.0,
157+
2162160.0,
158+
110880.0,
159+
3960.0,
160+
90.0,
161+
1.0,
162+
]
163+
elseif nA > 0.25
164+
C = T[17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0]
165+
elseif nA > 0.015
166+
C = T[30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0]
167+
else
168+
C = T[120.0, 60.0, 12.0, 1.0]
169+
end
170+
si = 0
171+
else
172+
C = T[
173+
64764752532480000.0,
174+
32382376266240000.0,
175+
7771770303897600.0,
176+
1187353796428800.0,
177+
129060195264000.0,
178+
10559470521600.0,
179+
670442572800.0,
180+
33522128640.0,
181+
1323241920.0,
182+
40840800.0,
183+
960960.0,
184+
16380.0,
185+
182.0,
186+
1.0,
187+
]
188+
s = log2(nA / 5.4) # power of 2 later reversed by squaring
189+
si = ceil(Int, s)
190+
end
191+
192+
if si > 0
193+
A ./= convert(T, 2^si)
194+
end
195+
196+
A2 = A * A
197+
P = copy(Inn)
198+
W = C[2] * P
199+
V = C[1] * P
200+
Apows = typeof(P)[]
201+
for k in 1:(div(size(C, 1), 2) - 1)
202+
k2 = 2 * k
203+
P *= A2
204+
push!(Apows, P)
205+
W += C[k2 + 2] * P
206+
V += C[k2 + 1] * P
207+
end
208+
U = A * W
209+
X = V + U
210+
F = lu!(V - U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization
211+
ldiv!(F, X)
212+
Xpows = typeof(X)[X]
213+
if si > 0 # squaring to reverse dividing by power of 2
214+
for t in 1:si
215+
X *= X
216+
push!(Xpows, X)
217+
end
218+
end
219+
220+
_unbalance!(X, ilo, ihi, scale, n)
221+
return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows)
222+
end
223+
224+
# Application of the chain rule to exp!, also Algorithm 6.4 from
225+
# Al-Mohy, Awad H. and Higham, Nicholas J. (2009).
226+
# Computing the Fréchet Derivative of the Matrix Exponential, with an application to
227+
# Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657.
228+
# http://eprints.maths.manchester.ac.uk/id/eprint/1218
229+
function _matfun_frechet!(
230+
::typeof(exp), ΔA, A::StridedMatrix{T}, X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows)
231+
) where {T<:BlasFloat}
232+
n = LinearAlgebra.checksquare(A)
233+
_balance!(ΔA, ilo, ihi, scale, n)
234+
235+
if si > 0
236+
ΔA ./= convert(T, 2^si)
237+
end
238+
239+
∂A2 = mul!(A * ΔA, ΔA, A, true, true)
240+
A2 = first(Apows)
241+
# we will repeatedly overwrite ∂temp and ∂P below
242+
∂temp = Matrix{eltype(∂A2)}(undef, n, n)
243+
∂P = copy(∂A2)
244+
∂W = C[4] * ∂P
245+
∂V = C[3] * ∂P
246+
for k in 2:(length(Apows) - 1)
247+
k2 = 2 * k
248+
P = Apows[k - 1]
249+
∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P
250+
axpy!(C[k2 + 2], ∂P, ∂W)
251+
axpy!(C[k2 + 1], ∂P, ∂V)
252+
end
253+
∂U, ∂temp = mul!(mul!(∂temp, A, ∂W), ΔA, W, true, true), ∂W
254+
∂temp .= ∂U .- ∂V
255+
∂X = add!!(∂U, ∂V)
256+
mul!(∂X, ∂temp, first(Xpows), true, true)
257+
ldiv!(F, ∂X)
258+
259+
if si > 0
260+
for t in 1:(length(Xpows) - 1)
261+
X = Xpows[t]
262+
∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X
263+
end
264+
end
265+
266+
_unbalance!(∂X, ilo, ihi, scale, n)
267+
return ∂X
268+
end
269+
270+
#####
271+
##### utils
272+
#####
273+
274+
# Given (ilo, ihi, iscale) returned by LAPACK.gebal!('B', A), apply same balancing to X
275+
function _balance!(X, ilo, ihi, scale, n)
276+
n = size(X, 1)
277+
if ihi < n
278+
for j in (ihi + 1):n
279+
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
280+
end
281+
end
282+
if ilo > 1
283+
for j in (ilo - 1):-1:1
284+
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
285+
end
286+
end
287+
288+
for j in ilo:ihi
289+
scj = scale[j]
290+
for i in 1:n
291+
X[j, i] /= scj
292+
end
293+
for i in 1:n
294+
X[i, j] *= scj
295+
end
296+
end
297+
return X
298+
end
299+
300+
# Reverse of _balance!
301+
function _unbalance!(X, ilo, ihi, scale, n)
302+
for j in ilo:ihi
303+
scj = scale[j]
304+
for i in 1:n
305+
X[j, i] *= scj
306+
end
307+
for i in 1:n
308+
X[i, j] /= scj
309+
end
310+
end
311+
312+
if ilo > 1
313+
for j in (ilo - 1):-1:1
314+
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
315+
end
316+
end
317+
if ihi < n
318+
for j in (ihi + 1):n
319+
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
320+
end
321+
end
322+
return X
323+
end

src/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a
292292
@eval begin
293293
function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
294294
Y, intermediates = _matfun($func, A)
295-
= _matfun_frechet($func, A, Y, ΔA, intermediates)
295+
= _matfun_frechet($func, ΔA, A, Y, intermediates)
296296
# If ΔA was hermitian, then ∂Y has the same structure as Y
297297
∂Y = if ishermitian(ΔA) && (isa(Y, Symmetric) || isa(Y, Hermitian))
298298
_symhermlike!(Ȳ, Y)
@@ -308,8 +308,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a
308308
# for Hermitian Y, we don't need to realify the diagonal of ΔY, since the
309309
# effect is the same as applying _hermitrizelike! at the end
310310
∂Y = eltype(Y) <: Real ? real(ΔY) : ΔY
311-
# for matrix functions, the pullback is related to the pushforward by an adjoint
312-
= _matfun_frechet($func, A, Y, ∂Y', intermediates)
311+
= _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates)
313312
# the cotangent of Hermitian A should be Hermitian
314313
∂A = _hermitrizelike!(Ā, A)
315314
return NO_FIELDS, ∂A
@@ -344,9 +343,9 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
344343
ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA)
345344
end
346345
if ΔcosA isa AbstractZero
347-
= _matfun_frechet(sin, A, sinA, ΔsinA, (λ, U, sinλ, cosλ))
346+
= _matfun_frechet_adjoint(sin, ΔsinA, A, sinA, (λ, U, sinλ, cosλ))
348347
elseif ΔsinA isa AbstractZero
349-
= _matfun_frechet(cos, A, cosA, ΔcosA, (λ, U, cosλ, -sinλ))
348+
= _matfun_frechet_adjoint(cos, ΔcosA, A, cosA, (λ, U, cosλ, -sinλ))
350349
else
351350
# we will overwrite tmp with various temporary values during this computation
352351
tmp = mul!(similar(U, Base.promote_eltype(U, ΔsinA, ΔcosA)), ΔsinA, U)
@@ -367,7 +366,8 @@ end
367366
368367
Compute the matrix function `f(A)` for real or complex hermitian `A`.
369368
The function returns a tuple containing the result and a tuple of intermediates to be
370-
reused by `_matfun_frechet` to compute the Fréchet derivative.
369+
reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative.
370+
371371
Note any function `f` used with this **must** have a `frule` defined on it.
372372
"""
373373
function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
@@ -392,13 +392,7 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
392392
end
393393

394394
# Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations
395-
"""
396-
_matfun_frechet(f, A::RealHermSymComplexHerm, Y, ΔA, intermediates)
397-
398-
Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative
399-
of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`.
400-
"""
401-
function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ))
395+
function _matfun_frechet(f, ΔA, A::LinearAlgebra.RealHermSymComplexHerm, Y, (λ, U, fλ, df_dλ))
402396
# We will overwrite tmp matrix several times to hold different values
403397
tmp = mul!(similar(U, Base.promote_eltype(U, ΔA)), ΔA, U)
404398
∂Λ = mul!(similar(tmp), U', tmp)

0 commit comments

Comments
 (0)