Skip to content

Commit 2c2800b

Browse files
committed
Add DynamicPPL integration tests
1 parent 5b24ceb commit 2c2800b

File tree

6 files changed

+859
-1
lines changed

6 files changed

+859
-1
lines changed

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
[deps]
22
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
3+
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
34
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
45
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
56
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
67
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
8+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
79
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
810
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
911
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"

test/dynamicppl/compiler.jl

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
module DynamicPPLCompilerTests
2+
3+
using ..NumericalTests: check_numerical
4+
using LinearAlgebra: I
5+
using Test: @test, @testset, @test_throws
6+
using Turing
7+
8+
# TODO(penelopeysm): Move this to a DynamicPPL Test Utils module
9+
# We use this a lot!
10+
@model function gdemo_d()
11+
s ~ InverseGamma(2, 3)
12+
m ~ Normal(0, sqrt(s))
13+
1.5 ~ Normal(m, sqrt(s))
14+
2.0 ~ Normal(m, sqrt(s))
15+
return s, m
16+
end
17+
const gdemo_default = gdemo_d()
18+
19+
@testset "compiler.jl" begin
20+
@testset "assume" begin
21+
@model function test_assume()
22+
x ~ Bernoulli(1)
23+
y ~ Bernoulli(x / 2)
24+
return x, y
25+
end
26+
27+
smc = SMC()
28+
pg = PG(10)
29+
30+
res1 = sample(test_assume(), smc, 1000)
31+
res2 = sample(test_assume(), pg, 1000)
32+
33+
check_numerical(res1, [:y], [0.5]; atol=0.1)
34+
check_numerical(res2, [:y], [0.5]; atol=0.1)
35+
36+
# Check that all xs are 1.
37+
@test all(isone, res1[:x])
38+
@test all(isone, res2[:x])
39+
end
40+
@testset "beta binomial" begin
41+
prior = Beta(2, 2)
42+
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
43+
exact = Beta(prior.α + sum(obs), prior.β + length(obs) - sum(obs))
44+
meanp = exact.α / (exact.α + exact.β)
45+
46+
@model function testbb(obs)
47+
p ~ Beta(2, 2)
48+
x ~ Bernoulli(p)
49+
for i in 1:length(obs)
50+
obs[i] ~ Bernoulli(p)
51+
end
52+
return p, x
53+
end
54+
55+
smc = SMC()
56+
pg = PG(10)
57+
gibbs = Gibbs(HMC(0.2, 3, :p), PG(10, :x))
58+
59+
chn_s = sample(testbb(obs), smc, 1000)
60+
chn_p = sample(testbb(obs), pg, 2000)
61+
chn_g = sample(testbb(obs), gibbs, 1500)
62+
63+
check_numerical(chn_s, [:p], [meanp]; atol=0.05)
64+
check_numerical(chn_p, [:x], [meanp]; atol=0.1)
65+
check_numerical(chn_g, [:x], [meanp]; atol=0.1)
66+
end
67+
@testset "forbid global" begin
68+
xs = [1.5 2.0]
69+
# xx = 1
70+
71+
@model function fggibbstest(xs)
72+
s ~ InverseGamma(2, 3)
73+
m ~ Normal(0, sqrt(s))
74+
# xx ~ Normal(m, sqrt(s)) # this is illegal
75+
76+
for i in 1:length(xs)
77+
xs[i] ~ Normal(m, sqrt(s))
78+
# for xx in xs
79+
# xx ~ Normal(m, sqrt(s))
80+
end
81+
return s, m
82+
end
83+
84+
gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m))
85+
chain = sample(fggibbstest(xs), gibbs, 2)
86+
end
87+
@testset "new grammar" begin
88+
x = Float64[1 2]
89+
90+
@model function gauss(x)
91+
priors = Array{Float64}(undef, 2)
92+
priors[1] ~ InverseGamma(2, 3) # s
93+
priors[2] ~ Normal(0, sqrt(priors[1])) # m
94+
for i in 1:length(x)
95+
x[i] ~ Normal(priors[2], sqrt(priors[1]))
96+
end
97+
return priors
98+
end
99+
100+
chain = sample(gauss(x), PG(10), 10)
101+
chain = sample(gauss(x), SMC(), 10)
102+
103+
@model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV}
104+
priors = TV(undef, 2)
105+
priors[1] ~ InverseGamma(2, 3) # s
106+
priors[2] ~ Normal(0, sqrt(priors[1])) # m
107+
for i in 1:length(x)
108+
x[i] ~ Normal(priors[2], sqrt(priors[1]))
109+
end
110+
return priors
111+
end
112+
113+
@test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10)
114+
@test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10)
115+
116+
@test_throws ErrorException chain = sample(
117+
gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10
118+
)
119+
@test_throws ErrorException chain = sample(
120+
gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10
121+
)
122+
end
123+
@testset "new interface" begin
124+
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
125+
126+
@model function newinterface(obs)
127+
p ~ Beta(2, 2)
128+
for i in 1:length(obs)
129+
obs[i] ~ Bernoulli(p)
130+
end
131+
return p
132+
end
133+
134+
chain = sample(
135+
newinterface(obs),
136+
HMC(0.75, 3, :p, :x; adtype=AutoForwardDiff(; chunksize=2)),
137+
100,
138+
)
139+
end
140+
@testset "no return" begin
141+
@model function noreturn(x)
142+
s ~ InverseGamma(2, 3)
143+
m ~ Normal(0, sqrt(s))
144+
for i in 1:length(x)
145+
x[i] ~ Normal(m, sqrt(s))
146+
end
147+
end
148+
149+
chain = sample(noreturn([1.5 2.0]), HMC(0.15, 6), 1000)
150+
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6])
151+
end
152+
@testset "observe" begin
153+
@model function test()
154+
z ~ Normal(0, 1)
155+
x ~ Bernoulli(1)
156+
1 ~ Bernoulli(x / 2)
157+
0 ~ Bernoulli(x / 2)
158+
return x
159+
end
160+
161+
is = IS()
162+
smc = SMC()
163+
pg = PG(10)
164+
165+
res_is = sample(test(), is, 10000)
166+
res_smc = sample(test(), smc, 1000)
167+
res_pg = sample(test(), pg, 100)
168+
169+
@test all(isone, res_is[:x])
170+
@test res_is.logevidence 2 * log(0.5)
171+
172+
@test all(isone, res_smc[:x])
173+
@test res_smc.logevidence 2 * log(0.5)
174+
175+
@test all(isone, res_pg[:x])
176+
end
177+
178+
@testset "sample" begin
179+
alg = Gibbs(HMC(0.2, 3, :m), PG(10, :s))
180+
chn = sample(gdemo_default, alg, 1000)
181+
end
182+
183+
@testset "vectorization @." begin
184+
@model function vdemo1(x)
185+
s ~ InverseGamma(2, 3)
186+
m ~ Normal(0, sqrt(s))
187+
@. x ~ Normal(m, sqrt(s))
188+
return s, m
189+
end
190+
191+
alg = HMC(0.01, 5)
192+
x = randn(100)
193+
res = sample(vdemo1(x), alg, 250)
194+
195+
@model function vdemo1b(x)
196+
s ~ InverseGamma(2, 3)
197+
m ~ Normal(0, sqrt(s))
198+
@. x ~ Normal(m, $(sqrt(s)))
199+
return s, m
200+
end
201+
202+
res = sample(vdemo1b(x), alg, 250)
203+
204+
@model function vdemo2(x)
205+
μ ~ MvNormal(zeros(size(x, 1)), I)
206+
@. x ~ $(MvNormal(μ, I))
207+
end
208+
209+
D = 2
210+
alg = HMC(0.01, 5)
211+
res = sample(vdemo2(randn(D, 100)), alg, 250)
212+
213+
# Vector assumptions
214+
N = 10
215+
alg = HMC(0.2, 4; adtype=AutoForwardDiff(; chunksize=N))
216+
217+
@model function vdemo3()
218+
x = Vector{Real}(undef, N)
219+
for i in 1:N
220+
x[i] ~ Normal(0, sqrt(4))
221+
end
222+
end
223+
224+
t_loop = @elapsed res = sample(vdemo3(), alg, 1000)
225+
226+
# Test for vectorize UnivariateDistribution
227+
@model function vdemo4()
228+
x = Vector{Real}(undef, N)
229+
@. x ~ Normal(0, 2)
230+
end
231+
232+
t_vec = @elapsed res = sample(vdemo4(), alg, 1000)
233+
234+
@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)
235+
236+
t_mv = @elapsed res = sample(vdemo5(), alg, 1000)
237+
238+
println("Time for")
239+
println(" Loop : ", t_loop)
240+
println(" Vec : ", t_vec)
241+
println(" Mv : ", t_mv)
242+
243+
# Transformed test
244+
@model function vdemo6()
245+
x = Vector{Real}(undef, N)
246+
@. x ~ InverseGamma(2, 3)
247+
end
248+
249+
sample(vdemo6(), alg, 1000)
250+
251+
N = 3
252+
@model function vdemo7()
253+
x = Array{Real}(undef, N, N)
254+
@. x ~ [InverseGamma(2, 3) for i in 1:N]
255+
end
256+
257+
sample(vdemo7(), alg, 1000)
258+
end
259+
@testset "vectorization .~" begin
260+
@model function vdemo1(x)
261+
s ~ InverseGamma(2, 3)
262+
m ~ Normal(0, sqrt(s))
263+
x .~ Normal(m, sqrt(s))
264+
return s, m
265+
end
266+
267+
alg = HMC(0.01, 5)
268+
x = randn(100)
269+
res = sample(vdemo1(x), alg, 250)
270+
271+
@model function vdemo2(x)
272+
μ ~ MvNormal(zeros(size(x, 1)), I)
273+
return x .~ MvNormal(μ, I)
274+
end
275+
276+
D = 2
277+
alg = HMC(0.01, 5)
278+
res = sample(vdemo2(randn(D, 100)), alg, 250)
279+
280+
# Vector assumptions
281+
N = 10
282+
alg = HMC(0.2, 4; adtype=AutoForwardDiff(; chunksize=N))
283+
284+
@model function vdemo3()
285+
x = Vector{Real}(undef, N)
286+
for i in 1:N
287+
x[i] ~ Normal(0, sqrt(4))
288+
end
289+
end
290+
291+
t_loop = @elapsed res = sample(vdemo3(), alg, 1000)
292+
293+
# Test for vectorize UnivariateDistribution
294+
@model function vdemo4()
295+
x = Vector{Real}(undef, N)
296+
return x .~ Normal(0, 2)
297+
end
298+
299+
t_vec = @elapsed res = sample(vdemo4(), alg, 1000)
300+
301+
@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)
302+
303+
t_mv = @elapsed res = sample(vdemo5(), alg, 1000)
304+
305+
println("Time for")
306+
println(" Loop : ", t_loop)
307+
println(" Vec : ", t_vec)
308+
println(" Mv : ", t_mv)
309+
310+
# Transformed test
311+
@model function vdemo6()
312+
x = Vector{Real}(undef, N)
313+
return x .~ InverseGamma(2, 3)
314+
end
315+
316+
sample(vdemo6(), alg, 1000)
317+
318+
@model function vdemo7()
319+
x = Array{Real}(undef, N, N)
320+
return x .~ [InverseGamma(2, 3) for i in 1:N]
321+
end
322+
323+
sample(vdemo7(), alg, 1000)
324+
end
325+
@testset "Type parameters" begin
326+
N = 10
327+
alg = HMC(0.01, 5; adtype=AutoForwardDiff(; chunksize=N))
328+
x = randn(1000)
329+
@model function vdemo1(::Type{T}=Float64) where {T}
330+
x = Vector{T}(undef, N)
331+
for i in 1:N
332+
x[i] ~ Normal(0, sqrt(4))
333+
end
334+
end
335+
336+
t_loop = @elapsed res = sample(vdemo1(), alg, 250)
337+
t_loop = @elapsed res = sample(vdemo1(DynamicPPL.TypeWrap{Float64}()), alg, 250)
338+
339+
vdemo1kw(; T) = vdemo1(T)
340+
t_loop = @elapsed res = sample(
341+
vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250
342+
)
343+
344+
@model function vdemo2(::Type{T}=Float64) where {T<:Real}
345+
x = Vector{T}(undef, N)
346+
@. x ~ Normal(0, 2)
347+
end
348+
349+
t_vec = @elapsed res = sample(vdemo2(), alg, 250)
350+
t_vec = @elapsed res = sample(vdemo2(DynamicPPL.TypeWrap{Float64}()), alg, 250)
351+
352+
vdemo2kw(; T) = vdemo2(T)
353+
t_vec = @elapsed res = sample(
354+
vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250
355+
)
356+
357+
@model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector}
358+
x = TV(undef, N)
359+
@. x ~ InverseGamma(2, 3)
360+
end
361+
362+
sample(vdemo3(), alg, 250)
363+
sample(vdemo3(DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250)
364+
365+
vdemo3kw(; T) = vdemo3(T)
366+
sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250)
367+
end
368+
end
369+
370+
end # module

0 commit comments

Comments
 (0)