@@ -29,9 +29,11 @@ is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false
29
29
is_typed_varinfo (varinfo:: DynamicPPL.TypedVarInfo ) = true
30
30
is_typed_varinfo (varinfo:: DynamicPPL.SimpleVarInfo{<:NamedTuple} ) = true
31
31
32
+ const GDEMO_DEFAULT = DynamicPPL. TestUtils. demo_assume_observe_literal ()
33
+
32
34
@testset " model.jl" begin
33
35
@testset " convenience functions" begin
34
- model = gdemo_default # defined in test/test_util.jl
36
+ model = GDEMO_DEFAULT # defined in test/test_util.jl
35
37
36
38
# sample from model and extract variables
37
39
vi = VarInfo (model)
@@ -55,53 +57,26 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
55
57
@test ljoint ≈ lp
56
58
57
59
# ### logprior, logjoint, loglikelihood for MCMC chains ####
58
- for model in DynamicPPL. TestUtils. DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
59
- var_info = VarInfo (model)
60
- vns = DynamicPPL. TestUtils. varnames (model)
61
- syms = unique (DynamicPPL. getsym .(vns))
62
-
63
- # generate a chain of sample parameter values.
60
+ @testset " $(model. f) " for model in DynamicPPL. TestUtils. DEMO_MODELS
64
61
N = 200
65
- vals_OrderedDict = mapreduce (hcat, 1 : N) do _
66
- rand (OrderedDict, model)
67
- end
68
- vals_mat = mapreduce (hcat, 1 : N) do i
69
- [vals_OrderedDict[i][vn] for vn in vns]
70
- end
71
- i = 1
72
- for col in eachcol (vals_mat)
73
- col_flattened = []
74
- [push! (col_flattened, x... ) for x in col]
75
- if i == 1
76
- chain_mat = Matrix (reshape (col_flattened, 1 , length (col_flattened)))
77
- else
78
- chain_mat = vcat (
79
- chain_mat, reshape (col_flattened, 1 , length (col_flattened))
80
- )
81
- end
82
- i += 1
83
- end
84
- chain_mat = convert (Matrix{Float64}, chain_mat)
85
-
86
- # devise parameter names for chain
87
- sample_values_vec = collect (values (vals_OrderedDict[1 ]))
88
- symbol_names = []
89
- chain_sym_map = Dict ()
90
- for k in 1 : length (keys (var_info))
91
- vn_parent = keys (var_info)[k]
62
+ chain = make_chain_from_prior (model, N)
63
+ logpriors = logprior (model, chain)
64
+ loglikelihoods = loglikelihood (model, chain)
65
+ logjoints = logjoint (model, chain)
66
+
67
+ # Construct mapping of varname symbols to varname-parent symbols.
68
+ # Here, varname_leaves is used to ensure compatibility with the
69
+ # variables stored in the chain
70
+ var_info = VarInfo (model)
71
+ chain_sym_map = Dict {Symbol, Symbol} ()
72
+ for vn_parent in keys (var_info)
92
73
sym = DynamicPPL. getsym (vn_parent)
93
- vn_children = DynamicPPL. varname_leaves (vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
74
+ vn_children = DynamicPPL. varname_leaves (vn_parent, var_info[vn_parent])
94
75
for vn_child in vn_children
95
76
chain_sym_map[Symbol (vn_child)] = sym
96
- symbol_names = [symbol_names; Symbol (vn_child)]
97
77
end
98
78
end
99
- chain = Chains (chain_mat, symbol_names)
100
79
101
- # calculate the pointwise loglikelihoods for the whole chain using the newly written functions
102
- logpriors = logprior (model, chain)
103
- loglikelihoods = loglikelihood (model, chain)
104
- logjoints = logjoint (model, chain)
105
80
# compare them with true values
106
81
for i in 1 : N
107
82
samples_dict = Dict ()
@@ -115,18 +90,31 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
115
90
samples = (; samples_dict... )
116
91
samples = modify_value_representation (samples) # `modify_value_representation` defined in test/test_util.jl
117
92
@test logpriors[i] ≈
118
- DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
93
+ DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
119
94
@test loglikelihoods[i] ≈ DynamicPPL. TestUtils. loglikelihood_true (
120
95
model, samples[:s ], samples[:m ]
121
96
)
122
97
@test logjoints[i] ≈
123
- DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
98
+ DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
124
99
end
125
100
end
126
101
end
127
102
103
+ @testset " DynamicPPL#684: threadsafe evaluation with multiple types" begin
104
+ @model function multiple_types (x)
105
+ ns ~ filldist (Normal (0 , 2.0 ), 3 )
106
+ m ~ Uniform (0 , 1 )
107
+ return x ~ Normal (m, 1 )
108
+ end
109
+ model = multiple_types (1 )
110
+ chain = make_chain_from_prior (model, 10 )
111
+ loglikelihood (model, chain)
112
+ logprior (model, chain)
113
+ logjoint (model, chain)
114
+ end
115
+
128
116
@testset " rng" begin
129
- model = gdemo_default
117
+ model = GDEMO_DEFAULT
130
118
131
119
for sampler in (SampleFromPrior (), SampleFromUniform ())
132
120
for i in 1 : 10
@@ -144,13 +132,15 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
144
132
end
145
133
146
134
@testset " defaults without VarInfo, Sampler, and Context" begin
147
- model = gdemo_default
135
+ model = GDEMO_DEFAULT
148
136
149
137
Random. seed! (100 )
150
- s, m = model ()
138
+ retval = model ()
151
139
152
140
Random. seed! (100 )
153
- @test model (Random. default_rng ()) == (s, m)
141
+ retval2 = model (Random. default_rng ())
142
+ @test retval2. s == retval. s
143
+ @test retval2. m == retval. m
154
144
end
155
145
156
146
@testset " nameof" begin
@@ -184,7 +174,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
184
174
end
185
175
186
176
@testset " Internal methods" begin
187
- model = gdemo_default
177
+ model = GDEMO_DEFAULT
188
178
189
179
# sample from model and extract variables
190
180
vi = VarInfo (model)
@@ -224,7 +214,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
224
214
end
225
215
226
216
@testset " rand" begin
227
- model = gdemo_default
217
+ model = GDEMO_DEFAULT
228
218
229
219
Random. seed! (1776 )
230
220
s, m = model ()
@@ -293,10 +283,10 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
293
283
# Ensure log-probability computations are implemented.
294
284
@test logprior (model, x) ≈ DynamicPPL. TestUtils. logprior_true (model, x... )
295
285
@test loglikelihood (model, x) ≈
296
- DynamicPPL. TestUtils. loglikelihood_true (model, x... )
286
+ DynamicPPL. TestUtils. loglikelihood_true (model, x... )
297
287
@test logjoint (model, x) ≈ DynamicPPL. TestUtils. logjoint_true (model, x... )
298
288
@test logjoint (model, x) !=
299
- DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
289
+ DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
300
290
# Ensure `varnames` is implemented.
301
291
vi = last (
302
292
DynamicPPL. evaluate!! (
@@ -309,7 +299,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
309
299
end
310
300
end
311
301
312
- @testset " generated_quantities on `LKJCholesky`" begin
302
+ @testset " returned() on `LKJCholesky`" begin
313
303
n = 10
314
304
d = 2
315
305
model = DynamicPPL. TestUtils. demo_lkjchol (d)
@@ -333,7 +323,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
333
323
)
334
324
335
325
# Test!
336
- results = generated_quantities (model, chain)
326
+ results = returned (model, chain)
337
327
for (x_true, result) in zip (xs, results)
338
328
@test x_true. UL == result. x. UL
339
329
end
@@ -352,7 +342,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
352
342
info= (varname_to_symbol= vns_to_syms_with_extra,),
353
343
)
354
344
# Test!
355
- results = generated_quantities (model, chain_with_extra)
345
+ results = returned (model, chain_with_extra)
356
346
for (x_true, result) in zip (xs, results)
357
347
@test x_true. UL == result. x. UL
358
348
end
0 commit comments