|
| 1 | +# InferenceAlgorithm interface |
| 2 | + |
1 | 3 | abstract type Hamiltonian <: InferenceAlgorithm end
|
| 4 | + |
| 5 | +DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() |
| 6 | +requires_unconstrained_space(::Hamiltonian) = true |
| 7 | +# TODO(penelopeysm): This is really quite dangerous code because it implicitly |
| 8 | +# assumes that any concrete type that subtypes `Hamiltonian` has an adtype |
| 9 | +# field. |
| 10 | +get_adtype(alg::Hamiltonian) = alg.adtype |
| 11 | + |
2 | 12 | abstract type StaticHamiltonian <: Hamiltonian end
|
3 | 13 | abstract type AdaptiveHamiltonian <: Hamiltonian end
|
4 | 14 |
|
| 15 | +function update_sample_kwargs(alg::AdaptiveHamiltonian, N::Integer, kwargs) |
| 16 | + resume_from = get(kwargs, :resume_from, nothing) |
| 17 | + nadapts = get(kwargs, :nadapts, alg.n_adapts) |
| 18 | + discard_adapt = get(kwargs, :discard_adapt, true) |
| 19 | + discard_initial = get(kwargs, :discard_initial, -1) |
| 20 | + |
| 21 | + return if resume_from === nothing |
| 22 | + # If `nadapts` is `-1`, then the user called a convenience constructor |
| 23 | + # like `NUTS()` or `NUTS(0.65)`, and we should set a default for them. |
| 24 | + if nadapts == -1 |
| 25 | + _nadapts = min(1000, N ÷ 2) # Default to 1000 if not specified |
| 26 | + else |
| 27 | + _nadapts = nadapts |
| 28 | + end |
| 29 | + # If `discard_initial` is `-1`, then users did not specify the keyword argument. |
| 30 | + if discard_initial == -1 |
| 31 | + _discard_initial = discard_adapt ? _nadapts : 0 |
| 32 | + else |
| 33 | + _discard_initial = discard_initial |
| 34 | + end |
| 35 | + |
| 36 | + (nadapts=_nadapts, discard_initial=_discard_initial, kwargs...) |
| 37 | + else |
| 38 | + (nadapts=0, discard_adapt=false, discard_initial=0, kwargs...) |
| 39 | + end |
| 40 | +end |
| 41 | + |
5 | 42 | ###
|
6 | 43 | ### Sampler states
|
7 | 44 | ###
|
@@ -80,68 +117,6 @@ function HMC(
|
80 | 117 | return HMC(ϵ, n_leapfrog, metricT; adtype=adtype)
|
81 | 118 | end
|
82 | 119 |
|
83 |
| -DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() |
84 |
| - |
85 |
| -# Handle setting `nadapts` and `discard_initial` |
86 |
| -function AbstractMCMC.sample( |
87 |
| - rng::AbstractRNG, |
88 |
| - model::DynamicPPL.Model, |
89 |
| - sampler::Sampler{<:AdaptiveHamiltonian}, |
90 |
| - N::Integer; |
91 |
| - chain_type=DynamicPPL.default_chain_type(sampler), |
92 |
| - resume_from=nothing, |
93 |
| - initial_state=DynamicPPL.loadstate(resume_from), |
94 |
| - progress=PROGRESS[], |
95 |
| - nadapts=sampler.alg.n_adapts, |
96 |
| - discard_adapt=true, |
97 |
| - discard_initial=-1, |
98 |
| - kwargs..., |
99 |
| -) |
100 |
| - if resume_from === nothing |
101 |
| - # If `nadapts` is `-1`, then the user called a convenience |
102 |
| - # constructor like `NUTS()` or `NUTS(0.65)`, |
103 |
| - # and we should set a default for them. |
104 |
| - if nadapts == -1 |
105 |
| - _nadapts = min(1000, N ÷ 2) |
106 |
| - else |
107 |
| - _nadapts = nadapts |
108 |
| - end |
109 |
| - |
110 |
| - # If `discard_initial` is `-1`, then users did not specify the keyword argument. |
111 |
| - if discard_initial == -1 |
112 |
| - _discard_initial = discard_adapt ? _nadapts : 0 |
113 |
| - else |
114 |
| - _discard_initial = discard_initial |
115 |
| - end |
116 |
| - |
117 |
| - return AbstractMCMC.mcmcsample( |
118 |
| - rng, |
119 |
| - model, |
120 |
| - sampler, |
121 |
| - N; |
122 |
| - chain_type=chain_type, |
123 |
| - progress=progress, |
124 |
| - nadapts=_nadapts, |
125 |
| - discard_initial=_discard_initial, |
126 |
| - kwargs..., |
127 |
| - ) |
128 |
| - else |
129 |
| - return AbstractMCMC.mcmcsample( |
130 |
| - rng, |
131 |
| - model, |
132 |
| - sampler, |
133 |
| - N; |
134 |
| - chain_type=chain_type, |
135 |
| - initial_state=initial_state, |
136 |
| - progress=progress, |
137 |
| - nadapts=0, |
138 |
| - discard_adapt=false, |
139 |
| - discard_initial=0, |
140 |
| - kwargs..., |
141 |
| - ) |
142 |
| - end |
143 |
| -end |
144 |
| - |
145 | 120 | function find_initial_params(
|
146 | 121 | rng::Random.AbstractRNG,
|
147 | 122 | model::DynamicPPL.Model,
|
@@ -172,44 +147,34 @@ function find_initial_params(
|
172 | 147 | )
|
173 | 148 | end
|
174 | 149 |
|
175 |
| -function DynamicPPL.initialstep( |
| 150 | +function AbstractMCMC.step( |
176 | 151 | rng::AbstractRNG,
|
177 |
| - model::AbstractModel, |
178 |
| - spl::Sampler{<:Hamiltonian}, |
179 |
| - vi_original::AbstractVarInfo; |
| 152 | + ldf::LogDensityFunction, |
| 153 | + spl::Sampler{<:Hamiltonian}; |
180 | 154 | initial_params=nothing,
|
181 | 155 | nadapts=0,
|
182 | 156 | kwargs...,
|
183 | 157 | )
|
184 |
| - # Transform the samples to unconstrained space and compute the joint log probability. |
185 |
| - vi = DynamicPPL.link(vi_original, model) |
| 158 | + ldf.adtype === nothing && |
| 159 | + error("Hamiltonian sampler received a LogDensityFunction without an AD backend") |
186 | 160 |
|
187 |
| - # Extract parameters. |
188 |
| - theta = vi[:] |
| 161 | + theta = ldf.varinfo[:] |
| 162 | + |
| 163 | + has_initial_params = initial_params !== nothing |
189 | 164 |
|
190 | 165 | # Create a Hamiltonian.
|
191 | 166 | metricT = getmetricT(spl.alg)
|
192 | 167 | metric = metricT(length(theta))
|
193 |
| - ldf = DynamicPPL.LogDensityFunction( |
194 |
| - model, |
195 |
| - vi, |
196 |
| - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we |
197 |
| - # need to pass in the sampler? (In fact LogDensityFunction defaults to |
198 |
| - # using leafcontext(model.context) so could we just remove the argument |
199 |
| - # entirely?) |
200 |
| - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)); |
201 |
| - adtype=spl.alg.adtype, |
202 |
| - ) |
203 | 168 | lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
|
204 | 169 | lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
|
205 | 170 | hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
|
206 | 171 |
|
207 | 172 | # If no initial parameters are provided, resample until the log probability
|
208 | 173 | # and its gradient are finite. Otherwise, just use the existing parameters.
|
209 | 174 | vi, z = if initial_params === nothing
|
210 |
| - find_initial_params(rng, model, vi, hamiltonian) |
| 175 | + find_initial_params(rng, ldf.model, ldf.varinfo, hamiltonian) |
211 | 176 | else
|
212 |
| - vi, AHMC.phasepoint(rng, theta, hamiltonian) |
| 177 | + ldf.varinfo, AHMC.phasepoint(rng, theta, hamiltonian) |
213 | 178 | end
|
214 | 179 | theta = vi[:]
|
215 | 180 |
|
@@ -248,23 +213,20 @@ function DynamicPPL.initialstep(
|
248 | 213 | vi = setlogp!!(vi, log_density_old)
|
249 | 214 | end
|
250 | 215 |
|
251 |
| - transition = Transition(model, vi, t) |
| 216 | + transition = Transition(ldf.model, vi, t) |
252 | 217 | state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
|
253 | 218 |
|
254 | 219 | return transition, state
|
255 | 220 | end
|
256 | 221 |
|
257 | 222 | function AbstractMCMC.step(
|
258 | 223 | rng::Random.AbstractRNG,
|
259 |
| - model::Model, |
| 224 | + ldf::LogDensityFunction, |
260 | 225 | spl::Sampler{<:Hamiltonian},
|
261 | 226 | state::HMCState;
|
262 | 227 | nadapts=0,
|
263 | 228 | kwargs...,
|
264 | 229 | )
|
265 |
| - # Get step size |
266 |
| - @debug "current ϵ" getstepsize(spl, state) |
267 |
| - |
268 | 230 | # Compute transition.
|
269 | 231 | hamiltonian = state.hamiltonian
|
270 | 232 | z = state.z
|
@@ -294,13 +256,15 @@ function AbstractMCMC.step(
|
294 | 256 | end
|
295 | 257 |
|
296 | 258 | # Compute next transition and state.
|
297 |
| - transition = Transition(model, vi, t) |
| 259 | + transition = Transition(ldf.model, vi, t) |
298 | 260 | newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
|
299 | 261 |
|
300 | 262 | return transition, newstate
|
301 | 263 | end
|
302 | 264 |
|
303 | 265 | function get_hamiltonian(model, spl, vi, state, n)
|
| 266 | + # TODO(penelopeysm): This is used by the Gibbs sampler, we can |
| 267 | + # simplify it to use LDF when Gibbs is reworked |
304 | 268 | metric = gen_metric(n, spl, state)
|
305 | 269 | ldf = DynamicPPL.LogDensityFunction(
|
306 | 270 | model,
|
@@ -467,9 +431,9 @@ function NUTS(; kwargs...)
|
467 | 431 | return NUTS(-1, 0.65; kwargs...)
|
468 | 432 | end
|
469 | 433 |
|
470 |
| -for alg in (:HMC, :HMCDA, :NUTS) |
471 |
| - @eval getmetricT(::$alg{<:Any,metricT}) where {metricT} = metricT |
472 |
| -end |
| 434 | +getmetricT(::HMC{<:Any,metricT}) where {metricT} = metricT |
| 435 | +getmetricT(::HMCDA{<:Any,metricT}) where {metricT} = metricT |
| 436 | +getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT |
473 | 437 |
|
474 | 438 | #####
|
475 | 439 | ##### HMC core functions
|
|
0 commit comments