Skip to content

Commit f3ad067

Browse files
ranochaandreasnoack
authored andcommitted
Improve performance of generic dot products (#27678)
* improve performance of generic dot products * update generic dot as suggested by @haampie
1 parent 3791357 commit f3ad067

File tree

2 files changed

+20
-39
lines changed

2 files changed

+20
-39
lines changed

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -621,21 +621,6 @@ opnorm(v::TransposeAbsVec) = norm(v.parent)
621621

622622
norm(v::Union{TransposeAbsVec,AdjointAbsVec}, p::Real) = norm(v.parent, p)
623623

624-
function dot(x::AbstractArray, y::AbstractArray)
625-
lx = _length(x)
626-
if lx != _length(y)
627-
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(_length(y))."))
628-
end
629-
if lx == 0
630-
return dot(zero(eltype(x)), zero(eltype(y)))
631-
end
632-
s = zero(dot(first(x), first(y)))
633-
for (Ix, Iy) in zip(eachindex(x), eachindex(y))
634-
@inbounds s += dot(x[Ix], y[Iy])
635-
end
636-
s
637-
end
638-
639624
"""
640625
dot(x, y)
641626
x ⋅ y
@@ -678,14 +663,13 @@ function dot(x, y) # arbitrary iterables
678663
while true
679664
ix = iterate(x, xs)
680665
iy = iterate(y, ys)
681-
if (ix == nothing) || (iy == nothing)
682-
break
683-
end
666+
ix === nothing && break
667+
iy === nothing && break
684668
(vx, xs), (vy, ys) = ix, iy
685669
s += dot(vx, vy)
686670
end
687-
if !(iy == nothing && ix == nothing)
688-
throw(DimensionMismatch("x and y are of different lengths!"))
671+
if !(iy === nothing && ix === nothing)
672+
throw(DimensionMismatch("x and y are of different lengths!"))
689673
end
690674
return s
691675
end
@@ -709,24 +693,19 @@ julia> dot([im; im], [1; 1])
709693
0 - 2im
710694
```
711695
"""
712-
function dot(x::AbstractVector, y::AbstractVector)
713-
if length(LinearIndices(x)) != length(LinearIndices(y))
714-
throw(DimensionMismatch("dot product arguments have unequal lengths $(length(LinearIndices(x))) and $(length(LinearIndices(y)))"))
696+
function dot(x::AbstractArray, y::AbstractArray)
697+
lx = _length(x)
698+
if lx != _length(y)
699+
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(_length(y))."))
715700
end
716-
ix = iterate(x)
717-
if ix === nothing
718-
# we only need to check the first vector, since equal lengths have been asserted
701+
if lx == 0
719702
return dot(zero(eltype(x)), zero(eltype(y)))
720703
end
721-
iy = iterate(y)
722-
s = dot(ix[1], iy[1])
723-
ix, iy = iterate(x, ix[2]), iterate(y, iy[2])
724-
while ix != nothing
725-
s += dot(ix[1], iy[1])
726-
ix = iterate(x, ix[2])
727-
iy = iterate(y, iy[2])
704+
s = zero(dot(first(x), first(y)))
705+
for (Ix, Iy) in zip(eachindex(x), eachindex(y))
706+
@inbounds s += dot(x[Ix], y[Iy])
728707
end
729-
return s
708+
s
730709
end
731710

732711

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,16 +232,18 @@ end
232232
@test dot(Z, Z) == convert(elty, 34.0)
233233
end
234234

235-
dot_(x,y) = invoke(dot, Tuple{Any,Any}, x,y)
235+
dot1(x,y) = invoke(dot, Tuple{Any,Any}, x,y)
236+
dot2(x,y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x,y)
236237
@testset "generic dot" begin
237238
AA = [1+2im 3+4im; 5+6im 7+8im]
238239
BB = [2+7im 4+1im; 3+8im 6+5im]
239240
for A in (copy(AA), view(AA, 1:2, 1:2)), B in (copy(BB), view(BB, 1:2, 1:2))
240-
@test dot(A,B) == dot(vec(A),vec(B)) == dot_(A,B) == dot(float.(A),float.(B))
241-
@test dot(Int[], Int[]) == 0 == dot_(Int[], Int[])
241+
@test dot(A,B) == dot(vec(A),vec(B)) == dot1(A,B) == dot2(A,B) == dot(float.(A),float.(B))
242+
@test dot(Int[], Int[]) == 0 == dot1(Int[], Int[]) == dot2(Int[], Int[])
242243
@test_throws MethodError dot(Any[], Any[])
243-
@test_throws MethodError dot_(Any[], Any[])
244-
for n1 = 0:2, n2 = 0:2, d in (dot, dot_)
244+
@test_throws MethodError dot1(Any[], Any[])
245+
@test_throws MethodError dot2(Any[], Any[])
246+
for n1 = 0:2, n2 = 0:2, d in (dot, dot1, dot2)
245247
if n1 != n2
246248
@test_throws DimensionMismatch d(1:n1, 1:n2)
247249
else

0 commit comments

Comments
 (0)