Skip to content

Commit ae7abde

Browse files
committed
Update AbstractMCMC interface for Hamiltonian samplers
1 parent ad48370 commit ae7abde

File tree

3 files changed

+95
-171
lines changed

3 files changed

+95
-171
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ State of the [`DynamicNUTS`](@ref) sampler.
3535
# Fields
3636
$(TYPEDFIELDS)
3737
"""
38-
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
39-
logdensity::L
38+
struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S}
4039
vi::V
4140
"Cache of sample, log density, and gradient of log density evaluation."
4241
cache::C
@@ -48,30 +47,17 @@ function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS})
4847
return DynamicPPL.SampleFromUniform()
4948
end
5049

51-
function DynamicPPL.initialstep(
50+
function AbstractMCMC.step(
5251
rng::Random.AbstractRNG,
53-
model::DynamicPPL.Model,
54-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
55-
vi::DynamicPPL.AbstractVarInfo;
52+
ldf::DynamicPPL.LogDensityFunction,
53+
spl::DynamicPPL.Sampler{<:DynamicNUTS};
5654
kwargs...,
5755
)
58-
# Ensure that initial sample is in unconstrained space.
59-
if !DynamicPPL.islinked(vi)
60-
vi = DynamicPPL.link!!(vi, model)
61-
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
62-
end
63-
64-
# Define log-density function.
65-
= DynamicPPL.LogDensityFunction(
66-
model,
67-
vi,
68-
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
69-
adtype=spl.alg.adtype,
70-
)
56+
vi = ldf.varinfo
7157

7258
# Perform initial step.
7359
results = DynamicHMC.mcmc_keep_warmup(
74-
rng, , 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
60+
rng, ldf, 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
7561
)
7662
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
7763
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
@@ -81,32 +67,31 @@ function DynamicPPL.initialstep(
8167
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
8268

8369
# Create first sample and state.
84-
sample = Turing.Inference.Transition(model, vi)
85-
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
70+
sample = Turing.Inference.Transition(ldf.model, vi)
71+
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)
8672

8773
return sample, state
8874
end
8975

9076
function AbstractMCMC.step(
9177
rng::Random.AbstractRNG,
92-
model::DynamicPPL.Model,
78+
ldf::DynamicPPL.LogDensityFunction,
9379
spl::DynamicPPL.Sampler{<:DynamicNUTS},
9480
state::DynamicNUTSState;
9581
kwargs...,
9682
)
9783
# Compute next sample.
9884
vi = state.vi
99-
= state.logdensity
100-
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
85+
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ldf, state.stepsize)
10186
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
10287

10388
# Update the variables.
10489
vi = DynamicPPL.unflatten(vi, Q.q)
10590
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
10691

10792
# Create next sample and state.
108-
sample = Turing.Inference.Transition(model, vi)
109-
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
93+
sample = Turing.Inference.Transition(ldf.model, vi)
94+
newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize)
11095

11196
return sample, newstate
11297
end

src/mcmc/hmc.jl

Lines changed: 55 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1+
# InferenceAlgorithm interface
2+
13
abstract type Hamiltonian <: InferenceAlgorithm end
4+
5+
DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
6+
requires_unconstrained_space(::Hamiltonian) = true
7+
# TODO(penelopeysm): This is really quite dangerous code because it implicitly
8+
# assumes that any concrete type that subtypes `Hamiltonian` has an adtype
9+
# field.
10+
get_adtype(alg::Hamiltonian) = alg.adtype
11+
212
abstract type StaticHamiltonian <: Hamiltonian end
313
abstract type AdaptiveHamiltonian <: Hamiltonian end
414

15+
function update_sample_kwargs(alg::AdaptiveHamiltonian, N::Integer, kwargs)
16+
resume_from = get(kwargs, :resume_from, nothing)
17+
nadapts = get(kwargs, :nadapts, alg.n_adapts)
18+
discard_adapt = get(kwargs, :discard_adapt, true)
19+
discard_initial = get(kwargs, :discard_initial, -1)
20+
21+
return if resume_from === nothing
22+
# If `nadapts` is `-1`, then the user called a convenience constructor
23+
# like `NUTS()` or `NUTS(0.65)`, and we should set a default for them.
24+
if nadapts == -1
25+
_nadapts = min(1000, N ÷ 2) # Default to 1000 if not specified
26+
else
27+
_nadapts = nadapts
28+
end
29+
# If `discard_initial` is `-1`, then users did not specify the keyword argument.
30+
if discard_initial == -1
31+
_discard_initial = discard_adapt ? _nadapts : 0
32+
else
33+
_discard_initial = discard_initial
34+
end
35+
36+
(nadapts=_nadapts, discard_initial=_discard_initial, kwargs...)
37+
else
38+
(nadapts=0, discard_adapt=false, discard_initial=0, kwargs...)
39+
end
40+
end
41+
542
###
643
### Sampler states
744
###
@@ -80,68 +117,6 @@ function HMC(
80117
return HMC(ϵ, n_leapfrog, metricT; adtype=adtype)
81118
end
82119

83-
DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
84-
85-
# Handle setting `nadapts` and `discard_initial`
86-
function AbstractMCMC.sample(
87-
rng::AbstractRNG,
88-
model::DynamicPPL.Model,
89-
sampler::Sampler{<:AdaptiveHamiltonian},
90-
N::Integer;
91-
chain_type=DynamicPPL.default_chain_type(sampler),
92-
resume_from=nothing,
93-
initial_state=DynamicPPL.loadstate(resume_from),
94-
progress=PROGRESS[],
95-
nadapts=sampler.alg.n_adapts,
96-
discard_adapt=true,
97-
discard_initial=-1,
98-
kwargs...,
99-
)
100-
if resume_from === nothing
101-
# If `nadapts` is `-1`, then the user called a convenience
102-
# constructor like `NUTS()` or `NUTS(0.65)`,
103-
# and we should set a default for them.
104-
if nadapts == -1
105-
_nadapts = min(1000, N ÷ 2)
106-
else
107-
_nadapts = nadapts
108-
end
109-
110-
# If `discard_initial` is `-1`, then users did not specify the keyword argument.
111-
if discard_initial == -1
112-
_discard_initial = discard_adapt ? _nadapts : 0
113-
else
114-
_discard_initial = discard_initial
115-
end
116-
117-
return AbstractMCMC.mcmcsample(
118-
rng,
119-
model,
120-
sampler,
121-
N;
122-
chain_type=chain_type,
123-
progress=progress,
124-
nadapts=_nadapts,
125-
discard_initial=_discard_initial,
126-
kwargs...,
127-
)
128-
else
129-
return AbstractMCMC.mcmcsample(
130-
rng,
131-
model,
132-
sampler,
133-
N;
134-
chain_type=chain_type,
135-
initial_state=initial_state,
136-
progress=progress,
137-
nadapts=0,
138-
discard_adapt=false,
139-
discard_initial=0,
140-
kwargs...,
141-
)
142-
end
143-
end
144-
145120
function find_initial_params(
146121
rng::Random.AbstractRNG,
147122
model::DynamicPPL.Model,
@@ -172,44 +147,34 @@ function find_initial_params(
172147
)
173148
end
174149

175-
function DynamicPPL.initialstep(
150+
function AbstractMCMC.step(
176151
rng::AbstractRNG,
177-
model::AbstractModel,
178-
spl::Sampler{<:Hamiltonian},
179-
vi_original::AbstractVarInfo;
152+
ldf::LogDensityFunction,
153+
spl::Sampler{<:Hamiltonian};
180154
initial_params=nothing,
181155
nadapts=0,
182156
kwargs...,
183157
)
184-
# Transform the samples to unconstrained space and compute the joint log probability.
185-
vi = DynamicPPL.link(vi_original, model)
158+
ldf.adtype === nothing &&
159+
error("Hamiltonian sampler received a LogDensityFunction without an AD backend")
186160

187-
# Extract parameters.
188-
theta = vi[:]
161+
theta = ldf.varinfo[:]
162+
163+
has_initial_params = initial_params !== nothing
189164

190165
# Create a Hamiltonian.
191166
metricT = getmetricT(spl.alg)
192167
metric = metricT(length(theta))
193-
ldf = DynamicPPL.LogDensityFunction(
194-
model,
195-
vi,
196-
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
197-
# need to pass in the sampler? (In fact LogDensityFunction defaults to
198-
# using leafcontext(model.context) so could we just remove the argument
199-
# entirely?)
200-
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context));
201-
adtype=spl.alg.adtype,
202-
)
203168
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
204169
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
205170
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
206171

207172
# If no initial parameters are provided, resample until the log probability
208173
# and its gradient are finite. Otherwise, just use the existing parameters.
209174
vi, z = if initial_params === nothing
210-
find_initial_params(rng, model, vi, hamiltonian)
175+
find_initial_params(rng, ldf.model, ldf.varinfo, hamiltonian)
211176
else
212-
vi, AHMC.phasepoint(rng, theta, hamiltonian)
177+
ldf.varinfo, AHMC.phasepoint(rng, theta, hamiltonian)
213178
end
214179
theta = vi[:]
215180

@@ -248,23 +213,20 @@ function DynamicPPL.initialstep(
248213
vi = setlogp!!(vi, log_density_old)
249214
end
250215

251-
transition = Transition(model, vi, t)
216+
transition = Transition(ldf.model, vi, t)
252217
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
253218

254219
return transition, state
255220
end
256221

257222
function AbstractMCMC.step(
258223
rng::Random.AbstractRNG,
259-
model::Model,
224+
ldf::LogDensityFunction,
260225
spl::Sampler{<:Hamiltonian},
261226
state::HMCState;
262227
nadapts=0,
263228
kwargs...,
264229
)
265-
# Get step size
266-
@debug "current ϵ" getstepsize(spl, state)
267-
268230
# Compute transition.
269231
hamiltonian = state.hamiltonian
270232
z = state.z
@@ -294,13 +256,15 @@ function AbstractMCMC.step(
294256
end
295257

296258
# Compute next transition and state.
297-
transition = Transition(model, vi, t)
259+
transition = Transition(ldf.model, vi, t)
298260
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
299261

300262
return transition, newstate
301263
end
302264

303265
function get_hamiltonian(model, spl, vi, state, n)
266+
# TODO(penelopeysm): This is used by the Gibbs sampler, we can
267+
# simplify it to use LDF when Gibbs is reworked
304268
metric = gen_metric(n, spl, state)
305269
ldf = DynamicPPL.LogDensityFunction(
306270
model,
@@ -467,9 +431,9 @@ function NUTS(; kwargs...)
467431
return NUTS(-1, 0.65; kwargs...)
468432
end
469433

470-
for alg in (:HMC, :HMCDA, :NUTS)
471-
@eval getmetricT(::$alg{<:Any,metricT}) where {metricT} = metricT
472-
end
434+
getmetricT(::HMC{<:Any,metricT}) where {metricT} = metricT
435+
getmetricT(::HMCDA{<:Any,metricT}) where {metricT} = metricT
436+
getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT
473437

474438
#####
475439
##### HMC core functions

0 commit comments

Comments
 (0)