@@ -130,6 +130,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
130
130
test_base!! (SimpleVarInfo (Dict ()))
131
131
test_base!! (SimpleVarInfo (DynamicPPL. VarNamedVector ()))
132
132
end
133
+
133
134
@testset " flags" begin
134
135
# Test flag setting:
135
136
# is_flagged, set_flag!, unset_flag!
@@ -187,6 +188,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
187
188
setgid! (vi, gid2, vn)
188
189
@test meta. x. gids[meta. x. idcs[vn]] == Set ([gid1, gid2])
189
190
end
191
+
190
192
@testset " setval! & setval_and_resample!" begin
191
193
@model function testmodel (x)
192
194
n = length (x)
@@ -339,6 +341,52 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
339
341
@test vals_prev == vi. metadata. x. vals
340
342
end
341
343
344
+ @testset " setval! on chain" begin
345
+ # Define a helper function
346
+ """
347
+ test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
348
+
349
+ Test `setval!` on `model` and `chain`.
350
+
351
+ Worth noting that this only supports models containing symbols of the forms
352
+ `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
353
+ """
354
+ function test_setval! (model, chain; sample_idx= 1 , chain_idx= 1 )
355
+ var_info = VarInfo (model)
356
+ spl = SampleFromPrior ()
357
+ θ_old = var_info[spl]
358
+ DynamicPPL. setval! (var_info, chain, sample_idx, chain_idx)
359
+ θ_new = var_info[spl]
360
+ @test θ_old != θ_new
361
+ vals = DynamicPPL. values_as (var_info, OrderedDict)
362
+ iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
363
+ for (n, v) in mapreduce (collect, vcat, iters)
364
+ n = string (n)
365
+ if Symbol (n) ∉ keys (chain)
366
+ # Assume it's a group
367
+ chain_val = vec (
368
+ MCMCChains. group (chain, Symbol (n)). value[sample_idx, :, chain_idx]
369
+ )
370
+ v_true = vec (v)
371
+ else
372
+ chain_val = chain[sample_idx, n, chain_idx]
373
+ v_true = v
374
+ end
375
+
376
+ @test v_true == chain_val
377
+ end
378
+ end
379
+
380
+ @testset " $model " for model in DynamicPPL. TestUtils. DEMO_MODELS
381
+ chain = make_chain_from_prior (model, 10 )
382
+ # A simple way of checking that the computation is determinstic: run twice and compare.
383
+ res1 = generated_quantities (model, MCMCChains. get_sections (chain, :parameters ))
384
+ res2 = generated_quantities (model, MCMCChains. get_sections (chain, :parameters ))
385
+ @test all (res1 .== res2)
386
+ test_setval! (model, MCMCChains. get_sections (chain, :parameters ))
387
+ end
388
+ end
389
+
342
390
@testset " istrans" begin
343
391
@model demo_constrained () = x ~ truncated (Normal (), 0 , Inf )
344
392
model = demo_constrained ()
0 commit comments