Skip to content

Commit 4cdb50b

Browse files
authored
Specialize cholesky for Diagonal (#384)
1 parent 57bf31a commit 4cdb50b

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/host/linalg.jl

+14
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ end
101101

102102
Base.copy(D::Diagonal{T, <:AbstractGPUArray{T, N}}) where {T, N} = Diagonal(copy(D.diag))
103103

104+
# prevent scalar indexing
105+
function LinearAlgebra.cholesky!(D::Diagonal{T, <:AbstractGPUArray{T, N}},
106+
::Val{false} = Val(false); check::Bool = true
107+
) where {T, N}
108+
info = 0
109+
if mapreduce(x -> isreal(x) && isposdef(x), &, D.diag)
110+
D.diag .= sqrt.(D.diag)
111+
else
112+
info = findfirst(x -> !isreal(x) || !isposdef(x), collect(D.diag))
113+
check && throw(PosDefException(info))
114+
end
115+
Cholesky(D, 'U', convert(LinearAlgebra.BlasInt, info))
116+
end
117+
104118

105119
## matrix multiplication
106120

test/testsuite/linalg.jl

+14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
end
2222
end
2323

24+
2425
@testset "symmetric" begin
2526
@testset "Hermitian" begin
2627
A = rand(Float32,2,2)
@@ -103,6 +104,19 @@
103104
@test collect(D) == collect(C)
104105
end
105106

107+
@testset "cholesky + Diagonal" begin
108+
n = 128
109+
d = AT(rand(Float32, n))
110+
D = Diagonal(d)
111+
F = collect(D)
112+
@test collect(cholesky(D).U) collect(cholesky(F).U)
113+
@test collect(cholesky(D).L) collect(cholesky(F).L)
114+
115+
d = AT([1f0, 2f0, -1f0, 0f0])
116+
D = Diagonal(d)
117+
@test cholesky(D, check = false).info == 3
118+
end
119+
106120
@testset "$f! with diagonal $d" for (f, f!) in ((triu, triu!), (tril, tril!)),
107121
d in -2:2
108122
A = randn(Float32, 10, 10)

0 commit comments

Comments
 (0)