diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index 6848a616..db16da69 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -1,11 +1,11 @@ # interfacing with LinearAlgebra standard library import LinearAlgebra -using LinearAlgebra: Transpose, Adjoint, +using LinearAlgebra: Transpose, Adjoint, AdjOrTrans, Hermitian, Symmetric, LowerTriangular, UnitLowerTriangular, UpperTriangular, UnitUpperTriangular, - MulAddMul, wrap + UpperOrLowerTriangular, MulAddMul, wrap # # BLAS 1 @@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) end +const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}} + +function LinearAlgebra.generic_trimatmul!( + C::oneStridedMatrix{T}, uplocA, isunitcA, + tfunA::Function, A::oneStridedMatrix{T}, + triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}}, +) where {T<:onemklFloat} + uplocB = LinearAlgebra.uplo_char(triB) + isunitcB = LinearAlgebra.isunit_char(triB) + B = parent(triB) + tfunB = LinearAlgebra.wrapperop(B) + transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C' + transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C' + if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper + triu!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower + tril!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N' + # operation is reversed to avoid executing the tranpose + triu!(A) + trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C) + elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' + tril!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' + triu!(B) + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) + elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N' + tril!(A) + trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C) + else + throw("mixed triangular-triangular multiplication") # TODO: rethink + end + return C +end + # triangular LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) + trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C) LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) -LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) -LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = - trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) + trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C) +LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} = + trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C) +LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = + trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C) diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl index 4e038372..e01ffd2b 100644 --- a/lib/mkl/wrappers_blas.jl +++ b/lib/mkl/wrappers_blas.jl @@ -1139,6 +1139,76 @@ function trsm(side::Char, trsm!(side, uplo, transa, diag, alpha, A, copy(B)) end +for (mmname_variant, smname_variant, elty) in + ((:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64), + (:onemklStrmm_variant, :onemklStrsm_variant, :Float32), + (:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64), + (:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32)) + @eval begin + function trmm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + beta::Number, + A::oneStridedMatrix{$elty}, + B::oneStridedMatrix{$elty}, + C::oneStridedMatrix{$elty}) + m, n = size(B) + mA, nA = size(A) + if mA != nA throw(DimensionMismatch("A must be square")) end + if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end + lda = max(1,stride(A,2)) + ldb = max(1,stride(B,2)) + ldc = max(1,stride(C,2)) + queue = global_queue(context(A), device()) + $mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc) + B + end + + function trsm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + beta::Number, + A::oneStridedMatrix{$elty}, + B::oneStridedMatrix{$elty}, + C::oneStridedMatrix{$elty}) + m, n = size(B) + mA, nA = size(A) + if mA != nA throw(DimensionMismatch("A must be square")) end + if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end + lda = max(1,stride(A,2)) + ldb = max(1,stride(B,2)) + ldc = max(1,stride(C,2)) + queue = global_queue(context(A), device()) + $smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc) + B + end + end +end +function trmm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + A::oneStridedMatrix{T}, + B::oneStridedMatrix{T}, + C::oneStridedMatrix{T}) where T + trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C) +end +function trsm!(side::Char, + uplo::Char, + transa::Char, + diag::Char, + alpha::Number, + A::oneStridedMatrix{T}, + B::oneStridedMatrix{T}, + C::oneStridedMatrix{T}) where T + trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C) +end + ## hemm for (fname, elty) in ((:onemklZhemm,:ComplexF64), (:onemklChemm,:ComplexF32)) diff --git a/test/onemkl.jl b/test/onemkl.jl index 19f06926..2f61bc7b 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -661,6 +661,14 @@ end # move to host and compare h_C = Array(dB) @test C ≈ h_C + + C = rand(T,m,n) + dC = oneArray(C) + beta = zero(T) # rand(T) + oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC) + h_C = Array(dC) + D = alpha*A*B + beta*C + @test D ≈ h_C end @testset "trmm" begin @@ -684,6 +692,14 @@ end dC = copy(dB) oneMKL.trsm!('L','U','N','N',alpha,dA,dC) @test C ≈ Array(dC) + + C = rand(T,m,n) + dC = oneArray(C) + beta = rand(T) + oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC) + h_C = Array(dC) + D = alpha*(A\B) + beta*C + @test D ≈ h_C end @testset "left trsm" begin @@ -725,6 +741,14 @@ end dC = copy(dA) oneMKL.trsm!('R','U','N','N',alpha,dB,dC) @test C ≈ Array(dC) + + C = rand(T,m,m) + dC = oneArray(C) + beta = rand(T) + oneMKL.trsm!('R','U','N','N',alpha,beta,dA,dB,dC) + h_C = Array(dC) + D = alpha*(A/B) + beta*C + @test D ≈ h_C end @testset "right trsm" begin