Skip to content

Planar additions [WIP] #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d32b1c8
Add rrule planaradd!
lkdvos Jan 21, 2024
276f3bb
Add rrule `planarcontract`
lkdvos Jan 21, 2024
93db861
Add `planarcontract` (without `!`)
lkdvos Jan 21, 2024
d84ffb4
Fix TensorKitChainRulesCoreExt imports
lkdvos Jan 21, 2024
974173d
painful `reorder_indices` rewrite
lkdvos Jan 29, 2024
31bd0b3
Fix some things and make planar tests run
lkdvos Jan 30, 2024
13e4cdf
more fixes and rewrites
lkdvos Jan 30, 2024
3ed5ecf
more fixes and tests
lkdvos Jan 30, 2024
d38867d
Ad improvements
lkdvos Jan 30, 2024
e405b61
import `getindices`
lkdvos Jan 30, 2024
b42d119
Add planarcontract_indices
lkdvos May 14, 2024
ccb9ede
Move planar index functions to separate file
lkdvos May 14, 2024
a168fe9
Add planarcontract with conj flags
lkdvos May 14, 2024
fdad8f7
Add some extra checks in planarcontract implementation
lkdvos May 14, 2024
fccb958
Also move auxiliary index functions
lkdvos May 14, 2024
7eff5e3
Add planarcontract function
lkdvos May 14, 2024
ec18564
Update planarcontract_indices
lkdvos May 14, 2024
a20c2eb
Add some more auxiliary methods for tensor indices
lkdvos May 15, 2024
c2d7d13
Add planarcopy and planartrace
lkdvos May 15, 2024
0f8544e
more planar index stuff
lkdvos May 15, 2024
b50ad10
clean up planaroperations
lkdvos May 15, 2024
42db3f0
cherrypick some fix from master
lkdvos May 15, 2024
a5106f4
Add plancon
lkdvos May 15, 2024
d6b625b
update planar tests
lkdvos May 15, 2024
efd30e3
Merge branch 'master' into ld/planar-ad
lkdvos May 15, 2024
8bd8a15
Updates to make planar AD work
lkdvos May 17, 2024
6de6511
Prevent tensoroperations test on anyonic tests
lkdvos May 17, 2024
f9847a1
Fix `Base.summary(::TensorMap)`
lkdvos May 17, 2024
e8b900d
make otimes AD planar-compatible
lkdvos May 17, 2024
236e3f1
Formatter
lkdvos May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 141 additions & 7 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using TensorOperations: Backend, promote_contract
using TensorKit
using TensorKit: planaradd!, planarcontract!, planarcontract, _canonicalize
using VectorInterface
using ChainRulesCore
using LinearAlgebra
using TupleTools
using TupleTools: getindices

# Utility
# -------

_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)
trivtuple(::Index2Tuple{N₁,N₂}) where {N₁,N₂} = trivtuple(N₁ + N₂)

function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
Expand Down Expand Up @@ -112,18 +117,16 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C)
promote_contract(scalartype(ΔC), scalartype(B)))
dA = planarcontract!(dA, ΔC, pΔC, :N, B, pB, :C, ipA, One(), Zero())
return projectA(dA)
end
dB_ = @thunk begin
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(A)))
dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N)
promote_contract(scalartype(ΔC), scalartype(A)))
dB = planarcontract!(dB, A, pA, :C, ΔC, pΔC, :N, ipB, One(), Zero())
return projectB(dB)
end
return NoTangent(), dA_, dB_
Expand All @@ -134,7 +137,7 @@ end
function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple;
copy::Bool=false)
function permute_pullback(Δtdst)
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc)
invp = _canonicalize(TupleTools.invperm(linearize(p)), tsrc)
return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent()
end
return permute(tsrc, p; copy=true), permute_pullback
Expand Down Expand Up @@ -632,6 +635,137 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
return ΔA
end

# Planar rrules
# --------------
function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!),
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol,
α::Number, β::Number,
backend::Backend...)
C′ = planaradd!(copy(C), A, pA, conjA, α, β, backend...)

projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function planaradd_pullback(ΔC′)
ΔC = unthunk(ΔC′)

dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ip = _canonicalize(invperm(linearize(pA)), A)
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = planaradd!(_dA, ΔC, ip, conjA, conjA == :N ? conj(α) : α, Zero(),
backend...)
return projectA(_dA)
end
dα = @thunk begin
_dα = tensorscalar(planarcontract(A, ((), linearize(pA)), _conj(conjA),
ΔC, (trivtuple(pA), ()), :N,
((), ()), One(), backend...))
return projectα(_dα)
end
dβ = @thunk begin
_dβ = tensorscalar(planarcontract(C,
((), trivtuple(TensorOperations.numind(pA))),
:C,
ΔC, (trivtuple(pA), ()), :N,
((), ()), One(), backend...))
return projectβ(_dβ)
end
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dbackend...
end

return C′, planaradd_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!),
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol,
pAB::Index2Tuple,
α::Number, β::Number, backend::Backend...)
# indA = (codomainind(A), reverse(domainind(A)))
# indB = (codomainind(B), reverse(domainind(B)))
# pA, pB, pAB = TensorKit.reorder_planar_indices(indA, pA, indB, pB, pAB)
C′ = planarcontract!(copy(C), A, pA, conjA, B, pB, conjB, pAB, α, β, backend...)

projectA = ProjectTo(A)
projectB = ProjectTo(B)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function planarcontract_pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipAB = invperm(linearize(pAB))
pΔC = (getindices(ipAB, trivtuple(length(pA[1]))),
getindices(ipAB, length(pA[1]) .+ trivtuple(length(pB[2]))))
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipA = _canonicalize(invperm(linearize(pA)), A)
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
_dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α)))
_dA = planarcontract!(_dA, ΔC, pΔC, conjΔC, B, reverse(pB), conjB′, ipA,
conjA == :C ? α : conj(α), Zero(), backend...)
return projectA(_dA)
end
dB = @thunk begin
ipB = _canonicalize((invperm(linearize(pB)), ()), B)
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
_dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α)))
_dB = planarcontract!(_dB,
A, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
ipB, conjB == :C ? α : conj(α), Zero(), backend...)
return projectB(_dB)
end
dα = @thunk begin
_dα = tensorscalar(planarcontract(planarcontract(A, pA, conjA,
B, pB, conjB,
pAB, One(), backend...),
((), trivtuple(TensorOperations.numind(pAB))),
:C,
ΔC,
(trivtuple(TensorOperations.numind(pAB)), ()),
:N,
((), ()), One(), backend...))
return projectα(_dα)
end
dβ = @thunk begin
p′ = TensorKit.adjointtensorindices(C, trivtuple(pAB))
_dβ = tensorscalar(planarcontract(C', ((), p′),
ΔC, (trivtuple(pAB), ()), ((), ()),
One(), backend...))
return projectβ(_dβ)
end
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
NoTangent(),
dα, dβ, dbackend...
end

return C′, planarcontract_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.planartrace!),
C::AbstractTensorMap,
A::AbstractTensorMap,
p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
α::Number, β::Number, backend::Backend...)
C′ = planartrace!(copy(C), A, p, q, conjA, α, β, backend...)

function planartrace_pullback(ΔC′)
return ΔC = unthunk(ΔC′)
end

return C′, planartrace_pullback
end

# Convert rrules
#----------------
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
Expand Down
7 changes: 5 additions & 2 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQ
SVD, SDD, Polar

# tensor operations
export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor
export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor, plancon
export scalar, add!, contract!

# truncation schemes
Expand All @@ -88,7 +88,7 @@ export notrunc, truncerr, truncdim, truncspace, truncbelow
# Imports
#---------
using TupleTools
using TupleTools: StaticLength
using TupleTools: StaticLength, getindices

using Strided

Expand Down Expand Up @@ -196,12 +196,15 @@ include("tensors/braidingtensor.jl")
# #-----------------------------------------
@nospecialize
using Base.Meta: isexpr
include("planar/indices.jl")
include("planar/analyzers.jl")
include("planar/preprocessors.jl")
include("planar/postprocessors.jl")
include("planar/macros.jl")
@specialize
include("planar/planaroperations.jl")
include("planar/functions.jl")
include("planar/plancon.jl")

# deprecations: to be removed in version 1.0 or sooner
include("auxiliary/deprecate.jl")
Expand Down
12 changes: 8 additions & 4 deletions src/fusiontrees/manipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ function iscyclicpermutation(v1, v2)
length(v1) == length(v2) || return false
return iscyclicpermutation(indexin(v1, v2))
end
function iscyclicpermutation(v1::Tuple, v2::Tuple)
length(v1) == length(v2) || return false
return iscyclicpermutation(TupleTools.indexin(v1, v2))
end

# clockwise cyclic permutation while preserving (N₁, N₂): foldright & bendleft
function cycleclockwise(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I<:Sector}
Expand Down Expand Up @@ -580,11 +584,11 @@ function planar_trace(f₁::FusionTree{I}, f₂::FusionTree{I},
linearindex = (ntuple(identity, Val(length(f₁)))...,
reverse(length(f₁) .+ ntuple(identity, Val(length(f₂))))...)

q1′ = TupleTools.getindices(linearindex, q1)
q2′ = TupleTools.getindices(linearindex, q2)
q1′ = getindices(linearindex, q1)
q2′ = getindices(linearindex, q2)
p1′, p2′ = let q′ = (q1′..., q2′...)
(map(l -> l - count(l .> q′), TupleTools.getindices(linearindex, p1)),
map(l -> l - count(l .> q′), TupleTools.getindices(linearindex, p2)))
(map(l -> l - count(l .> q′), getindices(linearindex, p1)),
map(l -> l - count(l .> q′), getindices(linearindex, p2)))
end

u = one(I)
Expand Down
77 changes: 77 additions & 0 deletions src/planar/functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# methods/simple.jl
#
# Method-based access to planar operations using simple definitions.

# ------------------------------------------------------------------------------------------

function planarcopy(A, pA::Index2Tuple, conjA::Symbol, α::Number=One(), backend::Backend...)
TC = TO.promote_add(scalartype(A), scalartype(α))
C = tensoralloc_add(TC, pA, A, conjA)
return planaradd!(C, A, pA, conjA, α, Zero(), backend...)
end

# ------------------------------------------------------------------------------------------

function planartrace(A, pA::Index2Tuple, qA::Index2Tuple, conjA::Symbol, α::Number=One(),
backend::Backend...)
TC = TO.promote_contract(scalartype(A), scalartype(α))
C = tensoralloc_add(TC, pA, A, conjA)
return planartrace!(C, A, pA, qA, conjA, α, Zero(), backend...)
end

# ------------------------------------------------------------------------------------------

"""
planarcontract(A, IA, [conjA], B, IB, [conjB], [IC], [α=1])
planarcontract(A, pA::Index2Tuple, conjA, B, pB::Index2Tuple, conjB, pAB::Index2Tuple, α=1, [backend]) # expert mode

Contract indices of tensor `A` with corresponding indices in tensor `B` by assigning
them identical labels in the iterables `IA` and `IB`. The indices of the resulting
tensor correspond to the indices that only appear in either `IA` or `IB` and can be
ordered by specifying the optional argument `IC`. The default is to have all open
indices of `A` followed by all open indices of `B`. Note that inner contractions of an array
should be handled first with `tensortrace`, so that every label can appear only once in `IA`
or `IB` seperately, and once (for an open index) or twice (for a contracted index) in the
union of `IA` and `IB`.

Optionally, the symbols `conjA` and `conjB` can be used to specify that the input tensors
should be conjugated.

See also [`tensorcontract`](@ref).
"""
function planarcontract end

function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels,
conjB::Symbol, IC::TensorLabels,
α::Number=One())
ia = canonicalize_labels(A, IA)
ib = canonicalize_labels(B, IB)
ic = canonicalize_labels(IC)
pA, pB, pAB = planarcontract_indices(ia, ib, ic)
return planarcontract(A, pA, conjA, B, pB, conjB, pAB, α)
end
# default `IC`
function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels,
conjB::Symbol, α::Number=One())
ia = canonicalize_labels(A, IA)
ib = canonicalize_labels(B, IB)
pA, pB, pAB = planarcontract_indices(ia, ib)
return planarcontract(A, pA, conjA, B, pB, conjB, pAB, α)
end
# default `conjA` and `conjB`
function planarcontract(A, IA, B, IB, IC, α::Number=One())
return planarcontract(A, IA, :N, B, IB, :N, IC, α)
end
function planarcontract(A, IA, B, IB, α::Number=One())
return planarcontract(A, IA, :N, B, IB, :N, α)
end

# expert mode
function planarcontract(A, pA::Index2Tuple, conjA::Symbol,
B, pB::Index2Tuple, conjB::Symbol,
pAB::Index2Tuple, α::Number=One(),
backend::Backend...)
TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α))
C = TO.tensoralloc_contract(TC, pAB, A, pA, conjA, B, pB, conjB)
return planarcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend...)
end
Loading
Loading