diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 1f3829d3727bf..ca401f3c1c441 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -414,6 +414,84 @@ function triu(A::Symmetric, k::Integer=0) end end +function dot(A::Symmetric, B::Symmetric) + n = size(A, 2) + if n != size(B, 2) + throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) + end + + dotprod = zero(dot(first(A), first(B))) + @inbounds if A.uplo == 'U' && B.uplo == 'U' + for j in 1:n + for i in 1:(j - 1) + dotprod += 2 * dot(A.data[i, j], B.data[i, j]) + end + dotprod += dot(A[j, j], B[j, j]) + end + elseif A.uplo == 'L' && B.uplo == 'L' + for j in 1:n + dotprod += dot(A[j, j], B[j, j]) + for i in (j + 1):n + dotprod += 2 * dot(A.data[i, j], B.data[i, j]) + end + end + elseif A.uplo == 'U' && B.uplo == 'L' + for j in 1:n + for i in 1:(j - 1) + dotprod += 2 * dot(A.data[i, j], transpose(B.data[j, i])) + end + dotprod += dot(A[j, j], B[j, j]) + end + else + for j in 1:n + dotprod += dot(A[j, j], B[j, j]) + for i in (j + 1):n + dotprod += 2 * dot(A.data[i, j], transpose(B.data[j, i])) + end + end + end + return dotprod +end + +function dot(A::Hermitian, B::Hermitian) + n = size(A, 2) + if n != size(B, 2) + throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) + end + + dotprod = zero(dot(first(A), first(B))) + @inbounds if A.uplo == 'U' && B.uplo == 'U' + for j in 1:n + for i in 1:(j - 1) + dotprod += 2 * real(dot(A.data[i, j], B.data[i, j])) + end + dotprod += dot(A[j, j], B[j, j]) + end + elseif A.uplo == 'L' && B.uplo == 'L' + for j in 1:n + dotprod += dot(A[j, j], B[j, j]) + for i in (j + 1):n + dotprod += 2 * real(dot(A.data[i, j], B.data[i, j])) + end + end + elseif A.uplo == 'U' && B.uplo == 'L' + for j in 1:n + for i in 1:(j - 1) + dotprod += 2 * real(dot(A.data[i, j], adjoint(B.data[j, i]))) + end + dotprod += dot(A[j, j], B[j, j]) + end + else + for j in 1:n + dotprod += dot(A[j, j], B[j, j]) + for i in (j + 1):n + dotprod += 2 * real(dot(A.data[i, j], adjoint(B.data[j, i]))) + end + end + end + return dotprod +end + (-)(A::Symmetric) = Symmetric(-A.data, sym_uplo(A.uplo)) (-)(A::Hermitian) = Hermitian(-A.data, sym_uplo(A.uplo)) diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index ebdfc9c93c72f..c8984be219fd3 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -377,6 +377,40 @@ end end end end + + @testset "dot product of symmetric and Hermitian matrices" begin + for mtype in (Symmetric, Hermitian) + symau = mtype(a, :U) + symal = mtype(a, :L) + msymau = Matrix(symau) + msymal = Matrix(symal) + @test_throws DimensionMismatch dot(symau, mtype(zeros(eltya, n-1, n-1))) + for eltyc in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int) + creal = randn(n, n)/2 + cimag = randn(n, n)/2 + c = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(creal, cimag) : creal) + symcu = mtype(c, :U) + symcl = mtype(c, :L) + msymcu = Matrix(symcu) + msymcl = Matrix(symcl) + @test dot(symau, symcu) ≈ dot(msymau, msymcu) + @test dot(symau, symcl) ≈ dot(msymau, msymcl) + @test dot(symal, symcu) ≈ dot(msymal, msymcu) + @test dot(symal, symcl) ≈ dot(msymal, msymcl) + end + + # block matrices + blockm = [eltya == Int ? rand(1:7, 3, 3) : convert(Matrix{eltya}, eltya <: Complex ? complex.(randn(3, 3)/2, randn(3, 3)/2) : randn(3, 3)/2) for _ in 1:3, _ in 1:3] + symblockmu = mtype(blockm, :U) + symblockml = mtype(blockm, :L) + msymblockmu = Matrix(symblockmu) + msymblockml = Matrix(symblockml) + @test dot(symblockmu, symblockmu) ≈ dot(msymblockmu, msymblockmu) + @test dot(symblockmu, symblockml) ≈ dot(msymblockmu, msymblockml) + @test dot(symblockml, symblockmu) ≈ dot(msymblockml, msymblockmu) + @test dot(symblockml, symblockml) ≈ dot(msymblockml, msymblockml) + end + end end end