diff --git a/Project.toml b/Project.toml index dd7deb9e..23dba828 100644 --- a/Project.toml +++ b/Project.toml @@ -6,8 +6,10 @@ version = "0.14.4" [deps] LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f" @@ -31,8 +33,10 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" +OhMyThreads = "0.7.0" PackageExtensionCompat = "1" Random = "1" +ScopedValues = "1.3.0" SparseArrays = "1" Strided = "2" TensorKitSectors = "0.1" diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index 4fd32e34..aa7ea5bc 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -200,7 +200,6 @@ TensorKit.add_transpose! ```@docs compose(::AbstractTensorMap, ::AbstractTensorMap) trace_permute! -contract! ⊗(::AbstractTensorMap, ::AbstractTensorMap) ``` diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 045f83c0..ade7ff6b 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -82,7 +82,7 @@ export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQ # tensor operations export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor -export scalar, add!, contract! +export scalar, add! # truncation schemes export notrunc, truncerr, truncdim, truncspace, truncbelow @@ -101,6 +101,8 @@ using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend const TO = TensorOperations using LRUCache +using OhMyThreads +using ScopedValues using TensorKitSectors import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗ @@ -184,6 +186,7 @@ include("spaces/vectorspaces.jl") #------------------------------------- # general definitions include("tensors/abstracttensor.jl") +include("tensors/backends.jl") include("tensors/blockiterator.jl") include("tensors/tensor.jl") include("tensors/adjoint.jl") diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 121e5a3c..082db9e4 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -142,7 +142,8 @@ function planarcontract!(C::AbstractTensorMap, α::Number, β::Number, backend, allocator) if BraidingStyle(sectortype(C)) == Bosonic() - return contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) + return TO.tensorcontract!(C, A, pA, false, B, pB, false, pAB, + α, β, backend, allocator) end codA, domA = codomainind(A), domainind(A) diff --git a/src/tensors/backends.jl b/src/tensors/backends.jl new file mode 100644 index 00000000..7218c942 --- /dev/null +++ b/src/tensors/backends.jl @@ -0,0 +1,84 @@ +# Scheduler implementation +# ------------------------ +""" + const blockscheduler = ScopedValue{Scheduler}(SerialScheduler()) + +The default scheduler used when looping over different blocks in the matrix representation of a +tensor. + +For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref). +""" +const blockscheduler = ScopedValue{Scheduler}(SerialScheduler()) + +""" + const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler()) + +The default scheduler used when looping over different subblocks in a tensor. + +For controlling this value, see also [`set_subblockscheduler`](@ref) and [`with_subblockscheduler`](@ref). +""" +const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler()) + +function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...) + return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs) + Threads.nthreads() == 1 ? SerialScheduler() : DynamicScheduler() + else + OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...) + end +end + +""" + with_blockscheduler(f, [scheduler]; kwargs...) + +Run `f` in a scope where the `blockscheduler` is determined by `scheduler` and `kwargs...`. + +See also [`with_subblockscheduler!`](@ref). +""" +@inline function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven(); + kwargs...) + @with blockscheduler => select_scheduler(scheduler; kwargs...) f() +end + +""" + with_subblockscheduler(f, [scheduler]; kwargs...) + +Run `f` in a scope where the [`subblockscheduler`](@ref) is determined by `scheduler` and `kwargs...`. +The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional +set of keywords arguments. For a detailed description, consult the +[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers). + +See also [`with_blockscheduler!`](@ref). +""" +@inline function with_subblockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven(); + kwargs...) + @with subblockscheduler => select_scheduler(scheduler; kwargs...) f() +end + +# Backend implementation +# ---------------------- +# TODO: figure out a name +# TODO: what should be the default scheduler? +@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend + arraybackend::B = TO.DefaultBackend() + blockscheduler::BS = blockscheduler[] + subblockscheduler::SBS = subblockscheduler[] +end + +function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap, + A::AbstractTensorMap) + return TensorKitBackend() +end +function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap, + A::AbstractTensorMap) + return TensorKitBackend() +end +function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap, + A::AbstractTensorMap, B::AbstractTensorMap) + return TensorKitBackend() +end + +function add_transform! end +function TO.select_backend(::typeof(add_transform!), C::AbstractTensorMap, + A::AbstractTensorMap) + return TensorKitBackend() +end diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index b4ec4b87..06576dc6 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype() Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T) Base.length(iter::BlockIterator) = length(iter.structure) Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...) + +Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 01c0f94c..37769c15 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -81,7 +81,7 @@ end d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...) n1 = d[1] * d[2] n2 = d[3] * d[4] - data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d) + data = sreshape(StridedView(blocktype(b)(undef, n1, n2)), d) fill!(data, zero(eltype(b))) if f₁.uncoupled == reverse(f₂.uncoupled) braiddict = artin_braid(f₂, 1; inv=b.adjoint) @@ -104,13 +104,27 @@ Base.copy(b::BraidingTensor) = b TensorMap(b::BraidingTensor) = copy!(similar(b), b) Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b) +# Blocks iterator +# --------------- +blocks(b::BraidingTensor) = BlockIterator(b, blocksectors(b)) +blocktype(::Type{TT}) where {TT<:BraidingTensor} = Matrix{eltype(TT)} + +# TODO: efficient iterator +function Base.iterate(iter::BlockIterator{<:BraidingTensor}, state...) + next = iterate(iter.structure, state...) + isnothing(next) && return next + c, state = next + return c => block(iter.t, c), state +end +@inline Base.getindex(iter::BlockIterator{<:BraidingTensor}, c::Sector) = block(iter.t, c) + function block(b::BraidingTensor, s::Sector) sectortype(b) == typeof(s) || throw(SectorMismatch()) # TODO: probably always square? m = blockdim(codomain(b), s) n = blockdim(domain(b), s) - data = Matrix{eltype(b)}(undef, (m, n)) + data = blocktype(b)(undef, (m, n)) length(data) == 0 && return data # s ∉ blocksectors(b) @@ -149,6 +163,30 @@ function block(b::BraidingTensor, s::Sector) return data end +# Linear Algebra +# -------------- +function LinearAlgebra.mul!(C::AbstractTensorMap, A::AbstractTensorMap, B::BraidingTensor, + α::Number, β::Number) + compose(space(A), space(B)) == space(C) || + throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))")) + levels = B.adjoint ? (1, 2, 3, 4) : (1, 2, 4, 3) + return add_braid!(C, A, ((1, 2), (4, 3)), levels, α, β) +end +function LinearAlgebra.mul!(C::AbstractTensorMap, A::BraidingTensor, B::AbstractTensorMap, + α::Number, β::Number) + compose(space(A), space(B)) == space(C) || + throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))")) + levels = A.adjoint ? (2, 1, 3, 4) : (1, 2, 3, 4) + return add_transpose!(C, B, ((2, 1), (3, 4)), levels, α, β) +end +# TODO: implement this? +function LinearAlgebra.mul!(C::AbstractTensorMap, A::BraidingTensor, B::BraidingTensor, + α::Number, β::Number) + compose(space(A), space(B)) == space(C) || + throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))")) + return mul!(C, TensorMap(A), B, α, β) +end + # Index manipulations # ------------------- has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false @@ -158,9 +196,9 @@ function add_transform!(tdst::AbstractTensorMap, fusiontreetransform, α::Number, β::Number, - backend::AbstractBackend...) + backend::TensorKitBackend, allocator) return add_transform!(tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β, - backend...) + backend, allocator) end # VectorInterface @@ -173,8 +211,8 @@ end function TO.tensoradd!(C::AbstractTensorMap, A::BraidingTensor, pA::Index2Tuple, conjA::Symbol, - α::Number, β::Number, backend=TO.DefaultBackend(), - allocator=TO.DefaultAllocator()) + α::Number, β::Number, backend::AbstractBackend, + allocator) return TO.tensoradd!(C, TensorMap(A), pA, conjA, α, β, backend, allocator) end diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index d1d16803..e067a36b 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -377,7 +377,7 @@ end #------------------------------------- """ add_permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, - α::Number, β::Number, backend::AbstractBackend...) + α::Number, β::Number, backend...) Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after permuting the indices of `tsrc` according to `(p₁, p₂)`. @@ -389,14 +389,14 @@ See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_tran p::Index2Tuple, α::Number, β::Number, - backend::AbstractBackend...) + backend...) transformer = treepermuter(tdst, tsrc, p) return add_transform!(tdst, tsrc, p, transformer, α, β, backend...) end """ add_braid!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, - levels::IndexTuple, α::Number, β::Number, backend::AbstractBackend...) + levels::IndexTuple, α::Number, β::Number, backend...) Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after braiding the indices of `tsrc` according to `(p₁, p₂)` and `levels`. @@ -409,7 +409,7 @@ See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transp levels::IndexTuple, α::Number, β::Number, - backend::AbstractBackend...) + backend...) length(levels) == numind(tsrc) || throw(ArgumentError("incorrect levels $levels for tensor map $(codomain(tsrc)) ← $(domain(tsrc))")) @@ -422,7 +422,7 @@ end """ add_transpose!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, - α::Number, β::Number, backend::AbstractBackend...) + α::Number, β::Number, backend...) Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after transposing the indices of `tsrc` according to `(p₁, p₂)`. @@ -434,18 +434,50 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad p::Index2Tuple, α::Number, β::Number, - backend::AbstractBackend...) + backend...) transformer = treetransposer(tdst, tsrc, p) return add_transform!(tdst, tsrc, p, transformer, α, β, backend...) end +# Implementation +# -------------- +""" + add_transform!(C, A, pA, transformer, α, β, [backend], [allocator]) + +Return the updated `C`, which is the result of adding `α * A` to `β * B`, +permuting the data with `pA` while transforming the fusiontrees with `transformer`. +""" +function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple, + transformer, α::Number, β::Number) + return add_transform!(C, A, pA, transformer, α, β, TO.DefaultBackend()) +end +function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple, + transformer, α::Number, β::Number, backend) + return add_transform!(C, A, pA, transformer, α, β, backend, TO.DefaultAllocator()) +end +function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple, + transformer, α::Number, β::Number, backend, allocator) + if backend isa TO.DefaultBackend + newbackend = TO.select_backend(add_transform!, C, A) + return add_transform!(C, A, pA, transformer, α, β, newbackend, allocator) + elseif backend isa TO.NoBackend # error for missing backend + TC = typeof(C) + TA = typeof(A) + throw(ArgumentError("No suitable backend found for `add_transform!` and tensor types $TC and $TA")) + else # error for unknown backend + TC = typeof(C) + TA = typeof(A) + throw(ArgumentError("Unknown backend $backend for `add_transform!` and tensor types $TC and $TA")) + end +end + function add_transform!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, transformer, α::Number, β::Number, - backend::AbstractBackend...) + backend::TensorKitBackend, allocator) @boundscheck begin permute(space(tsrc), (p₁, p₂)) == space(tdst) || throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), @@ -455,7 +487,7 @@ function add_transform!(tdst::AbstractTensorMap, if p₁ === codomainind(tsrc) && p₂ === domainind(tsrc) add!(tdst, tsrc, α, β) else - add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend...) + add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend, allocator) end return tdst @@ -467,8 +499,9 @@ function add_transform_kernel!(tdst::TensorMap, ::TrivialTreeTransformer, α::Number, β::Number, - backend::AbstractBackend...) - return TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend...) + backend::TensorKitBackend, allocator) + return TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend.arraybackend, + allocator) end function add_transform_kernel!(tdst::TensorMap, @@ -477,22 +510,23 @@ function add_transform_kernel!(tdst::TensorMap, transformer::AbelianTreeTransformer, α::Number, β::Number, - backend::AbstractBackend...) + backend::TensorKitBackend, allocator) structure_dst = transformer.structure_dst.fusiontreestructure structure_src = transformer.structure_src.fusiontreestructure - # TODO: this could be multithreaded - for (row, col, val) in zip(transformer.rows, transformer.cols, transformer.vals) + tforeach(transformer.rows, transformer.cols, transformer.vals; + scheduler=backend.subblockscheduler) do row, col, val sz_dst, str_dst, offset_dst = structure_dst[col] subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst) sz_src, str_src, offset_src = structure_src[row] subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src) - TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * val, β, backend...) + return TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * val, β, + backend.arraybackend, allocator) end - return nothing + return tdst end function add_transform_kernel!(tdst::TensorMap, @@ -501,15 +535,14 @@ function add_transform_kernel!(tdst::TensorMap, transformer::GenericTreeTransformer, α::Number, β::Number, - backend::AbstractBackend...) + backend::TensorKitBackend, allocator) structure_dst = transformer.structure_dst.fusiontreestructure structure_src = transformer.structure_src.fusiontreestructure rows = rowvals(transformer.matrix) vals = nonzeros(transformer.matrix) - # TODO: this could be multithreaded - for j in axes(transformer.matrix, 2) + tforeach(axes(transformer.matrix, 2); scheduler=backend.subblockscheduler) do j sz_dst, str_dst, offset_dst = structure_dst[j] subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst) nzrows = nzrange(transformer.matrix, j) @@ -519,14 +552,14 @@ function add_transform_kernel!(tdst::TensorMap, subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src) TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[first(nzrows)], β, - backend...) + backend.arraybackend, allocator) # treat remaining entries for i in @view(nzrows[2:end]) sz_src, str_src, offset_src = structure_src[rows[i]] subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src) TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[i], One(), - backend...) + backend.arraybackend, allocator) end end @@ -539,79 +572,80 @@ function add_transform_kernel!(tdst::AbstractTensorMap, fusiontreetransform::Function, α::Number, β::Number, - backend::AbstractBackend...) + backend::TensorKitBackend, allocator) I = sectortype(spacetype(tdst)) if I === Trivial - _add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...) + _add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, + backend, allocator) elseif FusionStyle(I) isa UniqueFusion - _add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...) + _add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, + backend, allocator) else - _add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...) + _add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, + backend, allocator) end return nothing end # internal methods: no argument types -function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...) - TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend...) +function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, + backend::TensorKitBackend, allocator) + TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend.arraybackend, allocator) return nothing end -function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...) - if Threads.nthreads() > 1 - Threads.@sync for (f₁, f₂) in fusiontrees(tsrc) - Threads.@spawn _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, - f₁, f₂, α, β, backend...) - end - else - for (f₁, f₂) in fusiontrees(tsrc) - _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, - f₁, f₂, α, β, backend...) - end +function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, + backend::TensorKitBackend, allocator) + tforeach(fusiontrees(tsrc); scheduler=backend.subblockscheduler) do (f₁, f₂) + return _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, + f₁, f₂, α, β, backend.arraybackend, allocator) end return nothing end -function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend...) +function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, + backend, allocator) (f₁′, f₂′), coeff = first(fusiontreetransform(f₁, f₂)) @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β, - backend...) + backend, allocator) return nothing end -function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...) +function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend, allocator) if iszero(β) tdst = zerovector!(tdst) elseif β != 1 tdst = scale!(tdst, β) end β′ = One() - if Threads.nthreads() > 1 - Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc)) - Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, - s₂, α, β′, backend...) - end - else + if backend.subblockscheduler isa SerialScheduler for (f₁, f₂) in fusiontrees(tsrc) for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂) @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, - β′, backend...) + β′, backend.arraybackend, allocator) end end + else + tforeach(collect(Iterators.product(sectors(codomain(tsrc)), sectors(domain(tsrc)))); + scheduler=backend.subblockscheduler) do (s₁, + s₂) + return _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, + β′, backend.arraybackend, allocator) + end end return nothing end # TODO: β argument is weird here because it has to be 1 function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β, - backend...) + backend, allocator) for (f₁, f₂) in fusiontrees(tsrc) (f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂) @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β, - backend...) + backend, allocator) end end return nothing diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 395c49f2..b7f0ee1c 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -309,12 +309,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap) end # TensorMap multiplication -function LinearAlgebra.mul!(tC::AbstractTensorMap, - tA::AbstractTensorMap, - tB::AbstractTensorMap, α=true, β=false) +function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, + tB::AbstractTensorMap, + α::Number, β::Number, + backend::AbstractBackend=TO.DefaultBackend()) + if backend isa TO.DefaultBackend + newbackend = TO.select_backend(mul!, tC, tA, tB) + return mul!(tC, tA, tB, α, β, newbackend) + elseif backend isa TO.NoBackend # error for missing backend + TC = typeof(tC) + TA = typeof(tA) + TB = typeof(tB) + throw(ArgumentError("No suitable backend found for `mul!` and tensor types $TC, $TA and $TB")) + else # error for unknown backend + TC = typeof(tC) + TA = typeof(tA) + TB = typeof(tB) + throw(ArgumentError("Unknown backend for `mul!` and tensor types $TC, $TA and $TB")) + end +end + +function TO.select_backend(::typeof(mul!), C::AbstractTensorMap, A::AbstractTensorMap, + B::AbstractTensorMap) + return TensorKitBackend() +end + +function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, + tB::AbstractTensorMap, α::Number, β::Number, + backend::TensorKitBackend) compose(space(tA), space(tB)) == space(tC) || throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))")) + scheduler = backend.blockscheduler + if isnothing(scheduler) + return sequential_mul!(tC, tA, tB, α, β) + else + return threaded_mul!(tC, tA, tB, α, β, scheduler) + end +end + +function sequential_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, + tB::AbstractTensorMap, α::Number, β::Number) iterC = blocks(tC) iterA = blocks(tA) iterB = blocks(tB) @@ -336,13 +371,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, elseif cB < cC nextB = iterate(iterB, stateB) else - if β != one(β) + if !isone(β) rmul!(C, β) end nextC = iterate(iterC, stateC) end else - if β != one(β) + if !isone(β) rmul!(C, β) end nextC = iterate(iterC, stateC) @@ -351,7 +386,23 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, return tC end -# TODO: consider spawning threads for different blocks, support backends +function threaded_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap, + α::Number, β::Number, scheduler::Scheduler) + # obtain cached data before multithreading + bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB) + + # Note: using blocksectors instead of blocks to support chunksplitting + # TODO: experiment with sorting/splitting strategies + tforeach(blocksectors(tC); scheduler) do c + if haskey(bAs, c) # then also bBs should have it + mul!(bCs[c], bAs[c], bBs[c], α, β) + elseif !isone(β) + scale!(bCs[c], β) + end + end + + return tC +end # TensorMap inverse function Base.inv(t::AbstractTensorMap) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index cf977690..ed36ec2b 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -94,17 +94,17 @@ function TO.tensorcontract!(C::AbstractTensorMap, pA′ = adjointtensorindices(A, pA) B′ = B' pB′ = adjointtensorindices(B, pB) - contract!(C, A′, pA′, B′, pB′, pAB′, α, β, backend, allocator) + TO.blas_contract!(C, A′, pA′, B′, pB′, pAB′, α, β, backend, allocator) elseif conjA A′ = A' pA′ = adjointtensorindices(A, pA) - contract!(C, A′, pA′, B, pB, pAB′, α, β, backend, allocator) + TO.blas_contract!(C, A′, pA′, B, pB, pAB′, α, β, backend, allocator) elseif conjB B′ = B' pB′ = adjointtensorindices(B, pB) - contract!(C, A, pA, B′, pB′, pAB′, α, β, backend, allocator) + TO.blas_contract!(C, A, pA, B′, pB′, pAB′, α, β, backend, allocator) else - contract!(C, A, pA, B, pB, pAB′, α, β, backend, allocator) + TO.blas_contract!(C, A, pA, B, pB, pAB′, α, β, backend, allocator) end return C end @@ -222,114 +222,105 @@ end # TODO: contraction with either A or B a rank (1, 1) tensor does not require to # permute the fusion tree and should therefore be special cased. This will speed # up MPS algorithms -""" - contract!(C::AbstractTensorMap, - A::AbstractTensorMap, (oindA, cindA)::Index2Tuple, - B::AbstractTensorMap, (cindB, oindB)::Index2Tuple, - (p₁, p₂)::Index2Tuple, - α::Number, β::Number, - backend, allocator) - -Return the updated `C`, which is the result of adding `α * A * B` to `C` after permuting -the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` respectively. -""" -function contract!(C::AbstractTensorMap, - A::AbstractTensorMap, (oindA, cindA)::Index2Tuple, - B::AbstractTensorMap, (cindB, oindB)::Index2Tuple, - (p₁, p₂)::Index2Tuple, - α::Number, β::Number, - backend, allocator) - length(cindA) == length(cindB) || - throw(IndexError("number of contracted indices does not match")) - N₁, N₂ = length(oindA), length(oindB) - - # find optimal contraction scheme - hsp = has_shared_permute - ipAB = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) - - qA = TupleTools.sortperm(cindA) - cindA′ = TupleTools.getindices(cindA, qA) - cindB′ = TupleTools.getindices(cindB, qA) - - qB = TupleTools.sortperm(cindB) - cindA′′ = TupleTools.getindices(cindA, qB) - cindB′′ = TupleTools.getindices(cindB, qB) - - dA, dB, dC = dim(A), dim(B), dim(C) - - # keep order A en B, check possibilities for cind - memcost1 = memcost2 = dC * (!hsp(C, (oindAinC, oindBinC))) - memcost1 += dA * (!hsp(A, (oindA, cindA′))) + - dB * (!hsp(B, (cindB′, oindB))) - memcost2 += dA * (!hsp(A, (oindA, cindA′′))) + - dB * (!hsp(B, (cindB′′, oindB))) - - # reverse order A en B, check possibilities for cind - memcost3 = memcost4 = dC * (!hsp(C, (oindBinC, oindAinC))) - memcost3 += dB * (!hsp(B, (oindB, cindB′))) + - dA * (!hsp(A, (cindA′, oindA))) - memcost4 += dB * (!hsp(B, (oindB, cindB′′))) + - dA * (!hsp(A, (cindA′′, oindA))) - - if min(memcost1, memcost2) <= min(memcost3, memcost4) - if memcost1 <= memcost2 - return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂) + +# this is a copy of the TensorOperations implementation, adding two ways to +# permute the contracted indices for 4 total possible implementations +function TO.blas_contract!(C::AbstractTensorMap, A::AbstractTensorMap, pA, + B::AbstractTensorMap, pB, pAB, α, β, backend, allocator) + # index permutations for reverse contraction + indCinoBA = let N₁ = TO.numout(pA), N₂ = TO.numin(pB) + map(n -> ifelse(n > N₁, n - N₁, n + N₂), TO.linearize(pAB)) + end + tpAB = TO.trivialpermutation(pAB) + pBA = (TupleTools.getindices(indCinoBA, tpAB[1]), + TupleTools.getindices(indCinoBA, tpAB[2])) + + # permutations of contracted indices + qA = TupleTools.sortperm(pA[2]) + pA′ = (pA[1], TupleTools.getindices(pA[2], qA)) + pB′ = (TupleTools.getindices(pB[1], qA), pB[2]) + + qB = TupleTools.sortperm(pB[1]) + pA″ = (pA[1], TupleTools.getindices(pA[2], qB)) + pB″ = (TupleTools.getindices(pB[1], qB), pB[2]) + + memcost1 = TO.contract_memcost(C, A, pA′, B, pB′, pAB) + memcost2 = TO.contract_memcost(C, A, pA″, B, pB″, pAB) + memcost3 = TO.contract_memcost(C, B, reverse(pB′), A, reverse(pA′), pBA) + memcost4 = TO.contract_memcost(C, B, reverse(pB″), A, reverse(pA″), pBA) + + return if min(memcost1, memcost2) ≤ min(memcost3, memcost4) + if memcost1 ≤ memcost2 + _blas_contract!(C, A, pA′, B, pB′, pAB, α, β, backend, allocator) else - return _contract!(α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂) + _blas_contract!(C, A, pA″, B, pB″, pAB, α, β, backend, allocator) end else - p1′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₁) - p2′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₂) - if memcost3 <= memcost4 - return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′) + if memcost3 ≤ memcost4 + _blas_contract!(C, B, reverse(pB′), A, reverse(pA′), pBA, α, β, backend, + allocator) else - return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′) + _blas_contract!(C, B, reverse(pB″), A, reverse(pA″), pBA, α, β, backend, + allocator) end end end -# TODO: also transform _contract! into new interface, and add backend support -function _contract!(α, A::AbstractTensorMap, B::AbstractTensorMap, - β, C::AbstractTensorMap, - oindA::IndexTuple, cindA::IndexTuple, - oindB::IndexTuple, cindB::IndexTuple, - p₁::IndexTuple, p₂::IndexTuple) - if !(BraidingStyle(sectortype(C)) isa SymmetricBraiding) - throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead")) - end - N₁, N₂ = length(oindA), length(oindB) - copyA = false - if BraidingStyle(sectortype(A)) isa Fermionic - for i in cindA - if !isdual(space(A, i)) - copyA = true - end - end +function TO.contract_memcost(C::AbstractTensorMap, A::AbstractTensorMap, pA, + B::AbstractTensorMap, pB, pAB) + ipAB = TO.oindABinC(pAB, pA, pB) + return dim(A) * !isblascontractable(A, pA) + dim(B) * !isblascontractable(B, pB) + + dim(C) * !isblasdestination(C, ipAB) +end + +# TODO: delibarately not importing private TO functions from here on out. Should we? +function _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) + TC = eltype(C) + + A_, pA, flagA = makeblascontractable(A, pA, TC, backend, allocator, true) + B_, pB, flagB = makeblascontractable(B, pB, TC, backend, allocator, false) + + ipAB = TO.oindABinC(pAB, pA, pB) + flagC = isblasdestination(C, ipAB) + if flagC + mul!(C, A_, B_, α, β, backend) + else + C_ = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator) + mul!(C_, A_, B_, One(), Zero(), backend) + TO.tensoradd!(C, C_, pAB, false, α, β, backend, allocator) + TO.tensorfree!(C_, allocator) end - A′ = permute(A, (oindA, cindA); copy=copyA) - B′ = permute(B, (cindB, oindB)) - if BraidingStyle(sectortype(A)) isa Fermionic - for i in domainind(A′) - if !isdual(space(A′, i)) - A′ = twist!(A′, i) + flagA || TO.tensorfree!(A_, allocator) + flagB || TO.tensorfree!(B_, allocator) + return C +end + +isblascontractable(A, pA) = (pA[1] == codomainind(A) && pA[2] == domainind(A)) +function isblasdestination(A::AbstractTensorMap, p::Index2Tuple) + return (p[1] == codomainind(A) && p[2] == domainind(A)) +end + +@inline function makeblascontractable(A::AbstractTensorMap, pA, TC, backend, allocator, + dotwist::Bool=false) + flagA = (scalartype(A) === TC && isblascontractable(A, pA) && !dotwist) + if !flagA + A_ = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator) + Anew = TO.tensoradd!(A_, A, pA, false, One(), Zero(), backend, allocator) + if dotwist && (BraidingStyle(sectortype(A)) isa Fermionic) + for i in domainind(Anew) + if !isdual(space(Anew, i)) + twist!(Anew, i) + end end + # TODO: this seems type-unstable: + # Anew = twist!(Anew, filter(i -> !isdual(space(Anew, i)), domainind(Anew))) end - # A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′))) - # commented version leads to boxing of `A′` and type instabilities in the result - end - ipAB = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) - if has_shared_permute(C, (oindAinC, oindBinC)) - C′ = permute(C, (oindAinC, oindBinC)) - mul!(C′, A′, B′, α, β) + pAnew = TO.trivialpermutation(pA) else - C′ = A′ * B′ - add_permute!(C, C′, (p₁, p₂), α, β) + Anew = A + pAnew = pA end - return C + return Anew, pAnew, flagA end # Scalar implementation diff --git a/test/runtests.jl b/test/runtests.jl index d0cd9945..56139563 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,8 +55,8 @@ sectorlist = (Z2Irrep, Z3Irrep, Z4Irrep, Z3Irrep ⊠ Z4Irrep, Z2Irrep ⊠ FibonacciAnyon ⊠ FibonacciAnyon) Ti = time() -include("fusiontrees.jl") -include("spaces.jl") +# include("fusiontrees.jl") +# include("spaces.jl") include("tensors.jl") include("diagonal.jl") include("planar.jl")