Skip to content
This repository was archived by the owner on Mar 23, 2025. It is now read-only.

allow switching backends #69

Merged
merged 5 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ function __init__()
enable_tblis()
include("tblis.jl")
end
@require Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" begin
include("octavian.jl")
end
end

end # module NDTensors
71 changes: 47 additions & 24 deletions src/dense.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Dense storage
#
using LinearAlgebra: BlasFloat

struct Dense{ElT, VecT<:AbstractVector} <: TensorStorage{ElT}
data::VecT
Expand Down Expand Up @@ -441,39 +442,61 @@ function outer!(R::DenseTensor{ElR},
return R
end

# BLAS matmul
export backend_auto, backend_blas, backend_generic

@eval struct GemmBackend{T}
(f::Type{<:GemmBackend})() = $(Expr(:new, :f))
end
GemmBackend(s) = GemmBackend{Symbol(s)}()
macro GemmBackend_str(s)
:(GemmBackend{$(Expr(:quote, Symbol(s)))})
end

const gemm_backend = Ref(:Auto)
function backend_auto()
gemm_backend[] = :Auto
end
function backend_blas()
gemm_backend[] = :BLAS
end
function backend_generic()
gemm_backend[] = :Generic
end

@inline function auto_select_backend(::Type{<:StridedVecOrMat{<:BlasFloat}}, ::Type{<:StridedVecOrMat{<:BlasFloat}}, ::Type{<:StridedVecOrMat{<:BlasFloat}})
GemmBackend(:BLAS)
end

@inline function auto_select_backend(::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat})
GemmBackend(:Generic)
end

function _gemm!(tA, tB, alpha,
A::AbstractVecOrMat{<:LinearAlgebra.BlasFloat},
B::AbstractVecOrMat{<:LinearAlgebra.BlasFloat},
beta, C::AbstractVecOrMat{<:LinearAlgebra.BlasFloat})
A::TA,
B::TB,
beta, C::TC) where {TA<:AbstractVecOrMat, TB<:AbstractVecOrMat, TC<:AbstractVecOrMat}
if gemm_backend[] == :Auto
_gemm!(auto_select_backend(TA, TB, TC), tA, tB, alpha, A, B, beta, C)
else
_gemm!(GemmBackend(gemm_backend[]), tA, tB, alpha, A, B, beta, C)
end
end

# BLAS matmul
function _gemm!(::GemmBackend{:BLAS}, tA, tB, alpha,
A::AbstractVecOrMat,
B::AbstractVecOrMat,
beta, C::AbstractVecOrMat)
#@timeit_debug timer "BLAS.gemm!" begin
BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
#end # @timeit
end

# generic matmul
function _gemm!(tA, tB, alpha::AT,
function _gemm!(::GemmBackend{:Generic}, tA, tB, alpha::AT,
A::AbstractVecOrMat, B::AbstractVecOrMat,
beta::BT, C::AbstractVecOrMat) where {AT, BT}
if tA == 'T'
A = transpose(A)
end
if tB == 'T'
B = transpose(B)
end
if beta == zero(BT)
if alpha == one(AT)
C .= A * B
else
C .= alpha .* (A * B)
end
else
if alpha == one(AT)
C .= (A * B) .+ beta .* C
else
C .= alpha .* (A * B) .+ beta .* C
end
end
mul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta)
return C
end

Expand Down
14 changes: 14 additions & 0 deletions src/octavian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using .Octavian

export backend_octavian

function backend_octavian()
gemm_backend[] = :Octavian
end

function _gemm!(::GemmBackend{:Octavian}, tA, tB, alpha,
A::AbstractVecOrMat,
B::AbstractVecOrMat,
beta, C::AbstractVecOrMat)
Octavian.matmul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta)
end
39 changes: 39 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using NDTensors,
Test

@static if VERSION >= v"1.5"
using Pkg
Pkg.add("Octavian")
using Octavian
end

@testset "Dense Tensors" begin

@testset "DenseTensor basic functionality" begin
Expand Down Expand Up @@ -210,6 +216,39 @@ end
end
end

@testset "change backends" begin
a, b, c = [randn(5,5) for i=1:3]
backend_auto()
@test NDTensors.gemm_backend[] == :Auto
@test NDTensors.auto_select_backend(typeof.((a, b, c))...) == NDTensors.GemmBackend(:BLAS)
res1 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
backend_blas()
@test NDTensors.gemm_backend[] == :BLAS
res2 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
backend_generic()
@test NDTensors.gemm_backend[] == :Generic
res3 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
@test res1 == res2
@test res1 ≈ res3
backend_auto()
end

@static if VERSION >= v"1.5"
@testset "change backends" begin
a, b, c = [randn(5,5) for i=1:3]
backend_auto()
@test NDTensors.gemm_backend[] == :Auto
@test NDTensors.auto_select_backend(typeof.((a, b, c))...) == NDTensors.GemmBackend(:BLAS)
res1 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
backend_octavian()
@test NDTensors.gemm_backend[] == :Octavian
res4 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
@test res1 ≈ res4
backend_auto()
end
end

end

nothing