Skip to content

Commit 37e60ca

Browse files
authored
Merge pull request #162 from fverdugo/spmv
Additional performance improvements in sparse matrix-vector product
2 parents e47b5ec + 79e395a commit 37e60ca

6 files changed

+112
-2
lines changed

CHANGELOG.md

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.5.1] - 2024-07-26
9+
10+
### Added
11+
12+
- Function `spmv!`.
13+
14+
### Fixed
15+
16+
- Performance improvements in sparse matrix-vector multiplication.
17+
818
## [0.5.0] - 2024-07-26
919

1020
### Changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PartitionedArrays"
22
uuid = "5a9dfac6-5c52-46f7-8278-5e2210713be9"
33
authors = ["Francesc Verdugo <[email protected]> and contributors"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
CircularArrays = "7a955b69-7140-5f4e-a0ed-f168c5e2e749"

src/PartitionedArrays.jl

+2
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ export dense_diag
153153
export dense_diag!
154154
export rap
155155
export rap!
156+
export spmv!
157+
export spmtv!
156158
export spmm
157159
export spmm!
158160
export spmtm

src/p_sparse_matrix.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1829,7 +1829,7 @@ function LinearAlgebra.mul!(c::PVector,a::PSparseMatrix,b::PVector)
18291829
return mul!(c,a,b,1,0)
18301830
end
18311831
t = consistent!(b)
1832-
foreach(mul!,own_values(c),own_own_values(a),own_values(b))
1832+
foreach(spmv!,own_values(c),own_own_values(a),own_values(b))
18331833
wait(t)
18341834
foreach(muladd!,own_values(c),own_ghost_values(a),ghost_values(b))
18351835
c

src/sparse_utils.jl

+83
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,86 @@ function csrr_to_csc_step_2(
597597
nothing
598598
end
599599

600+
@inline function spmv!(b,A,x)
601+
mul!(b,A,x)
602+
end
603+
604+
@inline function spmtv!(b,A,x)
605+
mul!(b,transpose(A),x)
606+
end
607+
608+
function spmv!(b,A::SparseMatrixCSR{1},x)
609+
@boundscheck begin
610+
@assert length(b) == size(A,1)
611+
@assert length(x) == size(A,2)
612+
end
613+
spmv_csr!(b,x,A.rowptr,A.colval,A.nzval)
614+
end
615+
616+
function spmtv!(b,A::SparseMatrixCSR{1},x)
617+
@boundscheck begin
618+
@assert length(b) == size(A,2)
619+
@assert length(x) == size(A,1)
620+
end
621+
spmv_csc!(b,x,A.rowptr,A.colval,A.nzval)
622+
end
623+
624+
function spmv!(b,A::SparseMatrixCSC,x)
625+
@boundscheck begin
626+
@assert length(b) == size(A,1)
627+
@assert length(x) == size(A,2)
628+
end
629+
spmv_csc!(b,x,A.colptr,A.rowval,A.nzval)
630+
end
631+
632+
function spmtv!(b,A::SparseMatrixCSC,x)
633+
@boundscheck begin
634+
@assert length(b) == size(A,2)
635+
@assert length(x) == size(A,1)
636+
end
637+
spmv_csr!(b,x,A.colptr,A.rowval,A.nzval)
638+
end
639+
640+
function spmv_csr!(b,x,rowptr_A,colval_A,nzval_A)
641+
ncols = length(x)
642+
nrows = length(b)
643+
u = one(eltype(rowptr_A))
644+
z = zero(eltype(b))
645+
@inbounds for row in 1:nrows
646+
pini = rowptr_A[row]
647+
pend = rowptr_A[row+1]
648+
bi = z
649+
p = pini
650+
while p < pend
651+
aij = nzval_A[p]
652+
col = colval_A[p]
653+
xj = x[col]
654+
bi += aij*xj
655+
p += u
656+
end
657+
b[row] = bi
658+
end
659+
b
660+
end
661+
662+
function spmv_csc!(b,x,colptr_A,rowval_A,nzval_A)
663+
ncols = length(x)
664+
nrows = length(b)
665+
u = one(eltype(colptr_A))
666+
z = zero(eltype(b))
667+
fill!(b,z)
668+
@inbounds for col in 1:ncols
669+
pini = colptr_A[col]
670+
pend = colptr_A[col+1]
671+
p = pini
672+
xj = x[col]
673+
while p < pend
674+
aij = nzval_A[p]
675+
row = rowval_A[p]
676+
b[row] += aij*xj
677+
p += u
678+
end
679+
end
680+
b
681+
end
682+

test/sparse_utils_tests.jl

+15
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ function test_mat(T)
2929
B = compresscoo(T,I,J,V,m,n)
3030
@test typeof(B) == T
3131
@test A == B
32+
33+
b1 = ones(Tv,size(B,1))
34+
b2 = ones(Tv,size(B,1))
35+
x = collect(Tv,1:size(B,2))
36+
mul!(b1,B,x)
37+
spmv!(b2,B,x)
38+
@test norm(b1-b2)/norm(b1) + 1 1
39+
40+
b1 = ones(Tv,size(B,2))
41+
b2 = ones(Tv,size(B,2))
42+
x = collect(Tv,1:size(B,1))
43+
mul!(b1,transpose(B),x)
44+
spmtv!(b2,B,x)
45+
@test norm(b1-b2)/norm(b1) + 1 1
46+
3247

3348
i,j,v = findnz(B)
3449
for (k,(ki,kj,kv)) in enumerate(nziterator(B))

0 commit comments

Comments
 (0)