-
Notifications
You must be signed in to change notification settings - Fork 226
sample with LogDensityFunction: part 1 - hmc.jl
, sghmc.jl
, DynamicHMCExt
#2588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sample-ldf
Are you sure you want to change the base?
Conversation
Turing.jl documentation for PR #2588 is available at: |
hmc.jl
+ sghmc.jl
hmc.jl
+ sghmc.jl
hmc.jl
, sghmc.jl
, DynamicHMCExt
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## sample-ldf #2588 +/- ##
==============================================
- Coverage 85.50% 80.49% -5.02%
==============================================
Files 22 22
Lines 1456 1507 +51
==============================================
- Hits 1245 1213 -32
- Misses 211 294 +83 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many of the changes in sghmc.jl are quite similar to the ones in this file, so I added some comments explaining.
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} | ||
logdensity::L | ||
struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sampler state, traditionally, has included the LogDensityFunction
as a field so that it doesn't need to be re-constructed on each iteration from the model + varinfo. This is no longer necessary because the LDF is itself an argument to AbstractMCMC.step
.
# Ensure that initial sample is in unconstrained space. | ||
if !DynamicPPL.islinked(vi) | ||
vi = DynamicPPL.link!!(vi, model) | ||
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) | ||
end | ||
|
||
# Define log-density function. | ||
ℓ = DynamicPPL.LogDensityFunction( | ||
model, | ||
vi, | ||
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); | ||
adtype=spl.alg.adtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this stuff is now handled inside AbstractMCMC.sample()
, so there's no longer a need to duplicate this code inside every initialstep
method.
src/mcmc/Inference.jl
Outdated
function AbstractMCMC.bundle_samples( | ||
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, | ||
model::AbstractModel, | ||
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction}, | ||
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, | ||
state, | ||
chain_type::Type{MCMCChains.Chains}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature is not super ideal, but it minimises breakage right now. Eventually when everything is fixed we can change the Union to just LDF.
# Handle setting `nadapts` and `discard_initial` | ||
function AbstractMCMC.sample( | ||
rng::AbstractRNG, | ||
model::DynamicPPL.Model, | ||
sampler::Sampler{<:AdaptiveHamiltonian}, | ||
N::Integer; | ||
chain_type=DynamicPPL.default_chain_type(sampler), | ||
resume_from=nothing, | ||
initial_state=DynamicPPL.loadstate(resume_from), | ||
progress=PROGRESS[], | ||
nadapts=sampler.alg.n_adapts, | ||
discard_adapt=true, | ||
discard_initial=-1, | ||
kwargs..., | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of this overload was purely to modify the kwargs to sample. I ditched it in favour of adding a new hook, update_sample_kwargs
, which does the same thing without abusing multiple dispatch. I think that function does the same thing. It's of course quite hard to prove this, although separating it into a different function does allow us to write unit tests for it to make sure that it's doing the right thing (which are now in test/mcmc/hmc.jl
), so that's another benefit.
(Overloading sample() for individual samplers like this is quite precarious because we can't recursively call AbstractMCMC.sample
or we will end up with infinite recursion -- it has to call mcmcsample
. So, there's no way to 'extend' this with extra behaviour by e.g. calling another method of sample
before calling mcmcsample
.)
for alg in (:HMC, :HMCDA, :NUTS) | ||
@eval getmetricT(::$alg{<:Any,metricT}) where {metricT} = metricT | ||
end | ||
getmetricT(::HMC{<:Any,metricT}) where {metricT} = metricT | ||
getmetricT(::HMCDA{<:Any,metricT}) where {metricT} = metricT | ||
getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metaprogramming is cool and all, but this wasn't really necessary, imo.
test/mcmc/hmc.jl
Outdated
@testset "$(alg)" for alg in algs | ||
# Construct a HMC state by taking a single step | ||
vi = DynamicPPL.VarInfo(gdemo_default) | ||
vi = DynamicPPL.link(vi, gdemo_default) | ||
ldf = LogDensityFunction(gdemo_default, vi; adtype=Turing.DEFAULT_ADTYPE) | ||
spl = Sampler(alg) | ||
hmc_state = DynamicPPL.initialstep( | ||
Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default) | ||
)[2] | ||
_, hmc_state = AbstractMCMC.step(Random.default_rng(), ldf, spl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally, I think this test reveals one drawback of the current proposal: it becomes more annoying to directly call the AbstractMCMC interface. Let's say we want to benchmark the first step of a given sampler (for example, we were doing this the other day on the Gibbs sampler). Previously, we'd do:
rng = Random.default_rng()
model = ...
spl = ...
@be AbstractMCMC.step(rng, model, spl)
Now, we have to do:
rng = Random.default_rng()
model = ...
vi = link(VarInfo(model), model)
ldf = LogDensityFunction(model, vi; adtype=AutoForwardDiff())
spl = ...
@be AbstractMCMC.step(rng, ldf, spl)
I think this is a fairly small price to pay because the occasions where we reach directly for AbstractMCMC interface are quite few, and the code simplification is more important than this. But I thought this was probably something just worth noting.
This would be less problematic if we introduced more convenient constructors for LDF: TuringLang/DynamicPPL.jl#863 so it might be worth keeping that in mind.
2d730bc
to
1dcd37c
Compare
@model, | ||
Metadata, | ||
VarInfo, | ||
LogDensityFunction, | ||
SimpleVarInfo, | ||
AbstractVarInfo, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Being pedantic and annoying, I think it'll be better if we reorder so that the *VarInfo
s stay together.
""" | ||
update_sample_kwargs(spl::AbstractSampler, N::Integer, kwargs) | ||
|
||
Some samplers carry additional information about the keyword arguments that | ||
should be passed to `AbstractMCMC.sample`. This function provides a hook for | ||
them to update the default keyword arguments. The default implementation is for | ||
no changes to be made to `kwargs`. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth adding that this function returns a NamedTuple
(if I understand correctly).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, it's prob worth mentioning what is N
?
get_adtype(::AbstractSampler) = nothing | ||
|
||
""" | ||
requires_unconstrained_space(sampler::AbstractSampler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is a unimportant point: for something like MH, both constrained and unconstrained spaces might be fine.
On a high level, I do think "requires_unconstrained_space"-ness is a sampler property, but it's not sufficient condition to control linking.
Don't know if what I wrote makes sense.
abstract type Hamiltonian <: InferenceAlgorithm end | ||
# AbstractSampler interface for Turing | ||
|
||
abstract type Hamiltonian <: AbstractMCMC.AbstractSampler end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to tag @yebai for visibility.
I can see the motivation for InferenceAlgorithm
, but I think this makes things a lot cleaner.
initial_params=nothing, | ||
nadapts=0, | ||
kwargs..., | ||
) | ||
# Transform the samples to unconstrained space and compute the joint log probability. | ||
vi = DynamicPPL.link(vi_original, model) | ||
ldf.adtype === nothing && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a good sanity check, does it make sense to force ldf and spl has the same adtype?
# This file contains the basic methods for `AbstractMCMC.sample`. | ||
# The overall aim is that users can call | ||
# | ||
# sample(::Model, ::InferenceAlgorithm, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume InferenceAlgorithm
here is still needed until we update all the interface?
This PR moves in the general direction of #2555.
It will take a long time to get everything to work, so I am trying to do this incrementally.
Summary
The fundamental idea (see #2555) is that we want
sample(::Union{Model, LogDensityFunction}, ::Union{InferenceAlgorithm, Sampler{<:InferenceAlgorithm}}, N)
to always forward to
sample(::LogDensityFunction, ::Sampler{<:InferenceAlgorithm}, N)
(along the way, we construct the LDF if we need to, and also construct the sampler if we need to).
Then, the concrete AbstractMCMC interface functions (i.e.,
mcmcsample
,step
) do not ever see a Model, they only see an LDF.The future of
DynamicPPL.initialstep
Note that this allows us to sidestep (and eventually, completely remove)
DynamicPPL.initialstep
. The reason why that function exists is becauseAbstractMCMC.step
would do two things: first, generate the VarInfo that would eventually go into the LDF, and secondly, callinitialstep
(which was sampler-specific). Since the VarInfo generation bit is now handled in the LDF construction, it means that instead of having an extra function, we can just go back to implementing the two basicAbstractMCMC.step
methods, which is a nice bonus.Changes in this PR
Changing this all at once is bound to be not only impossible to do but also impossible to review. Thus, I've decided to (try to) implement this in a few stages. This PR is probably the one that makes the most sweeping changes, and also establishes the interface required. It:
Establishes the desired method dispatch behaviour for
sample
(seesrc/mcmc/abstractmcmc.jl
). Because we aren't ready to extend this to every sampler and inference algorithm yet, these methods dispatch only onLDFCompatibleAlgorithm
orLDFCompatibleSampler
, which are defined at the top of the file. The idea is that we'll add samplers as we go along, and one day we'll eventually be ready to remove this type and just useInferenceAlgorithm
.When automatically constructing the LDF, there are a few things that we need to know to construct it properly:
This PR therefore also introduces interface functions that all (LDF-compatible) samplers must conform to, namely
requires_unconstrained_space(::AbstractSampler)
andget_adtype(::AbstractSampler)
. Sensible defaults oftrue
andnothing
are given. Note that these functions were already floating around the Turing codebase, so all I've really done is to bring it together and actually write docstrings for them.Finally, there is an
update_sample_kwargs
function which samplers can use as a hook to modify the keyword arguments sent tosample()
. See comments below for more details.Fortunately
This doesn't actually require any changes to DynamicPPL, which I found to be a huge relief!
It's likely that some of the code in this PR will eventually be moved to DynamicPPL, as they don't have any non-DynamicPPL dependencies. But that can be handled very easily at a later stage, once we're confident that this all works.
Unfortunately
Changing the interface one sampler at a time completely breaks Gibbs, because for Gibbs to work, it requires all of its component samplers to be updated. So we may have to live with the Gibbs tests being broken for a while, and rely on me promising that I'll fix it at some point in time. In this PR, I've disabled the Gibbs tests that live outside
test/mcmc/gibbs.jl
.Because I don't know how long this will take me, I don't even want to merge this into
breaking
, as I don't want to have a new release held up by the fact that only half the changes have been done. I've created a new base branch,sample-ldf
, to collect all the work on this. When we're happy with it, we can merge that intobreaking
.