Skip to content

Commit f7eeaab

Browse files
committed
Implement TruncationKeepSorted
1 parent 4618584 commit f7eeaab

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

src/tensors/factorizations/truncation.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,23 @@ function _compute_truncerr(Σdata, truncdim, p=2)
119119
p, zero(S))
120120
end
121121

122-
function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector}
122+
function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}; by=identity,
123+
rev::Bool=true) where {I<:Sector}
123124
# early return
124125
(isempty(S) || all(iszero, values(truncdim))) && return nothing
125-
σmin, imin = findmin(keys(truncdim)) do c
126-
d = truncdim[c]
127-
return S[c][d]
126+
if rev
127+
σmin, imin = findmin(keys(truncdim)) do c
128+
d = truncdim[c]
129+
return by(S[c][d])
130+
end
131+
return σmin, keys(truncdim)[imin]
132+
else
133+
σmax, imax = findmax(keys(truncdim)) do c
134+
d = truncdim[c]
135+
return by(S[c][d])
136+
end
137+
return σmax, keys(truncdim)[imax]
128138
end
129-
return σmin, keys(truncdim)[imin]
130139
end
131140

132141
# implementations
@@ -173,12 +182,18 @@ function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError)
173182
end
174183

175184
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted)
176-
@assert strategy.by === abs && strategy.rev == true "Not implemented"
185+
return findtruncated(Sd, strategy)
186+
end
187+
function findtruncated(Sd::SectorDict, strategy::TruncationKeepSorted)
188+
permutations = SectorDict(c => (sortperm(d; strategy.by, strategy.rev))
189+
for (c, d) in Sd)
190+
Sd = SectorDict(c => sort(d; strategy.by, strategy.rev) for (c, d) in Sd)
191+
177192
I = keytype(Sd)
178193
truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd)
179194
totaldim = sum(dim(c) * d for (c, d) in truncdim; init=0)
180195
while true
181-
next = _findnexttruncvalue(Sd, truncdim)
196+
next = _findnexttruncvalue(Sd, truncdim; strategy.by, strategy.rev)
182197
isnothing(next) && break
183198
_, cmin = next
184199
truncdim[cmin] -= 1
@@ -191,7 +206,7 @@ function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted)
191206
delete!(truncdim, cmin)
192207
end
193208
end
194-
return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim)
209+
return SectorDict(c => permutations[c][Base.OneTo(d)] for (c, d) in truncdim)
195210
end
196211

197212
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationSpace)

0 commit comments

Comments
 (0)