|
| 1 | +# matrix functions of dense matrices |
| 2 | +# https://en.wikipedia.org/wiki/Matrix_function |
| 3 | + |
| 4 | +# NOTE: for a matrix function f, the pushforward and pullback can be computed using the |
| 5 | +# Fréchet derivative and its adjoint, respectively. |
| 6 | +# https://en.wikipedia.org/wiki/Fréchet_derivative |
| 7 | + |
| 8 | +# The pushforwards and pullbacks are related by matrix adjoints. If the pushforward of f(A) |
| 9 | +# at A is (f_*)_A(ΔA), then the pullback at A is (f^*)_A(ΔY) = ((f_*)_A(ΔY'))'. |
| 10 | +# If f has a power series representation with real coefficients, then this simplifies to |
| 11 | +# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) |
| 12 | +# So we reuse the code from the pushforward to implement the pullback. |
| 13 | + |
| 14 | +##### |
| 15 | +##### interface function definitions |
| 16 | +##### |
| 17 | + |
| 18 | +""" |
| 19 | + _matfun(f, A) -> (Y, intermediates) |
| 20 | +
|
| 21 | +Compute the matrix function `Y=f(A)` for matrix `A`. |
| 22 | +The function returns a tuple containing the result and a tuple of intermediates to be |
| 23 | +reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative. |
| 24 | +""" |
| 25 | +_matfun |
| 26 | + |
| 27 | +""" |
| 28 | + _matfun!(f, A) -> (Y, intermediates) |
| 29 | +
|
| 30 | +Similar to [`_matfun`](@ref), but where `A` may be overwritten. |
| 31 | +""" |
| 32 | +_matfun! |
| 33 | + |
| 34 | +""" |
| 35 | + _matfun_frechet(f, E, A, Y, intermediates) |
| 36 | +
|
| 37 | +Compute the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A`` in the direction |
| 38 | +of ``E``, where `intermediates` is the second argument returned by [`_matfun`](@ref). |
| 39 | +
|
| 40 | +The Fréchet derivative is the unique linear map ``L_f \\colon E → L_f(A, E)``, such that |
| 41 | +```math |
| 42 | +L_f(A, E) = f(A + E) - f(A) + o(\\lVert E \\rVert). |
| 43 | +``` |
| 44 | +
|
| 45 | +[^Higham08]: |
| 46 | + > Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70. |
| 47 | + > doi: 10.1137/1.9780898717778.ch3 |
| 48 | +""" |
| 49 | +_matfun_frechet |
| 50 | + |
| 51 | +""" |
| 52 | + _matfun_frechet!(f, E, A, Y, intermediates) |
| 53 | +
|
| 54 | +Similar to [`_matfun_frechet`](@ref), but where `E` may be overwritten. |
| 55 | +""" |
| 56 | +_matfun_frechet! |
| 57 | + |
| 58 | +""" |
| 59 | + _matfun_frechet_adjoint(f, E, A, Y, intermediates) |
| 60 | +
|
| 61 | +Compute the adjoint of the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A`` |
| 62 | +in the direction of ``E``, where `intermediates` is the second argument returned by |
| 63 | +[`_matfun`](@ref). |
| 64 | +
|
| 65 | +Given the Fréchet ``L_f(A, E)`` computed by [`_matfun_frechet`](@ref), then its adjoint |
| 66 | +``L_f^⋆(A, E)`` is defined by the identity |
| 67 | +```math |
| 68 | +\\langle B, L_f(A, C) \\rangle = \\langle L_f^⋆(A, B), C \\rangle. |
| 69 | +``` |
| 70 | +This identity is satisfied by ``L_f^⋆(A, E) = L_f(A, E')'``. |
| 71 | +
|
| 72 | +[^Higham08]: |
| 73 | + > Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70. |
| 74 | + > doi: 10.1137/1.9780898717778.ch3 |
| 75 | +""" |
| 76 | +function _matfun_frechet_adjoint(f, E, A, Y, intermediates) |
| 77 | + E′ = E' |
| 78 | + # avoid passing an Adjoint to _matfun_frechet in case it can't handle it |
| 79 | + E′ = E′ isa Adjoint ? copy(E′) : E′ |
| 80 | + LE = adjoint(_matfun_frechet(f, E′, A, Y, intermediates)) |
| 81 | + # avoid returning an Adjoint |
| 82 | + return LE isa Adjoint ? copy(LE) : LE |
| 83 | +end |
| 84 | + |
| 85 | +""" |
| 86 | + _matfun_frechet_adjoint!(f, E, A, Y, intermediates) |
| 87 | +
|
| 88 | +Similar to [`_matfun_frechet_adjoint`](@ref), but where `E` may be overwritten. |
| 89 | +""" |
| 90 | +function _matfun_frechet_adjoint!(f, E, A, Y, intermediates) |
| 91 | + E′ = E' |
| 92 | + # avoid passing an Adjoint to _matfun_frechet in case it can't handle it |
| 93 | + E′ = E′ isa Adjoint ? copy(E′) : E′ |
| 94 | + LE = adjoint(_matfun_frechet!(f, E′, A, Y, intermediates)) |
| 95 | + # avoid returning an Adjoint |
| 96 | + return LE isa Adjoint ? copy(LE) : LE |
| 97 | +end |
| 98 | + |
| 99 | +##### |
| 100 | +##### `exp`/`exp!` |
| 101 | +##### |
| 102 | + |
| 103 | +function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFloat}) |
| 104 | + if ishermitian(A) |
| 105 | + hermA = Hermitian(A) |
| 106 | + hermX, intermediates = _matfun(exp, hermA) |
| 107 | + ∂hermX = _matfun_frechet(exp, ΔA, hermA, hermX, intermediates) |
| 108 | + X = Matrix(hermX) |
| 109 | + ∂X = Matrix(∂hermX) |
| 110 | + else |
| 111 | + X, intermediates = _matfun!(exp, A) |
| 112 | + ∂X = _matfun_frechet!(exp, ΔA, A, X, intermediates) |
| 113 | + end |
| 114 | + return X, ∂X |
| 115 | +end |
| 116 | + |
| 117 | +function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) |
| 118 | + # TODO: try to make this more type-stable |
| 119 | + if ishermitian(A0) |
| 120 | + # call _matfun instead of the rrule to avoid hermitrizing ∂A in the pullback |
| 121 | + hermA = Hermitian(A0) |
| 122 | + hermX, hermX_intermediates = _matfun(exp, hermA) |
| 123 | + function exp_pullback_hermitian(ΔX) |
| 124 | + ∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates) |
| 125 | + return NO_FIELDS, Matrix(∂hermA) |
| 126 | + end |
| 127 | + return Matrix(hermX), exp_pullback_hermitian |
| 128 | + else |
| 129 | + A = copy(A0) |
| 130 | + X, intermediates = _matfun!(exp, A) |
| 131 | + function exp_pullback(ΔX) |
| 132 | + ∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates) |
| 133 | + return NO_FIELDS, ∂A |
| 134 | + end |
| 135 | + return X, exp_pullback |
| 136 | + end |
| 137 | +end |
| 138 | + |
| 139 | +## Destructive matrix exponential using algorithm from Higham, 2008, |
| 140 | +## "Functions of Matrices: Theory and Computation", SIAM |
| 141 | +## Adapted from LinearAlgebra.exp! with return of intermediates |
| 142 | +## https://github.com/JuliaLang/julia/blob/f613b55/stdlib/LinearAlgebra/src/dense.jl#L583-L666 |
| 143 | +function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} |
| 144 | + n = LinearAlgebra.checksquare(A) |
| 145 | + ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A |
| 146 | + nA = opnorm(A, 1) |
| 147 | + Inn = Matrix{T}(I, n, n) |
| 148 | + ## For sufficiently small nA, use lower order Padé-Approximations |
| 149 | + if (nA <= 2.1) |
| 150 | + if nA > 0.95 |
| 151 | + C = T[ |
| 152 | + 17643225600.0, |
| 153 | + 8821612800.0, |
| 154 | + 2075673600.0, |
| 155 | + 302702400.0, |
| 156 | + 30270240.0, |
| 157 | + 2162160.0, |
| 158 | + 110880.0, |
| 159 | + 3960.0, |
| 160 | + 90.0, |
| 161 | + 1.0, |
| 162 | + ] |
| 163 | + elseif nA > 0.25 |
| 164 | + C = T[17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0] |
| 165 | + elseif nA > 0.015 |
| 166 | + C = T[30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0] |
| 167 | + else |
| 168 | + C = T[120.0, 60.0, 12.0, 1.0] |
| 169 | + end |
| 170 | + si = 0 |
| 171 | + else |
| 172 | + C = T[ |
| 173 | + 64764752532480000.0, |
| 174 | + 32382376266240000.0, |
| 175 | + 7771770303897600.0, |
| 176 | + 1187353796428800.0, |
| 177 | + 129060195264000.0, |
| 178 | + 10559470521600.0, |
| 179 | + 670442572800.0, |
| 180 | + 33522128640.0, |
| 181 | + 1323241920.0, |
| 182 | + 40840800.0, |
| 183 | + 960960.0, |
| 184 | + 16380.0, |
| 185 | + 182.0, |
| 186 | + 1.0, |
| 187 | + ] |
| 188 | + s = log2(nA / 5.4) # power of 2 later reversed by squaring |
| 189 | + si = ceil(Int, s) |
| 190 | + end |
| 191 | + |
| 192 | + if si > 0 |
| 193 | + A ./= convert(T, 2^si) |
| 194 | + end |
| 195 | + |
| 196 | + A2 = A * A |
| 197 | + P = copy(Inn) |
| 198 | + W = C[2] * P |
| 199 | + V = C[1] * P |
| 200 | + Apows = typeof(P)[] |
| 201 | + for k in 1:(div(size(C, 1), 2) - 1) |
| 202 | + k2 = 2 * k |
| 203 | + P *= A2 |
| 204 | + push!(Apows, P) |
| 205 | + W += C[k2 + 2] * P |
| 206 | + V += C[k2 + 1] * P |
| 207 | + end |
| 208 | + U = A * W |
| 209 | + X = V + U |
| 210 | + F = lu!(V - U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization |
| 211 | + ldiv!(F, X) |
| 212 | + Xpows = typeof(X)[X] |
| 213 | + if si > 0 # squaring to reverse dividing by power of 2 |
| 214 | + for t in 1:si |
| 215 | + X *= X |
| 216 | + push!(Xpows, X) |
| 217 | + end |
| 218 | + end |
| 219 | + |
| 220 | + _unbalance!(X, ilo, ihi, scale, n) |
| 221 | + return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) |
| 222 | +end |
| 223 | + |
| 224 | +# Application of the chain rule to exp!, also Algorithm 6.4 from |
| 225 | +# Al-Mohy, Awad H. and Higham, Nicholas J. (2009). |
| 226 | +# Computing the Fréchet Derivative of the Matrix Exponential, with an application to |
| 227 | +# Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657. |
| 228 | +# http://eprints.maths.manchester.ac.uk/id/eprint/1218 |
| 229 | +function _matfun_frechet!( |
| 230 | + ::typeof(exp), ΔA, A::StridedMatrix{T}, X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) |
| 231 | +) where {T<:BlasFloat} |
| 232 | + n = LinearAlgebra.checksquare(A) |
| 233 | + _balance!(ΔA, ilo, ihi, scale, n) |
| 234 | + |
| 235 | + if si > 0 |
| 236 | + ΔA ./= convert(T, 2^si) |
| 237 | + end |
| 238 | + |
| 239 | + ∂A2 = mul!(A * ΔA, ΔA, A, true, true) |
| 240 | + A2 = first(Apows) |
| 241 | + # we will repeatedly overwrite ∂temp and ∂P below |
| 242 | + ∂temp = Matrix{eltype(∂A2)}(undef, n, n) |
| 243 | + ∂P = copy(∂A2) |
| 244 | + ∂W = C[4] * ∂P |
| 245 | + ∂V = C[3] * ∂P |
| 246 | + for k in 2:(length(Apows) - 1) |
| 247 | + k2 = 2 * k |
| 248 | + P = Apows[k - 1] |
| 249 | + ∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P |
| 250 | + axpy!(C[k2 + 2], ∂P, ∂W) |
| 251 | + axpy!(C[k2 + 1], ∂P, ∂V) |
| 252 | + end |
| 253 | + ∂U, ∂temp = mul!(mul!(∂temp, A, ∂W), ΔA, W, true, true), ∂W |
| 254 | + ∂temp .= ∂U .- ∂V |
| 255 | + ∂X = add!!(∂U, ∂V) |
| 256 | + mul!(∂X, ∂temp, first(Xpows), true, true) |
| 257 | + ldiv!(F, ∂X) |
| 258 | + |
| 259 | + if si > 0 |
| 260 | + for t in 1:(length(Xpows) - 1) |
| 261 | + X = Xpows[t] |
| 262 | + ∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X |
| 263 | + end |
| 264 | + end |
| 265 | + |
| 266 | + _unbalance!(∂X, ilo, ihi, scale, n) |
| 267 | + return ∂X |
| 268 | +end |
| 269 | + |
| 270 | +##### |
| 271 | +##### utils |
| 272 | +##### |
| 273 | + |
| 274 | +# Given (ilo, ihi, iscale) returned by LAPACK.gebal!('B', A), apply same balancing to X |
| 275 | +function _balance!(X, ilo, ihi, scale, n) |
| 276 | + n = size(X, 1) |
| 277 | + if ihi < n |
| 278 | + for j in (ihi + 1):n |
| 279 | + LinearAlgebra.rcswap!(j, Int(scale[j]), X) |
| 280 | + end |
| 281 | + end |
| 282 | + if ilo > 1 |
| 283 | + for j in (ilo - 1):-1:1 |
| 284 | + LinearAlgebra.rcswap!(j, Int(scale[j]), X) |
| 285 | + end |
| 286 | + end |
| 287 | + |
| 288 | + for j in ilo:ihi |
| 289 | + scj = scale[j] |
| 290 | + for i in 1:n |
| 291 | + X[j, i] /= scj |
| 292 | + end |
| 293 | + for i in 1:n |
| 294 | + X[i, j] *= scj |
| 295 | + end |
| 296 | + end |
| 297 | + return X |
| 298 | +end |
| 299 | + |
| 300 | +# Reverse of _balance! |
| 301 | +function _unbalance!(X, ilo, ihi, scale, n) |
| 302 | + for j in ilo:ihi |
| 303 | + scj = scale[j] |
| 304 | + for i in 1:n |
| 305 | + X[j, i] *= scj |
| 306 | + end |
| 307 | + for i in 1:n |
| 308 | + X[i, j] /= scj |
| 309 | + end |
| 310 | + end |
| 311 | + |
| 312 | + if ilo > 1 |
| 313 | + for j in (ilo - 1):-1:1 |
| 314 | + LinearAlgebra.rcswap!(j, Int(scale[j]), X) |
| 315 | + end |
| 316 | + end |
| 317 | + if ihi < n |
| 318 | + for j in (ihi + 1):n |
| 319 | + LinearAlgebra.rcswap!(j, Int(scale[j]), X) |
| 320 | + end |
| 321 | + end |
| 322 | + return X |
| 323 | +end |
0 commit comments