Skip to content

Commit 2fbf57b

Browse files
committed
Implement make_chain_from_prior([rng,] model, n_iters)
1 parent 639db16 commit 2fbf57b

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/test_util.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,32 @@ function modify_value_representation(nt::NamedTuple)
110110
end
111111
return modified_nt
112112
end
113+
114+
"""
115+
make_chain_from_prior([rng,] model, n_iters)
116+
117+
Construct an MCMCChains.Chains object by sampling from the prior of `model` for
118+
`n_iters` iterations.
119+
"""
120+
function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
121+
# Sample from the prior
122+
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
123+
# Convert each varinfo into an OrderedDict of vns => params.
124+
# We have to use varname_and_value_leaves so that each parameter is a scalar
125+
dicts = map(varinfos) do t
126+
vals = DynamicPPL.values_as(t, OrderedDict)
127+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
128+
tuples = mapreduce(collect, vcat, iters)
129+
OrderedDict(tuples)
130+
end
131+
# Extract all varnames found in any dictionary. Doing it this way guards
132+
# against the possibility of having different varnames in different
133+
# dictionaries, e.g. for models that have dynamic variables / array sizes
134+
all_varnames = collect(union(map(keys, dicts)...))
135+
vals = [get(dict, vn, missing) for dict in dicts, vn in all_varnames]
136+
# Construct and return the Chains object
137+
return Chains(vals, all_varnames)
138+
end
139+
function make_chain_from_prior(model::Model, n_iters::Int)
140+
return make_chain_from_prior(Random.default_rng(), model, n_iters)
141+
end

0 commit comments

Comments
 (0)