Skip to content

Commit fa0f068

Browse files
committed
Update ESS to use LDF
1 parent 49f6988 commit fa0f068

File tree

3 files changed

+69
-42
lines changed

3 files changed

+69
-42
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
# Because this is a pain to implement all at once, we do it for one sampler at a time.
2323
# This type tells us which samplers have been 'updated' to the new interface.
24-
25-
const LDFCompatibleSampler = Union{Hamiltonian}
24+
const LDFCompatibleSampler = Union{Hamiltonian,ESS}
2625

2726
"""
2827
sample(

src/mcmc/ess.jl

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,26 @@ Mean
2020
│ 1 │ m │ 0.824853 │
2121
```
2222
"""
23-
struct ESS <: InferenceAlgorithm end
23+
struct ESS <: AbstractSampler end
24+
25+
DynamicPPL.initialsampler(::ESS) = DynamicPPL.SampleFromPrior()
26+
update_sample_kwargs(::ESS, ::Integer, kwargs) = kwargs
27+
get_adtype(::ESS) = nothing
28+
requires_unconstrained_space(::ESS) = false
2429

2530
# always accept in the first step
26-
function DynamicPPL.initialstep(
27-
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
28-
)
31+
function AbstractMCMC.step(rng::AbstractRNG, ldf::LogDensityFunction, spl::ESS; kwargs...)
32+
vi = ldf.varinfo
2933
for vn in keys(vi)
3034
dist = getdist(vi, vn)
3135
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
3236
error("ESS only supports Gaussian prior distributions")
3337
end
34-
return Transition(model, vi), vi
38+
return Transition(ldf.model, vi), vi
3539
end
3640

3741
function AbstractMCMC.step(
38-
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
42+
rng::AbstractRNG, ldf::LogDensityFunction, spl::ESS, vi::AbstractVarInfo; kwargs...
3943
)
4044
# obtain previous sample
4145
f = vi[:]
@@ -45,14 +49,13 @@ function AbstractMCMC.step(
4549
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing)
4650

4751
# compute next state
52+
# Note: `f_loglikelihood` effectively calculates the log-likelihood (not
53+
# log-joint, despite the use of `LDP.logdensity`) because `tilde_assume` is
54+
# overloaded on `SamplingContext(rng, ESS(), ...)` below.
55+
f_loglikelihood = Base.Fix1(LogDensityProblems.logdensity, ldf)
4856
sample, state = AbstractMCMC.step(
4957
rng,
50-
EllipticalSliceSampling.ESSModel(
51-
ESSPrior(model, spl, vi),
52-
DynamicPPL.LogDensityFunction(
53-
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
54-
),
55-
),
58+
EllipticalSliceSampling.ESSModel(ESSPrior(ldf.model, spl, vi), f_loglikelihood),
5659
EllipticalSliceSampling.ESS(),
5760
oldstate,
5861
)
@@ -61,67 +64,57 @@ function AbstractMCMC.step(
6164
vi = DynamicPPL.unflatten(vi, sample)
6265
vi = setlogp!!(vi, state.loglikelihood)
6366

64-
return Transition(model, vi), vi
67+
return Transition(ldf.model, vi), vi
6568
end
6669

6770
# Prior distribution of considered random variable
68-
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
71+
struct ESSPrior{M<:Model,V<:AbstractVarInfo,T}
6972
model::M
70-
sampler::S
73+
sampler::ESS
7174
varinfo::V
7275
μ::T
7376

74-
function ESSPrior{M,S,V}(
75-
model::M, sampler::S, varinfo::V
76-
) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
77+
function ESSPrior{M,V}(
78+
model::M, sampler::ESS, varinfo::V
79+
) where {M<:Model,V<:AbstractVarInfo}
7780
vns = keys(varinfo)
7881
μ = mapreduce(vcat, vns) do vn
7982
dist = getdist(varinfo, vn)
8083
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
8184
error("[ESS] only supports Gaussian prior distributions")
8285
DynamicPPL.tovec(mean(dist))
8386
end
84-
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
87+
return new{M,V,typeof(μ)}(model, sampler, varinfo, μ)
8588
end
8689
end
8790

88-
function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo)
89-
return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo)
91+
function ESSPrior(model::Model, sampler::ESS, varinfo::AbstractVarInfo)
92+
return ESSPrior{typeof(model),typeof(varinfo)}(model, sampler, varinfo)
9093
end
9194

9295
# Ensure that the prior is a Gaussian distribution (checked in the constructor)
9396
EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
9497

9598
# Only define out-of-place sampling
9699
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
97-
sampler = p.sampler
98-
varinfo = p.varinfo
99-
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
100-
vns = keys(varinfo)
101-
for vn in vns
102-
set_flag!(varinfo, vn, "del")
100+
# TODO(penelopeysm): This is ugly -- need to set 'del' flag because
101+
# otherwise DynamicPPL.SampleWithPrior will just use the existing
102+
# parameters in the varinfo. In general SampleWithPrior etc. need to be
103+
# reworked.
104+
for vn in keys(p.varinfo)
105+
set_flag!(p.varinfo, vn, "del")
103106
end
104-
p.model(rng, varinfo, sampler)
105-
return varinfo[:]
107+
_, vi = DynamicPPL.evaluate!!(p.model, p.varinfo, SamplingContext(rng, p.sampler))
108+
return vi[:]
106109
end
107110

108111
# Mean of prior distribution
109112
Distributions.mean(p::ESSPrior) = p.μ
110113

111-
# Evaluate log-likelihood of proposals
112-
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
113-
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
114-
115-
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)
116-
117114
function DynamicPPL.tilde_assume(
118-
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
115+
rng::Random.AbstractRNG, ::DefaultContext, ::ESS, right, vn, vi
119116
)
120117
return DynamicPPL.tilde_assume(
121118
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
122119
)
123120
end
124-
125-
function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
126-
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
127-
end

test/mcmc/ess.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,41 @@ using Turing
1313
@testset "ESS" begin
1414
@info "Starting ESS tests"
1515

16+
@testset "InferenceAlgorithm interface" begin
17+
alg = ESS()
18+
@test Turing.Inference.get_adtype(alg) === nothing
19+
@test !Turing.Inference.requires_unconstrained_space(alg)
20+
kwargs = (; _foo="bar")
21+
@test Turing.Inference.update_sample_kwargs(alg, 1000, kwargs) == kwargs
22+
end
23+
24+
@testset "sample() interface" begin
25+
@model function demo_normal(x)
26+
a ~ Normal()
27+
return x ~ Normal(a)
28+
end
29+
model = demo_normal(2.0)
30+
ldf = LogDensityFunction(model)
31+
sampling_objects = Dict("DynamicPPL.Model" => model, "LogDensityFunction" => ldf)
32+
seed = 468
33+
34+
@testset "sampling with $name" for (name, model_or_ldf) in sampling_objects
35+
spl = ESS()
36+
# check sampling works without rng
37+
@test sample(model_or_ldf, spl, 5) isa Chains
38+
# check reproducibility with rng
39+
chn1 = sample(Random.Xoshiro(seed), model_or_ldf, spl, 5)
40+
chn2 = sample(Random.Xoshiro(seed), model_or_ldf, spl, 5)
41+
@test mean(chn1[:a]) == mean(chn2[:a])
42+
end
43+
44+
@testset "check that initial_params are respected" begin
45+
a0 = 5.0
46+
chn = sample(model, ESS(), 5; initial_params=[a0])
47+
@test chn[:a][1] == a0
48+
end
49+
end
50+
1651
@model function demo(x)
1752
m ~ Normal()
1853
return x ~ Normal(m, 0.5)

0 commit comments

Comments
 (0)