Skip to content

Commit 5bccb9e

Browse files
committed
Re-add almost all integration tests into DPPL test suite proper
1 parent 2cc15b1 commit 5bccb9e

File tree

6 files changed

+359
-87
lines changed

6 files changed

+359
-87
lines changed

test/ad.jl

+39
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,43 @@
2525
end
2626
end
2727
end
28+
29+
@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
30+
# Failing model
31+
t = 1:0.05:8
32+
σ = 0.3
33+
y = @. rand(sin(t) + Normal(0, σ))
34+
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
35+
# Priors
36+
α ~ Normal(y[1], 0.001)
37+
τ ~ Exponential(1)
38+
η ~ filldist(Normal(0, 1), TT - 1)
39+
σ ~ Exponential(1)
40+
# create latent variable
41+
x = Vector{T}(undef, TT)
42+
x[1] = α
43+
for t in 2:TT
44+
x[t] = x[t - 1] + η[t - 1] * τ
45+
end
46+
# measurement model
47+
y ~ MvNormal(x, σ^2 * I)
48+
return x
49+
end
50+
model = state_space(y, length(t))
51+
52+
# Dummy sampling algorithm for testing. The test case can only be replicated
53+
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
54+
# overload assume so that model evaluation doesn't fail due to a lack
55+
# of implementation
56+
struct MyAlg end
57+
DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg}) = ()
58+
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyAlg}, dist, vn, vi) =
59+
DynamicPPL.assume(dist, vn, vi)
60+
61+
# Compiling the ReverseDiff tape used to fail here
62+
spl = Sampler(MyAlg())
63+
vi = VarInfo(model)
64+
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
65+
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
66+
end
2867
end

test/model.jl

+15-44
Original file line numberDiff line numberDiff line change
@@ -55,50 +55,8 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
5555
@test ljoint lp
5656

5757
#### logprior, logjoint, loglikelihood for MCMC chains ####
58-
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
59-
var_info = VarInfo(model)
60-
vns = DynamicPPL.TestUtils.varnames(model)
61-
syms = unique(DynamicPPL.getsym.(vns))
62-
63-
# generate a chain of sample parameter values.
64-
N = 200
65-
vals_OrderedDict = mapreduce(hcat, 1:N) do _
66-
rand(OrderedDict, model)
67-
end
68-
vals_mat = mapreduce(hcat, 1:N) do i
69-
[vals_OrderedDict[i][vn] for vn in vns]
70-
end
71-
i = 1
72-
for col in eachcol(vals_mat)
73-
col_flattened = []
74-
[push!(col_flattened, x...) for x in col]
75-
if i == 1
76-
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
77-
else
78-
chain_mat = vcat(
79-
chain_mat, reshape(col_flattened, 1, length(col_flattened))
80-
)
81-
end
82-
i += 1
83-
end
84-
chain_mat = convert(Matrix{Float64}, chain_mat)
85-
86-
# devise parameter names for chain
87-
sample_values_vec = collect(values(vals_OrderedDict[1]))
88-
symbol_names = []
89-
chain_sym_map = Dict()
90-
for k in 1:length(keys(var_info))
91-
vn_parent = keys(var_info)[k]
92-
sym = DynamicPPL.getsym(vn_parent)
93-
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
94-
for vn_child in vn_children
95-
chain_sym_map[Symbol(vn_child)] = sym
96-
symbol_names = [symbol_names; Symbol(vn_child)]
97-
end
98-
end
99-
chain = Chains(chain_mat, symbol_names)
100-
101-
# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
58+
for model in DynamicPPL.TestUtils.DEMO_MODELS
59+
chain = make_chain_from_prior(model, 200)
10260
logpriors = logprior(model, chain)
10361
loglikelihoods = loglikelihood(model, chain)
10462
logjoints = logjoint(model, chain)
@@ -125,6 +83,19 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
12583
end
12684
end
12785

86+
@testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin
87+
@model function multiple_types(x)
88+
ns ~ filldist(Normal(0, 2.0), 3)
89+
m ~ Uniform(0, 1)
90+
return x ~ Normal(m, 1)
91+
end
92+
model = multiple_types(1)
93+
chain = make_chain_from_prior(model, 10)
94+
loglikelihood(model, chain)
95+
logprior(model, chain)
96+
logjoint(model, chain)
97+
end
98+
12899
@testset "rng" begin
129100
model = gdemo_default
130101

test/model_utils.jl

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
@testset "model_utils.jl" begin
2+
@testset "value_iterator_from_chain" begin
3+
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
4+
chain = make_chain_from_prior(model, 10)
5+
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
6+
for vn in keys(d)
7+
val = DynamicPPL.getvalue(d, vn)
8+
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
9+
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
10+
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
11+
end
12+
end
13+
end
14+
end
15+
end
16+
end

test/runtests.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using Distributions
2626
using LinearAlgebra # Diagonal
2727

2828
using Combinatorics: combinations
29+
using OrderedCollections: OrderedSet
2930

3031
using DynamicPPL: getargs_dottilde, getargs_tilde, Selector
3132

@@ -48,15 +49,10 @@ include("test_util.jl")
4849
include("context_implementations.jl")
4950
include("logdensityfunction.jl")
5051
include("linking.jl")
51-
5252
include("threadsafe.jl")
53-
5453
include("serialization.jl")
55-
5654
include("pointwise_logdensities.jl")
57-
5855
include("lkj.jl")
59-
6056
include("debug_utils.jl")
6157
end
6258

test/test_util.jl

-34
Original file line numberDiff line numberDiff line change
@@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual)
4343
@test back(1)[1] grad
4444
end
4545

46-
"""
47-
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
48-
49-
Test `setval!` on `model` and `chain`.
50-
51-
Worth noting that this only supports models containing symbols of the forms
52-
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
53-
"""
54-
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
55-
var_info = VarInfo(model)
56-
spl = SampleFromPrior()
57-
θ_old = var_info[spl]
58-
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
59-
θ_new = var_info[spl]
60-
@test θ_old != θ_new
61-
vals = DynamicPPL.values_as(var_info, OrderedDict)
62-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
63-
for (n, v) in mapreduce(collect, vcat, iters)
64-
n = string(n)
65-
if Symbol(n) keys(chain)
66-
# Assume it's a group
67-
chain_val = vec(
68-
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
69-
)
70-
v_true = vec(v)
71-
else
72-
chain_val = chain[sample_idx, n, chain_idx]
73-
v_true = v
74-
end
75-
76-
@test v_true == chain_val
77-
end
78-
end
79-
8046
"""
8147
short_varinfo_name(vi::AbstractVarInfo)
8248

0 commit comments

Comments
 (0)