Skip to content

Commit 2cc15b1

Browse files
committed
Implement make_chain_from_prior([rng,] model, n_iters)
1 parent 2649a30 commit 2cc15b1

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1717
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1818
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1919
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
20+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2021
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2122
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2223
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

test/test_util.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,36 @@ function modify_value_representation(nt::NamedTuple)
110110
end
111111
return modified_nt
112112
end
113+
114+
"""
115+
make_chain_from_prior([rng,] model, n_iters)
116+
117+
Construct an MCMCChains.Chains object by sampling from the prior of `model` for
118+
`n_iters` iterations.
119+
"""
120+
function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
121+
# Sample from the prior
122+
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
123+
# Extract all varnames found in any dictionary. Doing it this way guards
124+
# against the possibility of having different varnames in different
125+
# dictionaries, e.g. for models that have dynamic variables / array sizes
126+
varnames = OrderedSet{VarName}()
127+
# Convert each varinfo into an OrderedDict of vns => params.
128+
# We have to use varname_and_value_leaves so that each parameter is a scalar
129+
dicts = map(varinfos) do t
130+
vals = DynamicPPL.values_as(t, OrderedDict)
131+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
132+
tuples = mapreduce(collect, vcat, iters)
133+
push!(varnames, map(first, tuples)...)
134+
OrderedDict(tuples)
135+
end
136+
# Convert back to list
137+
varnames = collect(varnames)
138+
# Construct matrix of values
139+
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
140+
# Construct and return the Chains object
141+
return Chains(vals, varnames)
142+
end
143+
function make_chain_from_prior(model::Model, n_iters::Int)
144+
return make_chain_from_prior(Random.default_rng(), model, n_iters)
145+
end

0 commit comments

Comments
 (0)