Skip to content

Consistency of return type for matrix functions #12408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 10, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ expm(x::Number) = exp(x)
## "Functions of Matrices: Theory and Computation", SIAM
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
n = chksquare(A)
n<2 && return exp(A)
if ishermitian(A)
return full(expm(Hermitian(A)))
end
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = norm(A, 1)
I = eye(T,n)
Expand Down Expand Up @@ -278,10 +280,12 @@ function rcswap!{T<:Number}(i::Integer, j::Integer, X::StridedMatrix{T})
end
end

expm(x::Number) = exp(x)

function logm(A::StridedMatrix)
# If possible, use diagonalization
if ishermitian(A)
return logm(Hermitian(A))
return full(logm(Hermitian(A)))
end

# Use Schur decomposition
Expand Down Expand Up @@ -313,12 +317,13 @@ function logm(A::StridedMatrix)
return retmat
end
end

logm(a::Number) = (b = log(complex(a)); imag(b) == 0 ? real(b) : b)
logm(a::Complex) = log(a)

function sqrtm{T<:Real}(A::StridedMatrix{T})
if issym(A)
return sqrtm(Symmetric(A))
return full(sqrtm(Symmetric(A)))
end
n = chksquare(A)
SchurF = schurfact(complex(A))
Expand All @@ -328,13 +333,14 @@ function sqrtm{T<:Real}(A::StridedMatrix{T})
end
function sqrtm{T<:Complex}(A::StridedMatrix{T})
if ishermitian(A)
return sqrtm(Hermitian(A))
return full(sqrtm(Hermitian(A)))
end
n = chksquare(A)
SchurF = schurfact(A)
R = full(sqrtm(UpperTriangular(SchurF[:T])))
SchurF[:vectors]*R*SchurF[:vectors]'
end

sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b)
sqrtm(a::Complex) = sqrt(a)

Expand Down
2 changes: 1 addition & 1 deletion base/linalg/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function getindex(A::Union{Eigen,GeneralizedEigen}, d::Symbol)
throw(KeyError(d))
end

isposdef(A::Union{Eigen,GeneralizedEigen}) = all(A.values .> 0)
isposdef(A::Union{Eigen,GeneralizedEigen}) = isreal(A.values) && all(A.values .> 0)

function eigfact!{T<:BlasReal}(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true)
n = size(A, 2)
Expand Down
61 changes: 52 additions & 9 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,57 @@ function svdvals!{T<:Real,S}(A::Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{
end

#Matrix-valued functions
expm{T<:Real}(A::RealHermSymComplexHerm{T}) = (F = eigfact(A); F.vectors*Diagonal(exp(F.values))*F.vectors')
function logm{T<:Real}(A::RealHermSymComplexHerm{T})
F = eigfact(A)
isposdef(F) && return F.vectors*Diagonal(log(F.values))*F.vectors'
return F.vectors*Diagonal(log(complex(F.values)))*F.vectors'
function expm(A::Symmetric)
F = eigfact(full(A))
return Symmetric((F.vectors * Diagonal(exp(F.values))) * F.vectors')
end
function sqrtm{T<:Real}(A::RealHermSymComplexHerm{T})
F = eigfact(A)
isposdef(F) && return F.vectors*Diagonal(sqrt(F.values))*F.vectors'
return F.vectors*Diagonal(sqrt(complex(F.values)))*F.vectors'
function expm{T}(A::Hermitian{T})
n = chksquare(A)
F = eigfact(full(A))
retmat = (F.vectors * Diagonal(exp(F.values))) * F.vectors'
if T <: Real
return real(Hermitian(retmat))
else
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
end

for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])

@eval begin

function ($funm)(A::Symmetric)
F = eigfact(full(A))
if isposdef(F)
retmat = (F.vectors * Diagonal(($func)(F.values))) * F.vectors'
else
retmat = (F.vectors * Diagonal(($func)(complex(F.values)))) * F.vectors'
end
return Symmetric(retmat)
end

function ($funm){T}(A::Hermitian{T})
n = chksquare(A)
F = eigfact(A)
if isposdef(F)
retmat = (F.vectors * Diagonal(($func)(F.values))) * F.vectors'
if T <: Real
return Hermitian(retmat)
else
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
else
retmat = (F.vectors * Diagonal(($func)(complex(F.values)))) * F.vectors'
return retmat
end
end

end

end