Skip to content

Commit 0b13164

Browse files
authored
Make subset preserve varname ordering in varinfo (#832)
* Make subsumes preserve varname order * Bump patch version to 0.35.1 * Add subsumes test * Fix subset NamedTuple method
1 parent acda046 commit 0b13164

File tree

6 files changed

+48
-43
lines changed

6 files changed

+48
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.35.0"
3+
version = "0.35.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/abstract_varinfo.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,7 @@ has_varnamedvector(vi::AbstractVarInfo) = false
331331
332332
Subset a `varinfo` to only contain the variables `vns`.
333333
334-
!!! warning
335-
The ordering of the variables in the resulting `varinfo` is _not_
336-
guaranteed to follow the ordering of the variables in `varinfo`.
337-
Hence care must be taken, in particular when used in conjunction with
338-
other methods which uses the vector-representation of the `varinfo`,
339-
e.g. `getindex(varinfo, sampler)`.
334+
The ordering of variables in the return value will be the same as in `varinfo`.
340335
341336
# Examples
342337
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)

src/simple_varinfo.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,9 @@ end
416416

417417
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
418418
vns_present = collect(keys(x))
419-
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
420-
return filter(Base.Fix1(subsumes, vn), vns_present)
421-
end
419+
vns_found = filter(
420+
vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present
421+
)
422422
C = ConstructionBase.constructorof(typeof(x))
423423
if isempty(vns_found)
424424
return C()
@@ -439,7 +439,8 @@ function _subset(x::NamedTuple, vns)
439439
end
440440

441441
syms = map(getsym, vns)
442-
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms)))
442+
x_syms = filter(Base.Fix2(in, syms), keys(x))
443+
return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms)))
443444
end
444445

445446
_subset(x::VarNamedVector, vns) = subset(x, vns)

src/varinfo.jl

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -326,43 +326,41 @@ else
326326
_tail(nt::NamedTuple) = Base.tail(nt)
327327
end
328328

329-
function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName})
329+
function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
330330
metadata = subset(varinfo.metadata, vns)
331331
return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce))
332332
end
333333

334-
function subset(varinfo::VectorVarInfo, vns::AbstractVector{<:VarName})
335-
metadata = subset(varinfo.metadata, vns)
336-
return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce))
334+
function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
335+
vns_syms = Set(unique(map(getsym, vns)))
336+
syms = filter(Base.Fix2(in, vns_syms), keys(metadata))
337+
metadatas = map(syms) do sym
338+
subset(getfield(metadata, sym), filter(==(sym) getsym, vns))
339+
end
340+
return NamedTuple{syms}(metadatas)
337341
end
338342

339-
function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym}
340-
# If all the variables are using the same symbol, then we can just extract that field from the metadata.
341-
metadata = subset(getfield(varinfo.metadata, sym), vns)
342-
return VarInfo(
343-
NamedTuple{(sym,)}(tuple(metadata)),
344-
deepcopy(varinfo.logp),
345-
deepcopy(varinfo.num_produce),
346-
)
347-
end
343+
# The above method is type unstable since we don't know which symbols are in `vns`.
344+
# In the below special case, when all `vns` have the same symbol, we can write a type stable
345+
# version.
348346

349-
function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
350-
syms = Tuple(unique(map(getsym, vns)))
351-
metadatas = map(syms) do sym
352-
subset(getfield(varinfo.metadata, sym), filter(==(sym) getsym, vns))
347+
@generated function subset(
348+
metadata::NamedTuple{names}, vns::AbstractVector{<:VarName{sym}}
349+
) where {names,sym}
350+
return if (sym in names)
351+
# TODO(mhauru) Note that this could still generate an empty metadata object if none
352+
# of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for
353+
# emptiness would make this type unstable again.
354+
:((; $sym=subset(metadata.$sym, vns)))
355+
else
356+
:(NamedTuple{}())
353357
end
354-
355-
return VarInfo(
356-
NamedTuple{syms}(metadatas), deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)
357-
)
358358
end
359359

360360
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName}
361361
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
362-
# For each `vn` in `vns`, get the variables subsumed by `vn`.
363-
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
364-
filter(Base.Fix1(subsumes, vn), metadata.vns)
365-
end
362+
# Find all the vns in metadata that are subsumed by one of the given vns.
363+
vns = filter(vn -> any(subsumes(vn_given, vn) for vn_given in vns_given), metadata.vns)
366364
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
367365
indices = if isempty(vns)
368366
Dict{VarName,Int}()

src/varnamedvector.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,8 @@ Return a new `VarNamedVector` containing the values from `vnv` for variables in
11381138
Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning
11391139
that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`.
11401140
1141+
Preserves the order of variables in `vnv`.
1142+
11411143
# Examples
11421144
11431145
```jldoctest varnamedvector-subset
@@ -1151,18 +1153,17 @@ true
11511153
julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0])
11521154
true
11531155
"""
1154-
function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:VarName}
1156+
function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName})
11551157
# NOTE: This does not specialize types when possible.
1156-
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
1157-
filter(Base.Fix1(subsumes, vn), vnv.varnames)
1158-
end
11591158
vnv_new = similar(vnv)
11601159
# Return early if possible.
11611160
isempty(vnv) && return vnv_new
11621161

1163-
for vn in vns
1164-
insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn))
1165-
settrans!(vnv_new, istrans(vnv, vn), vn)
1162+
for vn in vnv.varnames
1163+
if any(subsumes(vn_given, vn) for vn_given in vns_given)
1164+
insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn))
1165+
settrans!(vnv_new, istrans(vnv, vn), vn)
1166+
end
11661167
end
11671168

11681169
return vnv_new

test/varinfo.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,16 @@ end
729729
# Values should be the same.
730730
@test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns]
731731
end
732+
733+
@testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in
734+
vns_supported
735+
varinfo_subset = subset(varinfo, vns_subset)
736+
vns_subset_reversed = reverse(vns_subset)
737+
varinfo_subset_reversed = subset(varinfo, vns_subset_reversed)
738+
@test varinfo_subset[:] == varinfo_subset_reversed[:]
739+
ground_truth = [varinfo[vn] for vn in vns_subset]
740+
@test varinfo_subset[:] == ground_truth
741+
end
732742
end
733743

734744
# For certain varinfos we should have errors.

0 commit comments

Comments
 (0)