Skip to content
This repository was archived by the owner on Mar 23, 2025. It is now read-only.

Commit 54a9286

Browse files
authored
allow switching backends (#69)
* allow switching gemm backend, add Octavian as an optional backend for non-BLAS elemet types.
1 parent bb98268 commit 54a9286

File tree

4 files changed

+103
-24
lines changed

4 files changed

+103
-24
lines changed

src/NDTensors.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ function __init__()
173173
enable_tblis()
174174
include("tblis.jl")
175175
end
176+
@require Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" begin
177+
include("octavian.jl")
178+
end
176179
end
177180

178181
end # module NDTensors

src/dense.jl

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
22
# Dense storage
33
#
4+
using LinearAlgebra: BlasFloat
45

56
struct Dense{ElT, VecT<:AbstractVector} <: TensorStorage{ElT}
67
data::VecT
@@ -441,39 +442,61 @@ function outer!(R::DenseTensor{ElR},
441442
return R
442443
end
443444

444-
# BLAS matmul
445+
export backend_auto, backend_blas, backend_generic
446+
447+
@eval struct GemmBackend{T}
448+
(f::Type{<:GemmBackend})() = $(Expr(:new, :f))
449+
end
450+
GemmBackend(s) = GemmBackend{Symbol(s)}()
451+
macro GemmBackend_str(s)
452+
:(GemmBackend{$(Expr(:quote, Symbol(s)))})
453+
end
454+
455+
const gemm_backend = Ref(:Auto)
456+
function backend_auto()
457+
gemm_backend[] = :Auto
458+
end
459+
function backend_blas()
460+
gemm_backend[] = :BLAS
461+
end
462+
function backend_generic()
463+
gemm_backend[] = :Generic
464+
end
465+
466+
@inline function auto_select_backend(::Type{<:StridedVecOrMat{<:BlasFloat}}, ::Type{<:StridedVecOrMat{<:BlasFloat}}, ::Type{<:StridedVecOrMat{<:BlasFloat}})
467+
GemmBackend(:BLAS)
468+
end
469+
470+
@inline function auto_select_backend(::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat})
471+
GemmBackend(:Generic)
472+
end
473+
445474
function _gemm!(tA, tB, alpha,
446-
A::AbstractVecOrMat{<:LinearAlgebra.BlasFloat},
447-
B::AbstractVecOrMat{<:LinearAlgebra.BlasFloat},
448-
beta, C::AbstractVecOrMat{<:LinearAlgebra.BlasFloat})
475+
A::TA,
476+
B::TB,
477+
beta, C::TC) where {TA<:AbstractVecOrMat, TB<:AbstractVecOrMat, TC<:AbstractVecOrMat}
478+
if gemm_backend[] == :Auto
479+
_gemm!(auto_select_backend(TA, TB, TC), tA, tB, alpha, A, B, beta, C)
480+
else
481+
_gemm!(GemmBackend(gemm_backend[]), tA, tB, alpha, A, B, beta, C)
482+
end
483+
end
484+
485+
# BLAS matmul
486+
function _gemm!(::GemmBackend{:BLAS}, tA, tB, alpha,
487+
A::AbstractVecOrMat,
488+
B::AbstractVecOrMat,
489+
beta, C::AbstractVecOrMat)
449490
#@timeit_debug timer "BLAS.gemm!" begin
450491
BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
451492
#end # @timeit
452493
end
453494

454495
# generic matmul
455-
function _gemm!(tA, tB, alpha::AT,
496+
function _gemm!(::GemmBackend{:Generic}, tA, tB, alpha::AT,
456497
A::AbstractVecOrMat, B::AbstractVecOrMat,
457498
beta::BT, C::AbstractVecOrMat) where {AT, BT}
458-
if tA == 'T'
459-
A = transpose(A)
460-
end
461-
if tB == 'T'
462-
B = transpose(B)
463-
end
464-
if beta == zero(BT)
465-
if alpha == one(AT)
466-
C .= A * B
467-
else
468-
C .= alpha .* (A * B)
469-
end
470-
else
471-
if alpha == one(AT)
472-
C .= (A * B) .+ beta .* C
473-
else
474-
C .= alpha .* (A * B) .+ beta .* C
475-
end
476-
end
499+
mul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta)
477500
return C
478501
end
479502

src/octavian.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using .Octavian
2+
3+
export backend_octavian
4+
5+
function backend_octavian()
6+
gemm_backend[] = :Octavian
7+
end
8+
9+
function _gemm!(::GemmBackend{:Octavian}, tA, tB, alpha,
10+
A::AbstractVecOrMat,
11+
B::AbstractVecOrMat,
12+
beta, C::AbstractVecOrMat)
13+
Octavian.matmul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta)
14+
end

test/dense.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using NDTensors,
22
Test
33

4+
@static if VERSION >= v"1.5"
5+
using Pkg
6+
Pkg.add("Octavian")
7+
using Octavian
8+
end
9+
410
@testset "Dense Tensors" begin
511

612
@testset "DenseTensor basic functionality" begin
@@ -210,6 +216,39 @@ end
210216
end
211217
end
212218

219+
@testset "change backends" begin
220+
a, b, c = [randn(5,5) for i=1:3]
221+
backend_auto()
222+
@test NDTensors.gemm_backend[] == :Auto
223+
@test NDTensors.auto_select_backend(typeof.((a, b, c))...) == NDTensors.GemmBackend(:BLAS)
224+
res1 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
225+
backend_blas()
226+
@test NDTensors.gemm_backend[] == :BLAS
227+
res2 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
228+
backend_generic()
229+
@test NDTensors.gemm_backend[] == :Generic
230+
res3 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
231+
@test res1 == res2
232+
@test res1 res3
233+
backend_auto()
234+
end
235+
236+
@static if VERSION >= v"1.5"
237+
@testset "change backends" begin
238+
a, b, c = [randn(5,5) for i=1:3]
239+
backend_auto()
240+
@test NDTensors.gemm_backend[] == :Auto
241+
@test NDTensors.auto_select_backend(typeof.((a, b, c))...) == NDTensors.GemmBackend(:BLAS)
242+
res1 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
243+
backend_octavian()
244+
@test NDTensors.gemm_backend[] == :Octavian
245+
res4 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c))
246+
@test res1 res4
247+
backend_auto()
248+
end
249+
end
250+
213251
end
252+
214253
nothing
215254

0 commit comments

Comments
 (0)