Skip to content

Commit 623ebe9

Browse files
committed
Merge pull request #6096 from JuliaLang/anj/sparsemul
Fix A_mul_B! for y inputs with NaNs by setting elements to zero when β is zero.
2 parents 8ae5b00 + 7f18a51 commit 623ebe9

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

base/linalg/sparse.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ end
1313
function A_mul_B!::Number, A::SparseMatrixCSC, x::AbstractVector, β::Number, y::AbstractVector)
1414
A.n == length(x) || throw(DimensionMismatch(""))
1515
A.m == length(y) || throw(DimensionMismatch(""))
16-
for i = 1:A.m; y[i] *= β; end
16+
if β != 1
17+
β != 0 ? scale!(y,β) : fill!(y,zero(eltype(y)))
18+
end
1719
nzv = A.nzval
1820
rv = A.rowval
1921
for col = 1 : A.n
@@ -28,17 +30,24 @@ A_mul_B!(y::AbstractVector, A::SparseMatrixCSC, x::AbstractVector) = A_mul_B!(on
2830

2931
function *{TA,S,Tx}(A::SparseMatrixCSC{TA,S}, x::AbstractVector{Tx})
3032
T = promote_type(TA,Tx)
31-
A_mul_B!(one(T), A, x, zero(T), zeros(T, A.m))
33+
A_mul_B!(one(T), A, x, zero(T), Array(T, A.m))
3234
end
3335

3436
function Ac_mul_B!::Number, A::SparseMatrixCSC, x::AbstractVector, β::Number, y::AbstractVector)
3537
A.n == length(y) || throw(DimensionMismatch(""))
3638
A.m == length(x) || throw(DimensionMismatch(""))
3739
nzv = A.nzval
3840
rv = A.rowval
41+
zro = zero(eltype(y))
3942
@inbounds begin
4043
for i = 1 : A.n
41-
y[i] *= β
44+
if β != 1
45+
if β != 0
46+
y[i] *= β
47+
else
48+
y[i] = zro
49+
end
50+
end
4251
tmp = zero(eltype(y))
4352
for j = A.colptr[i] : (A.colptr[i+1]-1)
4453
tmp += conj(nzv[j])*x[rv[j]]
@@ -54,9 +63,16 @@ function At_mul_B!(α::Number, A::SparseMatrixCSC, x::AbstractVector, β::Number
5463
A.m == length(x) || throw(DimensionMismatch(""))
5564
nzv = A.nzval
5665
rv = A.rowval
66+
zro = zero(eltype(y))
5767
@inbounds begin
5868
for i = 1 : A.n
59-
y[i] *= β
69+
if β != 1
70+
if β != 0
71+
y[i] *= β
72+
else
73+
y[i] = zro
74+
end
75+
end
6076
tmp = zero(eltype(y))
6177
for j = A.colptr[i] : (A.colptr[i+1]-1)
6278
tmp += nzv[j]*x[rv[j]]
@@ -69,11 +85,11 @@ end
6985
At_mul_B!(y::AbstractVector, A::SparseMatrixCSC, x::AbstractVector) = At_mul_B!(one(eltype(x)), A, x, zero(eltype(y)), y)
7086
function Ac_mul_B{TA,S,Tx}(A::SparseMatrixCSC{TA,S}, x::AbstractVector{Tx})
7187
T = promote_type(TA, Tx)
72-
Ac_mul_B!(one(T), A, x, zero(T), zeros(T, A.n))
88+
Ac_mul_B!(one(T), A, x, zero(T), Array(T, A.n))
7389
end
7490
function At_mul_B{TA,S,Tx}(A::SparseMatrixCSC{TA,S}, x::AbstractVector{Tx})
7591
T = promote_type(TA, Tx)
76-
At_mul_B!(one(T), A, x, zero(T), zeros(T, A.n))
92+
At_mul_B!(one(T), A, x, zero(T), Array(T, A.n))
7793
end
7894

7995
*(X::BitArray{1}, A::SparseMatrixCSC) = invoke(*, (AbstractVector, SparseMatrixCSC), X, A)

0 commit comments

Comments
 (0)