Skip to content

sample with LogDensityFunction: part 1 - hmc.jl, sghmc.jl, DynamicHMCExt #2588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: sample-ldf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.39.1"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
Expand Down Expand Up @@ -49,6 +50,7 @@ TuringOptimExt = "Optim"
[compat]
ADTypes = "1.9"
AbstractMCMC = "5.5"
AbstractPPL = "0.11.0"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
AdvancedMH = "0.8"
Expand All @@ -62,7 +64,7 @@ Distributions = "0.25.77"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.36.3"
DynamicPPL = "0.36.8"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.8.8"
Expand Down
44 changes: 13 additions & 31 deletions ext/TuringDynamicHMCExt.jl
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of the changes in sghmc.jl are quite similar to the ones in this file, so I added some comments explaining.

Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,26 @@ State of the [`DynamicNUTS`](@ref) sampler.
# Fields
$(TYPEDFIELDS)
"""
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
logdensity::L
struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampler state, traditionally, has included the LogDensityFunction as a field so that it doesn't need to be re-constructed on each iteration from the model + varinfo. This is no longer necessary because the LDF is itself an argument to AbstractMCMC.step.

vi::V
"Cache of sample, log density, and gradient of log density evaluation."
cache::C
metric::M
stepsize::S
end

function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS})
function DynamicPPL.initialsampler(::DynamicNUTS)
return DynamicPPL.SampleFromUniform()
end

function DynamicPPL.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
function AbstractMCMC.step(
rng::Random.AbstractRNG, ldf::DynamicPPL.LogDensityFunction, spl::DynamicNUTS; kwargs...
)
# Ensure that initial sample is in unconstrained space.
if !DynamicPPL.islinked(vi)
vi = DynamicPPL.link!!(vi, model)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
end

# Define log-density function.
ℓ = DynamicPPL.LogDensityFunction(
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
)
Comment on lines -58 to -70
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this stuff is now handled inside AbstractMCMC.sample(), so there's no longer a need to duplicate this code inside every initialstep method.

vi = ldf.varinfo

# Perform initial step.
results = DynamicHMC.mcmc_keep_warmup(
rng, , 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
rng, ldf, 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
)
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
Expand All @@ -81,32 +64,31 @@ function DynamicPPL.initialstep(
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create first sample and state.
sample = Turing.Inference.Transition(model, vi)
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
sample = Turing.Inference.Transition(ldf.model, vi)
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)

return sample, state
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:DynamicNUTS},
ldf::DynamicPPL.LogDensityFunction,
spl::DynamicNUTS,
state::DynamicNUTSState;
kwargs...,
)
# Compute next sample.
vi = state.vi
ℓ = state.logdensity
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ldf, state.stepsize)
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)

# Update the variables.
vi = DynamicPPL.unflatten(vi, Q.q)
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create next sample and state.
sample = Turing.Inference.Transition(model, vi)
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
sample = Turing.Inference.Transition(ldf.model, vi)
newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize)

return sample, newstate
end
Expand Down
29 changes: 21 additions & 8 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
module Inference

using DynamicPPL:
DynamicPPL,
@model,
Metadata,
VarInfo,
LogDensityFunction,
SimpleVarInfo,
AbstractVarInfo,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being pedantic and annoying, I think it'll be better if we reorder so that the *VarInfos stay together.

# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only
# implemented for NTVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it
# or implement it for other VarInfo types and export it from DPPL.
Expand All @@ -24,6 +28,7 @@
DefaultContext,
PriorContext,
LikelihoodContext,
SamplingContext,
set_flag!,
unset_flag!
using Distributions, Libtask, Bijectors
Expand All @@ -32,14 +37,14 @@
using ..Turing: PROGRESS, Turing
using StatsFuns: logsumexp
using Random: AbstractRNG
using DynamicPPL
using AbstractMCMC: AbstractModel, AbstractSampler
using DocStringExtensions: FIELDS, TYPEDEF, TYPEDFIELDS
using DataStructures: OrderedSet
using DataStructures: OrderedSet, OrderedDict
using Accessors: Accessors

import ADTypes
import AbstractMCMC
import AbstractPPL
import AdvancedHMC
const AHMC = AdvancedHMC
import AdvancedMH
Expand Down Expand Up @@ -74,8 +79,6 @@
PG,
RepeatSampler,
Prior,
assume,
observe,
predict,
externalsampler

Expand Down Expand Up @@ -244,8 +247,8 @@
# This is type piracy (at least for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
spl::AbstractSampler,
state,
chain_type::Type{MCMCChains.Chains};
save_state=false,
Expand All @@ -256,6 +259,11 @@
thinning=1,
kwargs...,
)
model = if model_or_ldf isa DynamicPPL.LogDensityFunction
model_or_ldf.model
else
model_or_ldf
end
# Convert transitions to array format.
# Also retrieve the variable names.
varnames, vals = _params_to_array(model, ts)
Expand Down Expand Up @@ -307,12 +315,17 @@
# This is type piracy (for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
spl::AbstractSampler,
state,
chain_type::Type{Vector{NamedTuple}};
kwargs...,
)
model = if model_or_ldf isa DynamicPPL.LogDensityFunction
model_or_ldf.model

Check warning on line 325 in src/mcmc/Inference.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/Inference.jl#L325

Added line #L325 was not covered by tests
else
model_or_ldf
end
return map(ts) do t
# Construct a dictionary of pairs `vn => value`.
params = OrderedDict(getparams(model, t))
Expand Down
Loading
Loading