Skip to content

Commit 685a27d

Browse files
committed
Add TestUtils submodule (#313)
This PR adds a `DynamicPPL.TestUtils` submodule which is meant to include functionality to make it easy to test new samplers, new implementations of `AbstractVarInfo`, etc. As of right now, this is mainly just a collection of models with equivalent marginal posteriors using the different features of DPPL, e.g. some are using `.~`, some are using `@submodel`, etc. Eventually this should be expanded to be of more use, but more immediately this will be useful to test functionality in open PRs, e.g. #269, #309, #295, #292. These models are also already used in Turing.jl's test-suite (https://github.com/TuringLang/Turing.jl/blob/9f52d75c25390b68115624b2e6cf464275a88137/test/test_utils/models.jl#L55-L56), so this PR would avoid the code-duplication + make it easier to keep things up-to-date.
1 parent 86afffa commit 685a27d

File tree

4 files changed

+214
-113
lines changed

4 files changed

+214
-113
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.15.1"
3+
version = "0.15.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1415
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1516

1617
[compat]

src/DynamicPPL.jl

+2
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,6 @@ include("compat/ad.jl")
134134
include("loglikelihoods.jl")
135135
include("submodel_macro.jl")
136136

137+
include("test_utils.jl")
138+
137139
end # module

src/test_utils.jl

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
module TestUtils
2+
3+
using AbstractMCMC
4+
using DynamicPPL
5+
using LinearAlgebra
6+
using Distributions
7+
using Test
8+
9+
# A collection of models for which the mean-of-means for the posterior should
10+
# be same.
11+
@model function demo_dot_assume_dot_observe(
12+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
13+
) where {TV}
14+
# `dot_assume` and `observe`
15+
m = TV(undef, length(x))
16+
m .~ Normal()
17+
x ~ MvNormal(m, 0.25 * I)
18+
return (; m=m, x=x, logp=getlogp(__varinfo__))
19+
end
20+
21+
@model function demo_assume_index_observe(
22+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
23+
) where {TV}
24+
# `assume` with indexing and `observe`
25+
m = TV(undef, length(x))
26+
for i in eachindex(m)
27+
m[i] ~ Normal()
28+
end
29+
x ~ MvNormal(m, 0.25 * I)
30+
31+
return (; m=m, x=x, logp=getlogp(__varinfo__))
32+
end
33+
34+
@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0])
35+
# Multivariate `assume` and `observe`
36+
m ~ MvNormal(zero(x), I)
37+
x ~ MvNormal(m, 0.25 * I)
38+
39+
return (; m=m, x=x, logp=getlogp(__varinfo__))
40+
end
41+
42+
@model function demo_dot_assume_observe_index(
43+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
44+
) where {TV}
45+
# `dot_assume` and `observe` with indexing
46+
m = TV(undef, length(x))
47+
m .~ Normal()
48+
for i in eachindex(x)
49+
x[i] ~ Normal(m[i], 0.5)
50+
end
51+
52+
return (; m=m, x=x, logp=getlogp(__varinfo__))
53+
end
54+
55+
# Using vector of `length` 1 here so the posterior of `m` is the same
56+
# as the others.
57+
@model function demo_assume_dot_observe(x=[10.0])
58+
# `assume` and `dot_observe`
59+
m ~ Normal()
60+
x .~ Normal(m, 0.5)
61+
62+
return (; m=m, x=x, logp=getlogp(__varinfo__))
63+
end
64+
65+
@model function demo_assume_observe_literal()
66+
# `assume` and literal `observe`
67+
m ~ MvNormal(zeros(2), I)
68+
[10.0, 10.0] ~ MvNormal(m, 0.25 * I)
69+
70+
return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__))
71+
end
72+
73+
@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
74+
# `dot_assume` and literal `observe` with indexing
75+
m = TV(undef, 2)
76+
m .~ Normal()
77+
for i in eachindex(m)
78+
10.0 ~ Normal(m[i], 0.5)
79+
end
80+
81+
return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__))
82+
end
83+
84+
@model function demo_assume_literal_dot_observe()
85+
# `assume` and literal `dot_observe`
86+
m ~ Normal()
87+
[10.0] .~ Normal(m, 0.5)
88+
89+
return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
90+
end
91+
92+
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
93+
m = TV(undef, 2)
94+
m .~ Normal()
95+
96+
return m
97+
end
98+
99+
@model function demo_assume_submodel_observe_index_literal()
100+
# Submodel prior
101+
m = @submodel _prior_dot_assume()
102+
for i in eachindex(m)
103+
10.0 ~ Normal(m[i], 0.5)
104+
end
105+
106+
return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
107+
end
108+
109+
@model function _likelihood_dot_observe(m, x)
110+
return x ~ MvNormal(m, 0.25 * I)
111+
end
112+
113+
@model function demo_dot_assume_observe_submodel(
114+
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
115+
) where {TV}
116+
m = TV(undef, length(x))
117+
m .~ Normal()
118+
119+
# Submodel likelihood
120+
@submodel _likelihood_dot_observe(m, x)
121+
122+
return (; m=m, x=x, logp=getlogp(__varinfo__))
123+
end
124+
125+
@model function demo_dot_assume_dot_observe_matrix(
126+
x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}
127+
) where {TV}
128+
m = TV(undef, length(x))
129+
m .~ Normal()
130+
131+
# Dotted observe for `Matrix`.
132+
x .~ MvNormal(m, 0.25 * I)
133+
134+
return (; m=m, x=x, logp=getlogp(__varinfo__))
135+
end
136+
137+
const DEMO_MODELS = (
138+
demo_dot_assume_dot_observe(),
139+
demo_assume_index_observe(),
140+
demo_assume_multivariate_observe_index(),
141+
demo_dot_assume_observe_index(),
142+
demo_assume_dot_observe(),
143+
demo_assume_observe_literal(),
144+
demo_dot_assume_observe_index_literal(),
145+
demo_assume_literal_dot_observe(),
146+
demo_assume_submodel_observe_index_literal(),
147+
demo_dot_assume_observe_submodel(),
148+
demo_dot_assume_dot_observe_matrix(),
149+
)
150+
151+
# TODO: Is this really the best/most convenient "default" test method?
152+
"""
153+
test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)
154+
155+
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.
156+
157+
In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the
158+
`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target`
159+
provided in `kwargs...`.
160+
161+
# Arguments
162+
- `meanfunction`: A callable which computes the mean of the marginal means from the
163+
chain resulting from the `sample` call.
164+
- `sampler`: The `AbstractMCMC.AbstractSampler` to test.
165+
- `args...`: Arguments forwarded to `sample`.
166+
167+
# Keyword arguments
168+
- `target`: Value to compare result of `meanfunction(chain)` to.
169+
- `atol=1e-1`: Absolute tolerance used in `@test`.
170+
- `rtol=1e-3`: Relative tolerance used in `@test`.
171+
- `kwargs...`: Keyword arguments forwarded to `sample`.
172+
"""
173+
function test_sampler_demo_models(
174+
meanfunction,
175+
sampler::AbstractMCMC.AbstractSampler,
176+
args...;
177+
target=8.0,
178+
atol=1e-1,
179+
rtol=1e-3,
180+
kwargs...,
181+
)
182+
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS
183+
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...)
184+
μ = meanfunction(chain)
185+
@test μ target atol = atol rtol = rtol
186+
end
187+
end
188+
189+
"""
190+
test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...)
191+
192+
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.
193+
194+
As of right now, this is just an alias for [`test_sampler_demo_models`](@ref).
195+
"""
196+
function test_sampler_continuous(
197+
meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
198+
)
199+
return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)
200+
end
201+
202+
function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
203+
# Default for `MCMCChains.Chains`.
204+
return test_sampler_continuous(sampler, args...; kwargs...) do chain
205+
mean(Array(chain))
206+
end
207+
end
208+
209+
end

test/loglikelihoods.jl

+1-112
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,5 @@
1-
# A collection of models for which the mean-of-means for the posterior should
2-
# be same.
3-
@model function gdemo1(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
4-
# `dot_assume` and `observe`
5-
m = TV(undef, length(x))
6-
m .~ Normal()
7-
return x ~ MvNormal(m, 0.25 * I)
8-
end
9-
10-
@model function gdemo2(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
11-
# `assume` with indexing and `observe`
12-
m = TV(undef, length(x))
13-
for i in eachindex(m)
14-
m[i] ~ Normal()
15-
end
16-
return x ~ MvNormal(m, 0.25 * I)
17-
end
18-
19-
@model function gdemo3(x=[10.0, 10.0])
20-
# Multivariate `assume` and `observe`
21-
m ~ MvNormal(zero(x), I)
22-
return x ~ MvNormal(m, 0.25 * I)
23-
end
24-
25-
@model function gdemo4(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
26-
# `dot_assume` and `observe` with indexing
27-
m = TV(undef, length(x))
28-
m .~ Normal()
29-
for i in eachindex(x)
30-
x[i] ~ Normal(m[i], 0.5)
31-
end
32-
end
33-
34-
# Using vector of `length` 1 here so the posterior of `m` is the same
35-
# as the others.
36-
@model function gdemo5(x=[10.0])
37-
# `assume` and `dot_observe`
38-
m ~ Normal()
39-
return x .~ Normal(m, 0.5)
40-
end
41-
42-
@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV}
43-
# `assume` and literal `observe`
44-
m ~ MvNormal(zeros(2), I)
45-
return [10.0, 10.0] ~ MvNormal(m, 0.25 * I)
46-
end
47-
48-
@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV}
49-
# `dot_assume` and literal `observe` with indexing
50-
m = TV(undef, 2)
51-
m .~ Normal()
52-
for i in eachindex(m)
53-
10.0 ~ Normal(m[i], 0.5)
54-
end
55-
end
56-
57-
@model function gdemo8(::Type{TV}=Vector{Float64}) where {TV}
58-
# `assume` and literal `dot_observe`
59-
m ~ Normal()
60-
return [10.0] .~ Normal(m, 0.5)
61-
end
62-
63-
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
64-
m = TV(undef, 2)
65-
m .~ Normal()
66-
67-
return m
68-
end
69-
70-
@model function gdemo9()
71-
# Submodel prior
72-
m = @submodel _prior_dot_assume()
73-
for i in eachindex(m)
74-
10.0 ~ Normal(m[i], 0.5)
75-
end
76-
end
77-
78-
@model function _likelihood_dot_observe(m, x)
79-
return x ~ MvNormal(m, 0.25 * I)
80-
end
81-
82-
@model function gdemo10(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
83-
m = TV(undef, length(x))
84-
m .~ Normal()
85-
86-
# Submodel likelihood
87-
@submodel _likelihood_dot_observe(m, x)
88-
end
89-
90-
@model function gdemo11(x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}) where {TV}
91-
m = TV(undef, length(x))
92-
m .~ Normal()
93-
94-
# Dotted observe for `Matrix`.
95-
return x .~ MvNormal(m, 0.25 * I)
96-
end
97-
98-
const gdemo_models = (
99-
gdemo1(),
100-
gdemo2(),
101-
gdemo3(),
102-
gdemo4(),
103-
gdemo5(),
104-
gdemo6(),
105-
gdemo7(),
106-
gdemo8(),
107-
gdemo9(),
108-
gdemo10(),
109-
gdemo11(),
110-
)
111-
1121
@testset "loglikelihoods.jl" begin
113-
for m in gdemo_models
2+
for m in DynamicPPL.TestUtils.DEMO_MODELS
1143
vi = VarInfo(m)
1154

1165
vns = vi.metadata.m.vns

0 commit comments

Comments
 (0)