Skip to content

Commit b52aa73

Browse files
committed
Re-add test_setval! from test/turing/model.jl
1 parent 2fbf57b commit b52aa73

File tree

2 files changed

+48
-34
lines changed

2 files changed

+48
-34
lines changed

test/test_util.jl

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual)
4343
@test back(1)[1] grad
4444
end
4545

46-
"""
47-
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
48-
49-
Test `setval!` on `model` and `chain`.
50-
51-
Worth noting that this only supports models containing symbols of the forms
52-
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
53-
"""
54-
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
55-
var_info = VarInfo(model)
56-
spl = SampleFromPrior()
57-
θ_old = var_info[spl]
58-
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
59-
θ_new = var_info[spl]
60-
@test θ_old != θ_new
61-
vals = DynamicPPL.values_as(var_info, OrderedDict)
62-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
63-
for (n, v) in mapreduce(collect, vcat, iters)
64-
n = string(n)
65-
if Symbol(n) keys(chain)
66-
# Assume it's a group
67-
chain_val = vec(
68-
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
69-
)
70-
v_true = vec(v)
71-
else
72-
chain_val = chain[sample_idx, n, chain_idx]
73-
v_true = v
74-
end
75-
76-
@test v_true == chain_val
77-
end
78-
end
79-
8046
"""
8147
short_varinfo_name(vi::AbstractVarInfo)
8248

test/varinfo.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
130130
test_base!!(SimpleVarInfo(Dict()))
131131
test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector()))
132132
end
133+
133134
@testset "flags" begin
134135
# Test flag setting:
135136
# is_flagged, set_flag!, unset_flag!
@@ -187,6 +188,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
187188
setgid!(vi, gid2, vn)
188189
@test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2])
189190
end
191+
190192
@testset "setval! & setval_and_resample!" begin
191193
@model function testmodel(x)
192194
n = length(x)
@@ -339,6 +341,52 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
339341
@test vals_prev == vi.metadata.x.vals
340342
end
341343

344+
@testset "setval! on chain" begin
345+
# Define a helper function
346+
"""
347+
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
348+
349+
Test `setval!` on `model` and `chain`.
350+
351+
Worth noting that this only supports models containing symbols of the forms
352+
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
353+
"""
354+
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
355+
var_info = VarInfo(model)
356+
spl = SampleFromPrior()
357+
θ_old = var_info[spl]
358+
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
359+
θ_new = var_info[spl]
360+
@test θ_old != θ_new
361+
vals = DynamicPPL.values_as(var_info, OrderedDict)
362+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
363+
for (n, v) in mapreduce(collect, vcat, iters)
364+
n = string(n)
365+
if Symbol(n) keys(chain)
366+
# Assume it's a group
367+
chain_val = vec(
368+
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
369+
)
370+
v_true = vec(v)
371+
else
372+
chain_val = chain[sample_idx, n, chain_idx]
373+
v_true = v
374+
end
375+
376+
@test v_true == chain_val
377+
end
378+
end
379+
380+
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
381+
chain = make_chain_from_prior(model, 10)
382+
# A simple way of checking that the computation is determinstic: run twice and compare.
383+
res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters))
384+
res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters))
385+
@test all(res1 .== res2)
386+
test_setval!(model, MCMCChains.get_sections(chain, :parameters))
387+
end
388+
end
389+
342390
@testset "istrans" begin
343391
@model demo_constrained() = x ~ truncated(Normal(), 0, Inf)
344392
model = demo_constrained()

0 commit comments

Comments
 (0)