Skip to content

Commit 2f426e4

Browse files
goggleandreasnoack
authored andcommitted
Reduce code duplication for dot product of symmetric/Hermitian matrices (#33269)
1 parent 3135102 commit 2f426e4

File tree

1 file changed

+36
-71
lines changed

1 file changed

+36
-71
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 36 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -414,82 +414,47 @@ function triu(A::Symmetric, k::Integer=0)
414414
end
415415
end
416416

417-
function dot(A::Symmetric, B::Symmetric)
418-
n = size(A, 2)
419-
if n != size(B, 2)
420-
throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))"))
421-
end
422-
423-
dotprod = zero(dot(first(A), first(B)))
424-
@inbounds if A.uplo == 'U' && B.uplo == 'U'
425-
for j in 1:n
426-
for i in 1:(j - 1)
427-
dotprod += 2 * dot(A.data[i, j], B.data[i, j])
428-
end
429-
dotprod += dot(A[j, j], B[j, j])
430-
end
431-
elseif A.uplo == 'L' && B.uplo == 'L'
432-
for j in 1:n
433-
dotprod += dot(A[j, j], B[j, j])
434-
for i in (j + 1):n
435-
dotprod += 2 * dot(A.data[i, j], B.data[i, j])
436-
end
437-
end
438-
elseif A.uplo == 'U' && B.uplo == 'L'
439-
for j in 1:n
440-
for i in 1:(j - 1)
441-
dotprod += 2 * dot(A.data[i, j], transpose(B.data[j, i]))
442-
end
443-
dotprod += dot(A[j, j], B[j, j])
444-
end
445-
else
446-
for j in 1:n
447-
dotprod += dot(A[j, j], B[j, j])
448-
for i in (j + 1):n
449-
dotprod += 2 * dot(A.data[i, j], transpose(B.data[j, i]))
417+
for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:Hermitian, :adjoint, :real)]
418+
@eval begin
419+
function dot(A::$T, B::$T)
420+
n = size(A, 2)
421+
if n != size(B, 2)
422+
throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))"))
450423
end
451-
end
452-
end
453-
return dotprod
454-
end
455424

456-
function dot(A::Hermitian, B::Hermitian)
457-
n = size(A, 2)
458-
if n != size(B, 2)
459-
throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))"))
460-
end
461-
462-
dotprod = zero(dot(first(A), first(B)))
463-
@inbounds if A.uplo == 'U' && B.uplo == 'U'
464-
for j in 1:n
465-
for i in 1:(j - 1)
466-
dotprod += 2 * real(dot(A.data[i, j], B.data[i, j]))
467-
end
468-
dotprod += dot(A[j, j], B[j, j])
469-
end
470-
elseif A.uplo == 'L' && B.uplo == 'L'
471-
for j in 1:n
472-
dotprod += dot(A[j, j], B[j, j])
473-
for i in (j + 1):n
474-
dotprod += 2 * real(dot(A.data[i, j], B.data[i, j]))
475-
end
476-
end
477-
elseif A.uplo == 'U' && B.uplo == 'L'
478-
for j in 1:n
479-
for i in 1:(j - 1)
480-
dotprod += 2 * real(dot(A.data[i, j], adjoint(B.data[j, i])))
481-
end
482-
dotprod += dot(A[j, j], B[j, j])
483-
end
484-
else
485-
for j in 1:n
486-
dotprod += dot(A[j, j], B[j, j])
487-
for i in (j + 1):n
488-
dotprod += 2 * real(dot(A.data[i, j], adjoint(B.data[j, i])))
425+
dotprod = zero(dot(first(A), first(B)))
426+
@inbounds if A.uplo == 'U' && B.uplo == 'U'
427+
for j in 1:n
428+
for i in 1:(j - 1)
429+
dotprod += 2 * $real(dot(A.data[i, j], B.data[i, j]))
430+
end
431+
dotprod += dot(A[j, j], B[j, j])
432+
end
433+
elseif A.uplo == 'L' && B.uplo == 'L'
434+
for j in 1:n
435+
dotprod += dot(A[j, j], B[j, j])
436+
for i in (j + 1):n
437+
dotprod += 2 * $real(dot(A.data[i, j], B.data[i, j]))
438+
end
439+
end
440+
elseif A.uplo == 'U' && B.uplo == 'L'
441+
for j in 1:n
442+
for i in 1:(j - 1)
443+
dotprod += 2 * $real(dot(A.data[i, j], $trans(B.data[j, i])))
444+
end
445+
dotprod += dot(A[j, j], B[j, j])
446+
end
447+
else
448+
for j in 1:n
449+
dotprod += dot(A[j, j], B[j, j])
450+
for i in (j + 1):n
451+
dotprod += 2 * $real(dot(A.data[i, j], $trans(B.data[j, i])))
452+
end
453+
end
489454
end
455+
return dotprod
490456
end
491457
end
492-
return dotprod
493458
end
494459

495460
(-)(A::Symmetric) = Symmetric(-A.data, sym_uplo(A.uplo))

0 commit comments

Comments
 (0)