Skip to content

Add rules for dense matrix exponential #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6e8358c
Add matfun.jl file
sethaxen Jan 18, 2021
f987f35
Add matfun docstrings
sethaxen Jan 18, 2021
e6b92c9
Add exp matrix function
sethaxen Jan 18, 2021
7624ee5
At least store one intermediate
sethaxen Jan 18, 2021
a7792a5
Test exp!
sethaxen Jan 18, 2021
6d6b4cb
Make pullback type-inferrable
sethaxen Jan 18, 2021
3645d75
Add clearer test label
sethaxen Jan 18, 2021
937f2ac
Create as hermitian
sethaxen Jan 18, 2021
b48204c
Test rrule
sethaxen Jan 18, 2021
2c19bba
Add comment about relationship between pushforward and pullback
sethaxen Jan 18, 2021
58f6005
Add header
sethaxen Jan 18, 2021
6ee1759
Add reference to Frechet deriv paper
sethaxen Jan 18, 2021
b1a2980
Run JuliaFormatter
sethaxen Jan 18, 2021
e860b3e
Reduce comment spacing from code
sethaxen Jan 18, 2021
8f665ac
Update src/rulesets/LinearAlgebra/matfun.jl
sethaxen Jan 18, 2021
9e565ae
Correctly handle balancing
sethaxen Jan 18, 2021
71134fd
Test imbalanced matrix A
sethaxen Jan 18, 2021
bd48565
Increment version number
sethaxen Jan 18, 2021
062b11d
Merge branch 'exp2' of https://github.com/sethaxen/ChainRules.jl into…
sethaxen Jan 18, 2021
dc1b1ab
Apply suggestions from code review
sethaxen Jan 19, 2021
57aea17
Change signature of _matfun_frechet
sethaxen Jan 20, 2021
e2e6605
Give math for Frechet derivative
sethaxen Jan 20, 2021
976af09
Change Frechet notation
sethaxen Jan 20, 2021
d7d20ba
Add _matfun_frechet_adjoint
sethaxen Jan 20, 2021
b0ae61c
Simplify hermitian code
sethaxen Jan 20, 2021
62b963b
Correct comment
sethaxen Jan 20, 2021
87e4c53
Remove comments
sethaxen Jan 20, 2021
9bd06b1
Use abbreviated SHA
sethaxen Jan 20, 2021
2ed06e6
Link
sethaxen Jan 20, 2021
156e6f5
Update comment
sethaxen Jan 20, 2021
9a63d13
Move comment up
sethaxen Jan 20, 2021
5ba193d
Move comment further up
sethaxen Jan 20, 2021
49df929
Update docstrings
sethaxen Jan 20, 2021
8c27276
Push header to same level as rules
sethaxen Jan 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.48"
version = "0.7.49"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("rulesets/LinearAlgebra/utils.jl")
include("rulesets/LinearAlgebra/blas.jl")
include("rulesets/LinearAlgebra/dense.jl")
include("rulesets/LinearAlgebra/norm.jl")
include("rulesets/LinearAlgebra/matfun.jl")
include("rulesets/LinearAlgebra/structured.jl")
include("rulesets/LinearAlgebra/symmetric.jl")
include("rulesets/LinearAlgebra/factorization.jl")
Expand Down
323 changes: 323 additions & 0 deletions src/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
# matrix functions of dense matrices
# https://en.wikipedia.org/wiki/Matrix_function

# NOTE: for a matrix function f, the pushforward and pullback can be computed using the
# Fréchet derivative and its adjoint, respectively.
# https://en.wikipedia.org/wiki/Fréchet_derivative

# The pushforwards and pullbacks are related by matrix adjoints. If the pushforward of f(A)
# at A is (f_*)_A(ΔA), then the pullback at A is (f^*)_A(ΔY) = ((f_*)_A(ΔY'))'.
# If f has a power series representation with real coefficients, then this simplifies to
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY)
# So we reuse the code from the pushforward to implement the pullback.

#####
##### interface function definitions
#####

"""
_matfun(f, A) -> (Y, intermediates)

Compute the matrix function `Y=f(A)` for matrix `A`.
The function returns a tuple containing the result and a tuple of intermediates to be
reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative.
"""
_matfun

"""
_matfun!(f, A) -> (Y, intermediates)

Similar to [`_matfun`](@ref), but where `A` may be overwritten.
"""
_matfun!

"""
_matfun_frechet(f, E, A, Y, intermediates)

Compute the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A`` in the direction
of ``E``, where `intermediates` is the second argument returned by [`_matfun`](@ref).

The Fréchet derivative is the unique linear map ``L_f \\colon E → L_f(A, E)``, such that
```math
L_f(A, E) = f(A + E) - f(A) + o(\\lVert E \\rVert).
```

[^Higham08]:
> Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70.
> doi: 10.1137/1.9780898717778.ch3
"""
_matfun_frechet

"""
_matfun_frechet!(f, E, A, Y, intermediates)

Similar to [`_matfun_frechet`](@ref), but where `E` may be overwritten.
"""
_matfun_frechet!

"""
_matfun_frechet_adjoint(f, E, A, Y, intermediates)

Compute the adjoint of the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A``
in the direction of ``E``, where `intermediates` is the second argument returned by
[`_matfun`](@ref).

Given the Fréchet ``L_f(A, E)`` computed by [`_matfun_frechet`](@ref), then its adjoint
``L_f^⋆(A, E)`` is defined by the identity
```math
\\langle B, L_f(A, C) \\rangle = \\langle L_f^⋆(A, B), C \\rangle.
```
This identity is satisfied by ``L_f^⋆(A, E) = L_f(A, E')'``.

[^Higham08]:
> Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70.
> doi: 10.1137/1.9780898717778.ch3
"""
function _matfun_frechet_adjoint(f, E, A, Y, intermediates)
E′ = E'
# avoid passing an Adjoint to _matfun_frechet in case it can't handle it
E′ = E′ isa Adjoint ? copy(E′) : E′
LE = adjoint(_matfun_frechet(f, E′, A, Y, intermediates))
# avoid returning an Adjoint
return LE isa Adjoint ? copy(LE) : LE
end

"""
_matfun_frechet_adjoint!(f, E, A, Y, intermediates)

Similar to [`_matfun_frechet_adjoint`](@ref), but where `E` may be overwritten.
"""
function _matfun_frechet_adjoint!(f, E, A, Y, intermediates)
E′ = E'
# avoid passing an Adjoint to _matfun_frechet in case it can't handle it
E′ = E′ isa Adjoint ? copy(E′) : E′
LE = adjoint(_matfun_frechet!(f, E′, A, Y, intermediates))
# avoid returning an Adjoint
return LE isa Adjoint ? copy(LE) : LE
end

#####
##### `exp`/`exp!`
#####

function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFloat})
if ishermitian(A)
hermA = Hermitian(A)
hermX, intermediates = _matfun(exp, hermA)
∂hermX = _matfun_frechet(exp, ΔA, hermA, hermX, intermediates)
X = Matrix(hermX)
∂X = Matrix(∂hermX)
else
X, intermediates = _matfun!(exp, A)
∂X = _matfun_frechet!(exp, ΔA, A, X, intermediates)
end
return X, ∂X
end

function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat})
# TODO: try to make this more type-stable
if ishermitian(A0)
# call _matfun instead of the rrule to avoid hermitrizing ∂A in the pullback
hermA = Hermitian(A0)
hermX, hermX_intermediates = _matfun(exp, hermA)
function exp_pullback_hermitian(ΔX)
∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates)
return NO_FIELDS, Matrix(∂hermA)
end
return Matrix(hermX), exp_pullback_hermitian
else
A = copy(A0)
X, intermediates = _matfun!(exp, A)
function exp_pullback(ΔX)
∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates)
return NO_FIELDS, ∂A
end
return X, exp_pullback
end
end

## Destructive matrix exponential using algorithm from Higham, 2008,
## "Functions of Matrices: Theory and Computation", SIAM
## Adapted from LinearAlgebra.exp! with return of intermediates
## https://github.com/JuliaLang/julia/blob/f613b55/stdlib/LinearAlgebra/src/dense.jl#L583-L666
function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat}
n = LinearAlgebra.checksquare(A)
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = opnorm(A, 1)
Inn = Matrix{T}(I, n, n)
## For sufficiently small nA, use lower order Padé-Approximations
if (nA <= 2.1)
if nA > 0.95
C = T[
17643225600.0,
8821612800.0,
2075673600.0,
302702400.0,
30270240.0,
2162160.0,
110880.0,
3960.0,
90.0,
1.0,
]
elseif nA > 0.25
C = T[17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0]
elseif nA > 0.015
C = T[30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0]
else
C = T[120.0, 60.0, 12.0, 1.0]
end
si = 0
else
C = T[
64764752532480000.0,
32382376266240000.0,
7771770303897600.0,
1187353796428800.0,
129060195264000.0,
10559470521600.0,
670442572800.0,
33522128640.0,
1323241920.0,
40840800.0,
960960.0,
16380.0,
182.0,
1.0,
]
s = log2(nA / 5.4) # power of 2 later reversed by squaring
si = ceil(Int, s)
end

if si > 0
A ./= convert(T, 2^si)
end

A2 = A * A
P = copy(Inn)
W = C[2] * P
V = C[1] * P
Apows = typeof(P)[]
for k in 1:(div(size(C, 1), 2) - 1)
k2 = 2 * k
P *= A2
push!(Apows, P)
W += C[k2 + 2] * P
V += C[k2 + 1] * P
end
U = A * W
X = V + U
F = lu!(V - U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization
ldiv!(F, X)
Xpows = typeof(X)[X]
if si > 0 # squaring to reverse dividing by power of 2
for t in 1:si
X *= X
push!(Xpows, X)
end
end

_unbalance!(X, ilo, ihi, scale, n)
return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows)
end

# Application of the chain rule to exp!, also Algorithm 6.4 from
# Al-Mohy, Awad H. and Higham, Nicholas J. (2009).
# Computing the Fréchet Derivative of the Matrix Exponential, with an application to
# Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657.
# http://eprints.maths.manchester.ac.uk/id/eprint/1218
function _matfun_frechet!(
::typeof(exp), ΔA, A::StridedMatrix{T}, X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows)
) where {T<:BlasFloat}
n = LinearAlgebra.checksquare(A)
_balance!(ΔA, ilo, ihi, scale, n)

if si > 0
ΔA ./= convert(T, 2^si)
end

∂A2 = mul!(A * ΔA, ΔA, A, true, true)
A2 = first(Apows)
# we will repeatedly overwrite ∂temp and ∂P below
∂temp = Matrix{eltype(∂A2)}(undef, n, n)
∂P = copy(∂A2)
∂W = C[4] * ∂P
∂V = C[3] * ∂P
for k in 2:(length(Apows) - 1)
k2 = 2 * k
P = Apows[k - 1]
∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P
axpy!(C[k2 + 2], ∂P, ∂W)
axpy!(C[k2 + 1], ∂P, ∂V)
end
∂U, ∂temp = mul!(mul!(∂temp, A, ∂W), ΔA, W, true, true), ∂W
∂temp .= ∂U .- ∂V
∂X = add!!(∂U, ∂V)
mul!(∂X, ∂temp, first(Xpows), true, true)
ldiv!(F, ∂X)

if si > 0
for t in 1:(length(Xpows) - 1)
X = Xpows[t]
∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X
end
end

_unbalance!(∂X, ilo, ihi, scale, n)
return ∂X
end

#####
##### utils
#####

# Given (ilo, ihi, iscale) returned by LAPACK.gebal!('B', A), apply same balancing to X
function _balance!(X, ilo, ihi, scale, n)
n = size(X, 1)
if ihi < n
for j in (ihi + 1):n
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
end
end
if ilo > 1
for j in (ilo - 1):-1:1
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
end
end

for j in ilo:ihi
scj = scale[j]
for i in 1:n
X[j, i] /= scj
end
for i in 1:n
X[i, j] *= scj
end
end
return X
end

# Reverse of _balance!
function _unbalance!(X, ilo, ihi, scale, n)
for j in ilo:ihi
scj = scale[j]
for i in 1:n
X[j, i] *= scj
end
for i in 1:n
X[i, j] /= scj
end
end

if ilo > 1
for j in (ilo - 1):-1:1
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
end
end
if ihi < n
for j in (ihi + 1):n
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
end
end
return X
end
20 changes: 7 additions & 13 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a
@eval begin
function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
Y, intermediates = _matfun($func, A)
Ȳ = _matfun_frechet($func, A, Y, ΔA, intermediates)
Ȳ = _matfun_frechet($func, ΔA, A, Y, intermediates)
# If ΔA was hermitian, then ∂Y has the same structure as Y
∂Y = if ishermitian(ΔA) && (isa(Y, Symmetric) || isa(Y, Hermitian))
_symhermlike!(Ȳ, Y)
Expand All @@ -308,8 +308,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a
# for Hermitian Y, we don't need to realify the diagonal of ΔY, since the
# effect is the same as applying _hermitrizelike! at the end
∂Y = eltype(Y) <: Real ? real(ΔY) : ΔY
# for matrix functions, the pullback is related to the pushforward by an adjoint
Ā = _matfun_frechet($func, A, Y, ∂Y', intermediates)
Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates)
# the cotangent of Hermitian A should be Hermitian
∂A = _hermitrizelike!(Ā, A)
return NO_FIELDS, ∂A
Expand Down Expand Up @@ -344,9 +343,9 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA)
end
if ΔcosA isa AbstractZero
Ā = _matfun_frechet(sin, A, sinA, ΔsinA, (λ, U, sinλ, cosλ))
Ā = _matfun_frechet_adjoint(sin, ΔsinA, A, sinA, (λ, U, sinλ, cosλ))
elseif ΔsinA isa AbstractZero
Ā = _matfun_frechet(cos, A, cosA, ΔcosA, (λ, U, cosλ, -sinλ))
Ā = _matfun_frechet_adjoint(cos, ΔcosA, A, cosA, (λ, U, cosλ, -sinλ))
else
# we will overwrite tmp with various temporary values during this computation
tmp = mul!(similar(U, Base.promote_eltype(U, ΔsinA, ΔcosA)), ΔsinA, U)
Expand All @@ -367,7 +366,8 @@ end

Compute the matrix function `f(A)` for real or complex hermitian `A`.
The function returns a tuple containing the result and a tuple of intermediates to be
reused by `_matfun_frechet` to compute the Fréchet derivative.
reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative.

Note any function `f` used with this **must** have a `frule` defined on it.
"""
function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
Expand All @@ -392,13 +392,7 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
end

# Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations
"""
_matfun_frechet(f, A::RealHermSymComplexHerm, Y, ΔA, intermediates)

Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative
of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`.
"""
function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ))
function _matfun_frechet(f, ΔA, A::LinearAlgebra.RealHermSymComplexHerm, Y, (λ, U, fλ, df_dλ))
# We will overwrite tmp matrix several times to hold different values
tmp = mul!(similar(U, Base.promote_eltype(U, ΔA)), ΔA, U)
∂Λ = mul!(similar(tmp), U', tmp)
Expand Down
Loading