Skip to content

Commit a21980b

Browse files
committed
Add macros to generate matrix functions for symmetric and Hermitian matrices, fix type returned by logm and sqrtm
1 parent e66175b commit a21980b

File tree

3 files changed

+63
-14
lines changed

3 files changed

+63
-14
lines changed

base/linalg/dense.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ expm(x::Number) = exp(x)
189189
## "Functions of Matrices: Theory and Computation", SIAM
190190
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
191191
n = chksquare(A)
192-
n<2 && return exp(A)
192+
if ishermitian(A)
193+
return full(expm(Hermitian(A)))
194+
end
193195
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
194196
nA = norm(A, 1)
195197
I = eye(T,n)
@@ -278,10 +280,12 @@ function rcswap!{T<:Number}(i::Integer, j::Integer, X::StridedMatrix{T})
278280
end
279281
end
280282

283+
expm(x::Number) = exp(x)
284+
281285
function logm(A::StridedMatrix)
282286
# If possible, use diagonalization
283287
if ishermitian(A)
284-
return logm(Hermitian(A))
288+
return full(logm(Hermitian(A)))
285289
end
286290

287291
# Use Schur decomposition
@@ -313,12 +317,13 @@ function logm(A::StridedMatrix)
313317
return retmat
314318
end
315319
end
320+
316321
logm(a::Number) = (b = log(complex(a)); imag(b) == 0 ? real(b) : b)
317322
logm(a::Complex) = log(a)
318323

319324
function sqrtm{T<:Real}(A::StridedMatrix{T})
320325
if issym(A)
321-
return sqrtm(Symmetric(A))
326+
return full(sqrtm(Symmetric(A)))
322327
end
323328
n = chksquare(A)
324329
SchurF = schurfact(complex(A))
@@ -328,13 +333,14 @@ function sqrtm{T<:Real}(A::StridedMatrix{T})
328333
end
329334
function sqrtm{T<:Complex}(A::StridedMatrix{T})
330335
if ishermitian(A)
331-
return sqrtm(Hermitian(A))
336+
return full(sqrtm(Hermitian(A)))
332337
end
333338
n = chksquare(A)
334339
SchurF = schurfact(A)
335340
R = full(sqrtm(UpperTriangular(SchurF[:T])))
336341
SchurF[:vectors]*R*SchurF[:vectors]'
337342
end
343+
338344
sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b)
339345
sqrtm(a::Complex) = sqrt(a)
340346

base/linalg/eigen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function getindex(A::Union{Eigen,GeneralizedEigen}, d::Symbol)
2323
throw(KeyError(d))
2424
end
2525

26-
isposdef(A::Union{Eigen,GeneralizedEigen}) = all(A.values .> 0)
26+
isposdef(A::Union{Eigen,GeneralizedEigen}) = isreal(A.values) && all(A.values .> 0)
2727

2828
function eigfact!{T<:BlasReal}(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true)
2929
n = size(A, 2)

base/linalg/symmetric.jl

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,57 @@ function svdvals!{T<:Real,S}(A::Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{
126126
end
127127

128128
#Matrix-valued functions
129-
expm{T<:Real}(A::RealHermSymComplexHerm{T}) = (F = eigfact(A); F.vectors*Diagonal(exp(F.values))*F.vectors')
130-
function logm{T<:Real}(A::RealHermSymComplexHerm{T})
131-
F = eigfact(A)
132-
isposdef(F) && return F.vectors*Diagonal(log(F.values))*F.vectors'
133-
return F.vectors*Diagonal(log(complex(F.values)))*F.vectors'
129+
function expm(A::Symmetric)
130+
F = eigfact(full(A))
131+
return Symmetric(F.vectors * Diagonal(exp(F.values)) * F.vectors')
134132
end
135-
function sqrtm{T<:Real}(A::RealHermSymComplexHerm{T})
136-
F = eigfact(A)
137-
isposdef(F) && return F.vectors*Diagonal(sqrt(F.values))*F.vectors'
138-
return F.vectors*Diagonal(sqrt(complex(F.values)))*F.vectors'
133+
function expm{T}(A::Hermitian{T})
134+
n = chksquare(A)
135+
F = eigfact(full(A))
136+
retmat = F.vectors * Diagonal(exp(F.values)) * F.vectors'
137+
if T <: Real
138+
return real(Hermitian(retmat))
139+
else
140+
for i = 1:n
141+
retmat[i,i] = real(retmat[i,i])
142+
end
143+
return Hermitian(retmat)
144+
end
145+
end
146+
147+
for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])
148+
149+
@eval begin
150+
151+
function ($funm)(A::Symmetric)
152+
F = eigfact(full(A))
153+
if isposdef(F)
154+
retmat = F.vectors * Diagonal(($func)(F.values)) * F.vectors'
155+
else
156+
retmat = F.vectors * Diagonal(($func)(complex(F.values))) * F.vectors'
157+
end
158+
return Symmetric(retmat)
159+
end
160+
161+
function ($funm){T}(A::Hermitian{T})
162+
n = chksquare(A)
163+
F = eigfact(A)
164+
if isposdef(F)
165+
retmat = F.vectors * Diagonal(($func)(F.values)) * F.vectors'
166+
if T <: Real
167+
return real(Hermitian(retmat))
168+
else
169+
for i = 1:n
170+
retmat[i,i] = real(retmat[i,i])
171+
end
172+
return Hermitian(retmat)
173+
end
174+
else
175+
retmat = F.vectors * Diagonal(($func)(complex(F.values))) * F.vectors'
176+
return retmat
177+
end
178+
end
179+
180+
end
181+
139182
end

0 commit comments

Comments
 (0)