Skip to content

Commit 8143144

Browse files
committed
Add TestUtils submodule (#313)
This PR adds a `DynamicPPL.TestUtils` submodule which is meant to include functionality to make it easy to test new samplers, new implementations of `AbstractVarInfo`, etc. As of right now, this is mainly just a collection of models with equivalent marginal posteriors using the different features of DPPL, e.g. some are using `.~`, some are using `@submodel`, etc. Eventually this should be expanded to be of more use, but more immediately this will be useful to test functionality in open PRs, e.g. #269, #309, #295, #292. These models are also already used in Turing.jl's test-suite (https://github.com/TuringLang/Turing.jl/blob/9f52d75c25390b68115624b2e6cf464275a88137/test/test_utils/models.jl#L55-L56), so this PR would avoid the code-duplication + make it easier to keep things up-to-date.
1 parent 86afffa commit 8143144

File tree

4 files changed

+213
-113
lines changed

4 files changed

+213
-113
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.15.1"
3+
version = "0.15.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1415
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1516

1617
[compat]

src/DynamicPPL.jl

+2
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,6 @@ include("compat/ad.jl")
134134
include("loglikelihoods.jl")
135135
include("submodel_macro.jl")
136136

137+
include("test_utils.jl")
138+
137139
end # module

src/test_utils.jl

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
module TestUtils
2+
3+
using AbstractMCMC
4+
using DynamicPPL
5+
using Distributions
6+
using Test
7+
8+
# A collection of models for which the mean-of-means for the posterior should
9+
# be same.
10+
@model function demo_dot_assume_dot_observe(
11+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
12+
) where {TV}
13+
# `dot_assume` and `observe`
14+
m = TV(undef, length(x))
15+
m .~ Normal()
16+
x ~ MvNormal(m, 0.25 * I)
17+
return (; m=m, x=x, logp=getlogp(__varinfo__))
18+
end
19+
20+
@model function demo_assume_index_observe(
21+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
22+
) where {TV}
23+
# `assume` with indexing and `observe`
24+
m = TV(undef, length(x))
25+
for i in eachindex(m)
26+
m[i] ~ Normal()
27+
end
28+
x ~ MvNormal(m, 0.25 * I)
29+
30+
return (; m=m, x=x, logp=getlogp(__varinfo__))
31+
end
32+
33+
@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0])
34+
# Multivariate `assume` and `observe`
35+
m ~ MvNormal(zero(x), I)
36+
x ~ MvNormal(m, 0.25 * I)
37+
38+
return (; m=m, x=x, logp=getlogp(__varinfo__))
39+
end
40+
41+
@model function demo_dot_assume_observe_index(
42+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
43+
) where {TV}
44+
# `dot_assume` and `observe` with indexing
45+
m = TV(undef, length(x))
46+
m .~ Normal()
47+
for i in eachindex(x)
48+
x[i] ~ Normal(m[i], 0.5)
49+
end
50+
51+
return (; m=m, x=x, logp=getlogp(__varinfo__))
52+
end
53+
54+
# Using vector of `length` 1 here so the posterior of `m` is the same
55+
# as the others.
56+
@model function demo_assume_dot_observe(x=[10.0])
57+
# `assume` and `dot_observe`
58+
m ~ Normal()
59+
x .~ Normal(m, 0.5)
60+
61+
return (; m=m, x=x, logp=getlogp(__varinfo__))
62+
end
63+
64+
@model function demo_assume_observe_literal()
65+
# `assume` and literal `observe`
66+
m ~ MvNormal(zeros(2), I)
67+
[10.0, 10.0] ~ MvNormal(m, 0.25 * I)
68+
69+
return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__))
70+
end
71+
72+
@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
73+
# `dot_assume` and literal `observe` with indexing
74+
m = TV(undef, 2)
75+
m .~ Normal()
76+
for i in eachindex(m)
77+
10.0 ~ Normal(m[i], 0.5)
78+
end
79+
80+
return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__))
81+
end
82+
83+
@model function demo_assume_literal_dot_observe()
84+
# `assume` and literal `dot_observe`
85+
m ~ Normal()
86+
[10.0] .~ Normal(m, 0.5)
87+
88+
return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
89+
end
90+
91+
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
92+
m = TV(undef, 2)
93+
m .~ Normal()
94+
95+
return m
96+
end
97+
98+
@model function demo_assume_submodel_observe_index_literal()
99+
# Submodel prior
100+
m = @submodel _prior_dot_assume()
101+
for i in eachindex(m)
102+
10.0 ~ Normal(m[i], 0.5)
103+
end
104+
105+
return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
106+
end
107+
108+
@model function _likelihood_dot_observe(m, x)
109+
return x ~ MvNormal(m, 0.25 * I)
110+
end
111+
112+
@model function demo_dot_assume_observe_submodel(
113+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
114+
) where {TV}
115+
m = TV(undef, length(x))
116+
m .~ Normal()
117+
118+
# Submodel likelihood
119+
@submodel _likelihood_dot_observe(m, x)
120+
121+
return (; m=m, x=x, logp=getlogp(__varinfo__))
122+
end
123+
124+
@model function demo_dot_assume_dot_observe_matrix(
125+
x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}
126+
) where {TV}
127+
m = TV(undef, length(x))
128+
m .~ Normal()
129+
130+
# Dotted observe for `Matrix`.
131+
x .~ MvNormal(m, 0.25 * I)
132+
133+
return (; m=m, x=x, logp=getlogp(__varinfo__))
134+
end
135+
136+
const DEMO_MODELS = (
137+
demo_dot_assume_dot_observe(),
138+
demo_assume_index_observe(),
139+
demo_assume_multivariate_observe_index(),
140+
demo_dot_assume_observe_index(),
141+
demo_assume_dot_observe(),
142+
demo_assume_observe_literal(),
143+
demo_dot_assume_observe_index_literal(),
144+
demo_assume_literal_dot_observe(),
145+
demo_assume_submodel_observe_index_literal(),
146+
demo_dot_assume_observe_submodel(),
147+
demo_dot_assume_dot_observe_matrix(),
148+
)
149+
150+
# TODO: Is this really the best/most convenient "default" test method?
151+
"""
152+
test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)
153+
154+
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.
155+
156+
In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the
157+
`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target`
158+
provided in `kwargs...`.
159+
160+
# Arguments
161+
- `meanfunction`: A callable which computes the mean of the marginal means from the
162+
chain resulting from the `sample` call.
163+
- `sampler`: The `AbstractMCMC.AbstractSampler` to test.
164+
- `args...`: Arguments forwarded to `sample`.
165+
166+
# Keyword arguments
167+
- `target`: Value to compare result of `meanfunction(chain)` to.
168+
- `atol=1e-1`: Absolute tolerance used in `@test`.
169+
- `rtol=1e-3`: Relative tolerance used in `@test`.
170+
- `kwargs...`: Keyword arguments forwarded to `sample`.
171+
"""
172+
function test_sampler_demo_models(
173+
meanfunction,
174+
sampler::AbstractMCMC.AbstractSampler,
175+
args...;
176+
target=8.0,
177+
atol=1e-1,
178+
rtol=1e-3,
179+
kwargs...,
180+
)
181+
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS
182+
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...)
183+
μ = meanfunction(chain)
184+
@test μ target atol = atol rtol = rtol
185+
end
186+
end
187+
188+
"""
189+
test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...)
190+
191+
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.
192+
193+
As of right now, this is just an alias for [`test_sampler_demo_models`](@ref).
194+
"""
195+
function test_sampler_continuous(
196+
meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
197+
)
198+
return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)
199+
end
200+
201+
function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
202+
# Default for `MCMCChains.Chains`.
203+
return test_sampler_continuous(sampler, args...; kwargs...) do chain
204+
mean(Array(chain))
205+
end
206+
end
207+
208+
end

test/loglikelihoods.jl

+1-112
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,5 @@
1-
# A collection of models for which the mean-of-means for the posterior should
2-
# be same.
3-
@model function gdemo1(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
4-
# `dot_assume` and `observe`
5-
m = TV(undef, length(x))
6-
m .~ Normal()
7-
return x ~ MvNormal(m, 0.25 * I)
8-
end
9-
10-
@model function gdemo2(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
11-
# `assume` with indexing and `observe`
12-
m = TV(undef, length(x))
13-
for i in eachindex(m)
14-
m[i] ~ Normal()
15-
end
16-
return x ~ MvNormal(m, 0.25 * I)
17-
end
18-
19-
@model function gdemo3(x=[10.0, 10.0])
20-
# Multivariate `assume` and `observe`
21-
m ~ MvNormal(zero(x), I)
22-
return x ~ MvNormal(m, 0.25 * I)
23-
end
24-
25-
@model function gdemo4(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
26-
# `dot_assume` and `observe` with indexing
27-
m = TV(undef, length(x))
28-
m .~ Normal()
29-
for i in eachindex(x)
30-
x[i] ~ Normal(m[i], 0.5)
31-
end
32-
end
33-
34-
# Using vector of `length` 1 here so the posterior of `m` is the same
35-
# as the others.
36-
@model function gdemo5(x=[10.0])
37-
# `assume` and `dot_observe`
38-
m ~ Normal()
39-
return x .~ Normal(m, 0.5)
40-
end
41-
42-
@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV}
43-
# `assume` and literal `observe`
44-
m ~ MvNormal(zeros(2), I)
45-
return [10.0, 10.0] ~ MvNormal(m, 0.25 * I)
46-
end
47-
48-
@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV}
49-
# `dot_assume` and literal `observe` with indexing
50-
m = TV(undef, 2)
51-
m .~ Normal()
52-
for i in eachindex(m)
53-
10.0 ~ Normal(m[i], 0.5)
54-
end
55-
end
56-
57-
@model function gdemo8(::Type{TV}=Vector{Float64}) where {TV}
58-
# `assume` and literal `dot_observe`
59-
m ~ Normal()
60-
return [10.0] .~ Normal(m, 0.5)
61-
end
62-
63-
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
64-
m = TV(undef, 2)
65-
m .~ Normal()
66-
67-
return m
68-
end
69-
70-
@model function gdemo9()
71-
# Submodel prior
72-
m = @submodel _prior_dot_assume()
73-
for i in eachindex(m)
74-
10.0 ~ Normal(m[i], 0.5)
75-
end
76-
end
77-
78-
@model function _likelihood_dot_observe(m, x)
79-
return x ~ MvNormal(m, 0.25 * I)
80-
end
81-
82-
@model function gdemo10(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
83-
m = TV(undef, length(x))
84-
m .~ Normal()
85-
86-
# Submodel likelihood
87-
@submodel _likelihood_dot_observe(m, x)
88-
end
89-
90-
@model function gdemo11(x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}) where {TV}
91-
m = TV(undef, length(x))
92-
m .~ Normal()
93-
94-
# Dotted observe for `Matrix`.
95-
return x .~ MvNormal(m, 0.25 * I)
96-
end
97-
98-
const gdemo_models = (
99-
gdemo1(),
100-
gdemo2(),
101-
gdemo3(),
102-
gdemo4(),
103-
gdemo5(),
104-
gdemo6(),
105-
gdemo7(),
106-
gdemo8(),
107-
gdemo9(),
108-
gdemo10(),
109-
gdemo11(),
110-
)
111-
1121
@testset "loglikelihoods.jl" begin
113-
for m in gdemo_models
2+
for m in DynamicPPL.TestUtils.DEMO_MODELS
1143
vi = VarInfo(m)
1154

1165
vns = vi.metadata.m.vns

0 commit comments

Comments
 (0)