Skip to content

Commit ffe9272

Browse files
torfjeldegithub-actions[bot]devmotionsunxd3
authored
Fix for generated_quantities (#534)
* added method for extracting the child lens from a varname subsumed by another varname * added nested_getindex and nested_setindex! for VarInfo * added ConstructionBase.setproperties implementation for `Cholesky` * fixed minor formatting issue * added `supports_varname_indexing` for chains and use this in generated_quantities * use a private method rather than overloading getindex for Chains * removed getindex overloads in nested_index testing * moved generated_quantities tests to test/model.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * will now also correctly set variables to be resampled, etc. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added Compat as a test dep so we can methods such as stack * improved overload of ConstructionBase.setproperties * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added docstring to remove_parent_lens * removed methods which are not useful for the purpose of this PR * noticed we're incorrectly using chain rather than chain_params in generated_quantities * Update ext/DynamicPPLMCMCChainsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed doctests * added Requires.jl * Update src/DynamicPPL.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * bump patch version * Update src/DynamicPPL.jl Co-authored-by: David Widmann <[email protected]> * moved new generated_quantities functionality into setval_and_resample! so we can make use of this also for Turing.predict, etc. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update ext/DynamicPPLMCMCChainsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/chains.jl Co-authored-by: Xianda Sun <[email protected]> * bump compat entry for ConstructionBase.jl --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Xianda Sun <[email protected]>
1 parent 52cd7f9 commit ffe9272

9 files changed

+238
-20
lines changed

Project.toml

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.16"
3+
version = "0.23.17"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -16,32 +16,34 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1616
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1717
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1818
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1920
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2021
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2122
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2223

23-
[weakdeps]
24-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
25-
26-
[extensions]
27-
DynamicPPLMCMCChainsExt = ["MCMCChains"]
28-
29-
[extras]
30-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
31-
3224
[compat]
3325
AbstractMCMC = "2, 3.0, 4"
3426
AbstractPPL = "0.6"
3527
BangBang = "0.3"
3628
Bijectors = "0.13"
3729
ChainRulesCore = "0.9.7, 0.10, 1"
38-
ConstructionBase = "1"
30+
ConstructionBase = "1.5.4"
3931
Distributions = "0.23.8, 0.24, 0.25"
4032
DocStringExtensions = "0.8, 0.9"
4133
LogDensityProblems = "2"
42-
MacroTools = "0.5.6"
4334
MCMCChains = "6"
35+
MacroTools = "0.5.6"
4436
OrderedCollections = "1"
37+
Requires = "1"
4538
Setfield = "0.7.1, 0.8, 1"
4639
ZygoteRules = "0.2"
4740
julia = "1.6"
41+
42+
[extensions]
43+
DynamicPPLMCMCChainsExt = ["MCMCChains"]
44+
45+
[extras]
46+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
47+
48+
[weakdeps]
49+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

ext/DynamicPPLMCMCChainsExt.jl

+39-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,48 @@
11
module DynamicPPLMCMCChainsExt
22

3-
using DynamicPPL: DynamicPPL
4-
using MCMCChains: MCMCChains
3+
if isdefined(Base, :get_extension)
4+
using DynamicPPL: DynamicPPL
5+
using MCMCChains: MCMCChains
6+
else
7+
using ..DynamicPPL: DynamicPPL
8+
using ..MCMCChains: MCMCChains
9+
end
10+
11+
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
12+
function _check_varname_indexing(c::MCMCChains.Chains)
13+
return DynamicPPL.supports_varname_indexing(c) ||
14+
error("Chains do not support indexing using $vn.")
15+
end
16+
17+
# A few methods needed.
18+
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
19+
return _has_varname_to_symbol(chain.info)
20+
end
21+
function DynamicPPL.getindex_varname(
22+
c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx
23+
)
24+
_check_varname_indexing(c)
25+
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
26+
end
27+
function DynamicPPL.varnames(c::MCMCChains.Chains)
28+
_check_varname_indexing(c)
29+
return keys(c.info.varname_to_symbol)
30+
end
531

6-
function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCChains.Chains)
7-
chain_parameters = MCMCChains.get_sections(chain, :parameters)
32+
function DynamicPPL.generated_quantities(
33+
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
34+
)
35+
chain = MCMCChains.get_sections(chain_full, :parameters)
836
varinfo = DynamicPPL.VarInfo(model)
937
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
1038
return map(iters) do (sample_idx, chain_idx)
11-
DynamicPPL.setval_and_resample!(varinfo, chain_parameters, sample_idx, chain_idx)
12-
model(varinfo)
39+
# Update the varinfo with the current sample and make variables not present in `chain`
40+
# to be sampled.
41+
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
42+
43+
# TODO: Some of the variables can be a view into the `varinfo`, so we need to
44+
# `deepcopy` the `varinfo` before passing it to `model`.
45+
model(deepcopy(varinfo))
1346
end
1447
end
1548

src/DynamicPPL.jl

+13
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ const LEGACY_WARNING = """
154154
# Necessary forward declarations
155155
include("utils.jl")
156156
include("selector.jl")
157+
include("chains.jl")
157158
include("model.jl")
158159
include("sampler.jl")
159160
include("varname.jl")
@@ -175,4 +176,16 @@ include("logdensityfunction.jl")
175176
include("model_utils.jl")
176177
include("extract_priors.jl")
177178

179+
if !isdefined(Base, :get_extension)
180+
using Requires
181+
end
182+
183+
@static if !isdefined(Base, :get_extension)
184+
function __init__()
185+
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
186+
"../ext/DynamicPPLMCMCChainsExt.jl"
187+
)
188+
end
189+
end
190+
178191
end # module

src/chains.jl

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
supports_varname_indexing(chain::AbstractChains)
3+
4+
Return `true` if `chain` supports indexing using `VarName` in place of the
5+
variable name index.
6+
"""
7+
supports_varname_indexing(::AbstractChains) = false
8+
9+
"""
10+
getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx)
11+
12+
Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`.
13+
14+
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
15+
"""
16+
function getindex_varname end
17+
18+
"""
19+
varnames(chains::AbstractChains)
20+
21+
Return an iterator over the varnames present in `chains`.
22+
23+
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
24+
"""
25+
function varnames end

src/utils.jl

+36
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,42 @@ function splitlens(condition, lens)
501501
return current_parent, current_child, condition(current_parent)
502502
end
503503

504+
"""
505+
remove_parent_lens(vn_parent::VarName, vn_child::VarName)
506+
507+
Remove the parent lens `vn_parent` from `vn_child`.
508+
509+
# Examples
510+
```jldoctest
511+
julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a))
512+
(@lens _.a)
513+
514+
julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a[1]))
515+
(@lens _.a[1])
516+
517+
julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1]))
518+
(@lens _[1])
519+
520+
julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1].b))
521+
(@lens _[1].b)
522+
523+
julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a))
524+
ERROR: Could not find x.a in x.a
525+
526+
julia> DynamicPPL.remove_parent_lens(@varname(x.a[2]), @varname(x.a[1]))
527+
ERROR: Could not find x.a[2] in x.a[1]
528+
```
529+
"""
530+
function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
531+
_, child, issuccess = splitlens(getlens(vn_child)) do lens
532+
l = lens === nothing ? Setfield.IdentityLens() : lens
533+
VarName(vn_child, l) == vn_parent
534+
end
535+
536+
issuccess || error("Could not find $vn_parent in $vn_child")
537+
return child
538+
end
539+
504540
# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
505541
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
506542
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.

src/varinfo.jl

+57-2
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,41 @@ end
10641064
return Expr(:||, false, out...)
10651065
end
10661066

1067+
function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
1068+
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1069+
end
1070+
function nested_setindex_maybe!(
1071+
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
1072+
) where {names,sym}
1073+
return if sym in names
1074+
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1075+
else
1076+
nothing
1077+
end
1078+
end
1079+
function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName)
1080+
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
1081+
vns = md.vns
1082+
if vn in vns
1083+
setindex!(vi, val, vn)
1084+
return vn
1085+
end
1086+
1087+
# Otherwise, we need to check if either of the `vns` subsumes `vn`.
1088+
i = findfirst(Base.Fix2(subsumes, vn), vns)
1089+
i === nothing && return nothing
1090+
1091+
vn_parent = vns[i]
1092+
dist = getdist(md, vn_parent)
1093+
val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here.
1094+
# Split the varname into its tail lens.
1095+
lens = remove_parent_lens(vn_parent, vn)
1096+
# Update the value for the parent.
1097+
val_parent_updated = set!!(val_parent, lens, val)
1098+
setindex!(vi, val_parent_updated, vn_parent)
1099+
return vn_parent
1100+
end
1101+
10671102
# The default getindex & setindex!() for get & set values
10681103
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
10691104
getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn))
@@ -1131,7 +1166,8 @@ The value(s) may or may not be transformed to Euclidean space.
11311166
"""
11321167
setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi)
11331168
function BangBang.setindex!!(vi::VarInfo, val, vn::VarName)
1134-
return (setindex!(vi, val, vn); return vi)
1169+
setindex!(vi, val, vn)
1170+
return vi
11351171
end
11361172

11371173
"""
@@ -1600,7 +1636,26 @@ end
16001636
function setval_and_resample!(
16011637
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
16021638
)
1603-
return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
1639+
if supports_varname_indexing(chains)
1640+
# First we need to set every variable to be resampled.
1641+
for vn in keys(vi)
1642+
set_flag!(vi, vn, "del")
1643+
end
1644+
# Then we set the variables in `varinfo` from `chain`.
1645+
for vn in varnames(chains)
1646+
vn_updated = nested_setindex_maybe!(
1647+
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
1648+
)
1649+
1650+
# Unset the `del` flag if we found something.
1651+
if vn_updated !== nothing
1652+
# NOTE: This will be triggered even if only a subset of a variable has been set!
1653+
unset_flag!(vi, vn_updated, "del")
1654+
end
1655+
end
1656+
else
1657+
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
1658+
end
16041659
end
16051660

16061661
function _setval_and_resample_kernel!(

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
44
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
5+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
56
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
67
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
78
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -24,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2425
AbstractMCMC = "2.1, 3.0, 4"
2526
AbstractPPL = "0.6"
2627
Bijectors = "0.13"
28+
Compat = "4.3.0"
2729
Distributions = "0.25"
2830
DistributionsAD = "0.6.3"
2931
Documenter = "0.26.1, 0.27"

test/model.jl

+51
Original file line numberDiff line numberDiff line change
@@ -278,4 +278,55 @@ end
278278
@test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x)
279279
end
280280
end
281+
282+
@testset "generated_quantities on `LKJCholesky`" begin
283+
n = 10
284+
d = 2
285+
model = DynamicPPL.TestUtils.demo_lkjchol(d)
286+
xs = [model().x for _ in 1:n]
287+
288+
# Extract varnames and values.
289+
vns_and_vals_xs = map(
290+
collect Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs
291+
)
292+
vns = map(first, first(vns_and_vals_xs))
293+
vals = map(vns_and_vals_xs) do vns_and_vals
294+
map(last, vns_and_vals)
295+
end
296+
297+
# Construct the chain.
298+
syms = map(Symbol, vns)
299+
vns_to_syms = OrderedDict{VarName,Any}(zip(vns, syms))
300+
301+
chain = MCMCChains.Chains(
302+
permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,)
303+
)
304+
display(chain)
305+
306+
# Test!
307+
results = generated_quantities(model, chain)
308+
for (x_true, result) in zip(xs, results)
309+
@test x_true.UL == result.x.UL
310+
end
311+
312+
# With variables that aren't in the `model`.
313+
vns_to_syms_with_extra = let d = deepcopy(vns_to_syms)
314+
d[@varname(y)] = :y
315+
d
316+
end
317+
vals_with_extra = map(enumerate(vals)) do (i, v)
318+
vcat(v, i)
319+
end
320+
chain_with_extra = MCMCChains.Chains(
321+
permutedims(stack(vals_with_extra)),
322+
vcat(syms, [:y]);
323+
info=(varname_to_symbol=vns_to_syms_with_extra,),
324+
)
325+
display(chain_with_extra)
326+
# Test!
327+
results = generated_quantities(model, chain_with_extra)
328+
for (x_true, result) in zip(xs, results)
329+
@test x_true.UL == result.x.UL
330+
end
331+
end
281332
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using MCMCChains
1111
using Tracker
1212
using Zygote
1313
using Setfield
14+
using Compat
1415

1516
using Distributed
1617
using LinearAlgebra

0 commit comments

Comments
 (0)