Skip to content

Commit 543a97b

Browse files
committed
Re-add almost all integration tests into DPPL test suite proper
1 parent ae14716 commit 543a97b

File tree

6 files changed

+386
-93
lines changed

6 files changed

+386
-93
lines changed

test/ad.jl

+39
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,43 @@
2525
end
2626
end
2727
end
28+
29+
@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
30+
# Failing model
31+
t = 1:0.05:8
32+
σ = 0.3
33+
y = @. rand(sin(t) + Normal(0, σ))
34+
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
35+
# Priors
36+
α ~ Normal(y[1], 0.001)
37+
τ ~ Exponential(1)
38+
η ~ filldist(Normal(0, 1), TT - 1)
39+
σ ~ Exponential(1)
40+
# create latent variable
41+
x = Vector{T}(undef, TT)
42+
x[1] = α
43+
for t in 2:TT
44+
x[t] = x[t - 1] + η[t - 1] * τ
45+
end
46+
# measurement model
47+
y ~ MvNormal(x, σ^2 * I)
48+
return x
49+
end
50+
model = state_space(y, length(t))
51+
52+
# Dummy sampling algorithm for testing. The test case can only be replicated
53+
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
54+
# overload assume so that model evaluation doesn't fail due to a lack
55+
# of implementation
56+
struct MyEmptyAlg end
57+
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
58+
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
59+
DynamicPPL.assume(dist, vn, vi)
60+
61+
# Compiling the ReverseDiff tape used to fail here
62+
spl = Sampler(MyEmptyAlg())
63+
vi = VarInfo(model)
64+
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
65+
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
66+
end
2867
end

test/model.jl

+40-50
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false
2929
is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true
3030
is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
3131

32+
const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
33+
3234
@testset "model.jl" begin
3335
@testset "convenience functions" begin
34-
model = gdemo_default # defined in test/test_util.jl
36+
model = GDEMO_DEFAULT # defined in test/test_util.jl
3537

3638
# sample from model and extract variables
3739
vi = VarInfo(model)
@@ -55,53 +57,26 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
5557
@test ljoint lp
5658

5759
#### logprior, logjoint, loglikelihood for MCMC chains ####
58-
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
59-
var_info = VarInfo(model)
60-
vns = DynamicPPL.TestUtils.varnames(model)
61-
syms = unique(DynamicPPL.getsym.(vns))
62-
63-
# generate a chain of sample parameter values.
60+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
6461
N = 200
65-
vals_OrderedDict = mapreduce(hcat, 1:N) do _
66-
rand(OrderedDict, model)
67-
end
68-
vals_mat = mapreduce(hcat, 1:N) do i
69-
[vals_OrderedDict[i][vn] for vn in vns]
70-
end
71-
i = 1
72-
for col in eachcol(vals_mat)
73-
col_flattened = []
74-
[push!(col_flattened, x...) for x in col]
75-
if i == 1
76-
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
77-
else
78-
chain_mat = vcat(
79-
chain_mat, reshape(col_flattened, 1, length(col_flattened))
80-
)
81-
end
82-
i += 1
83-
end
84-
chain_mat = convert(Matrix{Float64}, chain_mat)
85-
86-
# devise parameter names for chain
87-
sample_values_vec = collect(values(vals_OrderedDict[1]))
88-
symbol_names = []
89-
chain_sym_map = Dict()
90-
for k in 1:length(keys(var_info))
91-
vn_parent = keys(var_info)[k]
62+
chain = make_chain_from_prior(model, N)
63+
logpriors = logprior(model, chain)
64+
loglikelihoods = loglikelihood(model, chain)
65+
logjoints = logjoint(model, chain)
66+
67+
# Construct mapping of varname symbols to varname-parent symbols.
68+
# Here, varname_leaves is used to ensure compatibility with the
69+
# variables stored in the chain
70+
var_info = VarInfo(model)
71+
chain_sym_map = Dict{Symbol,Symbol}()
72+
for vn_parent in keys(var_info)
9273
sym = DynamicPPL.getsym(vn_parent)
93-
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
74+
vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent])
9475
for vn_child in vn_children
9576
chain_sym_map[Symbol(vn_child)] = sym
96-
symbol_names = [symbol_names; Symbol(vn_child)]
9777
end
9878
end
99-
chain = Chains(chain_mat, symbol_names)
10079

101-
# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
102-
logpriors = logprior(model, chain)
103-
loglikelihoods = loglikelihood(model, chain)
104-
logjoints = logjoint(model, chain)
10580
# compare them with true values
10681
for i in 1:N
10782
samples_dict = Dict()
@@ -125,8 +100,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
125100
end
126101
end
127102

103+
@testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin
104+
@model function multiple_types(x)
105+
ns ~ filldist(Normal(0, 2.0), 3)
106+
m ~ Uniform(0, 1)
107+
return x ~ Normal(m, 1)
108+
end
109+
model = multiple_types(1)
110+
chain = make_chain_from_prior(model, 10)
111+
loglikelihood(model, chain)
112+
logprior(model, chain)
113+
logjoint(model, chain)
114+
end
115+
128116
@testset "rng" begin
129-
model = gdemo_default
117+
model = GDEMO_DEFAULT
130118

131119
for sampler in (SampleFromPrior(), SampleFromUniform())
132120
for i in 1:10
@@ -144,13 +132,15 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
144132
end
145133

146134
@testset "defaults without VarInfo, Sampler, and Context" begin
147-
model = gdemo_default
135+
model = GDEMO_DEFAULT
148136

149137
Random.seed!(100)
150-
s, m = model()
138+
retval = model()
151139

152140
Random.seed!(100)
153-
@test model(Random.default_rng()) == (s, m)
141+
retval2 = model(Random.default_rng())
142+
@test retval2.s == retval.s
143+
@test retval2.m == retval.m
154144
end
155145

156146
@testset "nameof" begin
@@ -184,7 +174,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
184174
end
185175

186176
@testset "Internal methods" begin
187-
model = gdemo_default
177+
model = GDEMO_DEFAULT
188178

189179
# sample from model and extract variables
190180
vi = VarInfo(model)
@@ -224,7 +214,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
224214
end
225215

226216
@testset "rand" begin
227-
model = gdemo_default
217+
model = GDEMO_DEFAULT
228218

229219
Random.seed!(1776)
230220
s, m = model()
@@ -309,7 +299,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
309299
end
310300
end
311301

312-
@testset "generated_quantities on `LKJCholesky`" begin
302+
@testset "returned() on `LKJCholesky`" begin
313303
n = 10
314304
d = 2
315305
model = DynamicPPL.TestUtils.demo_lkjchol(d)
@@ -333,7 +323,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
333323
)
334324

335325
# Test!
336-
results = generated_quantities(model, chain)
326+
results = returned(model, chain)
337327
for (x_true, result) in zip(xs, results)
338328
@test x_true.UL == result.x.UL
339329
end
@@ -352,7 +342,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
352342
info=(varname_to_symbol=vns_to_syms_with_extra,),
353343
)
354344
# Test!
355-
results = generated_quantities(model, chain_with_extra)
345+
results = returned(model, chain_with_extra)
356346
for (x_true, result) in zip(xs, results)
357347
@test x_true.UL == result.x.UL
358348
end

test/model_utils.jl

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
@testset "model_utils.jl" begin
2+
@testset "value_iterator_from_chain" begin
3+
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
4+
chain = make_chain_from_prior(model, 10)
5+
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
6+
for vn in keys(d)
7+
val = DynamicPPL.getvalue(d, vn)
8+
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
9+
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
10+
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
11+
end
12+
end
13+
end
14+
end
15+
end
16+
end

test/runtests.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using Distributions
2626
using LinearAlgebra # Diagonal
2727

2828
using Combinatorics: combinations
29+
using OrderedCollections: OrderedSet
2930

3031
using DynamicPPL: getargs_dottilde, getargs_tilde, Selector
3132

@@ -48,15 +49,10 @@ include("test_util.jl")
4849
include("context_implementations.jl")
4950
include("logdensityfunction.jl")
5051
include("linking.jl")
51-
5252
include("threadsafe.jl")
53-
5453
include("serialization.jl")
55-
5654
include("pointwise_logdensities.jl")
57-
5855
include("lkj.jl")
59-
6056
include("debug_utils.jl")
6157
end
6258

test/test_util.jl

-34
Original file line numberDiff line numberDiff line change
@@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual)
4343
@test back(1)[1] grad
4444
end
4545

46-
"""
47-
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
48-
49-
Test `setval!` on `model` and `chain`.
50-
51-
Worth noting that this only supports models containing symbols of the forms
52-
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
53-
"""
54-
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
55-
var_info = VarInfo(model)
56-
spl = SampleFromPrior()
57-
θ_old = var_info[spl]
58-
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
59-
θ_new = var_info[spl]
60-
@test θ_old != θ_new
61-
vals = DynamicPPL.values_as(var_info, OrderedDict)
62-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
63-
for (n, v) in mapreduce(collect, vcat, iters)
64-
n = string(n)
65-
if Symbol(n) keys(chain)
66-
# Assume it's a group
67-
chain_val = vec(
68-
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
69-
)
70-
v_true = vec(v)
71-
else
72-
chain_val = chain[sample_idx, n, chain_idx]
73-
v_true = v
74-
end
75-
76-
@test v_true == chain_val
77-
end
78-
end
79-
8046
"""
8147
short_varinfo_name(vi::AbstractVarInfo)
8248

0 commit comments

Comments
 (0)