-
Notifications
You must be signed in to change notification settings - Fork 226
sample with LogDensityFunction: part 1 - hmc.jl
, sghmc.jl
, DynamicHMCExt
#2588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sample-ldf
Are you sure you want to change the base?
Changes from all commits
8be9094
480fb69
6599cfa
ad48370
ae7abde
2c9f970
6923a72
4606bdd
17a1f00
71a8cf2
85b1997
49f6988
e4cb590
e08f548
8a0fb57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,43 +35,26 @@ State of the [`DynamicNUTS`](@ref) sampler. | |
# Fields | ||
$(TYPEDFIELDS) | ||
""" | ||
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} | ||
logdensity::L | ||
struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The sampler state, traditionally, has included the |
||
vi::V | ||
"Cache of sample, log density, and gradient of log density evaluation." | ||
cache::C | ||
metric::M | ||
stepsize::S | ||
end | ||
|
||
function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS}) | ||
function DynamicPPL.initialsampler(::DynamicNUTS) | ||
return DynamicPPL.SampleFromUniform() | ||
end | ||
|
||
function DynamicPPL.initialstep( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model, | ||
spl::DynamicPPL.Sampler{<:DynamicNUTS}, | ||
vi::DynamicPPL.AbstractVarInfo; | ||
kwargs..., | ||
function AbstractMCMC.step( | ||
rng::Random.AbstractRNG, ldf::DynamicPPL.LogDensityFunction, spl::DynamicNUTS; kwargs... | ||
) | ||
# Ensure that initial sample is in unconstrained space. | ||
if !DynamicPPL.islinked(vi) | ||
vi = DynamicPPL.link!!(vi, model) | ||
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) | ||
end | ||
|
||
# Define log-density function. | ||
ℓ = DynamicPPL.LogDensityFunction( | ||
model, | ||
vi, | ||
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); | ||
adtype=spl.alg.adtype, | ||
) | ||
Comment on lines
-58
to
-70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of this stuff is now handled inside |
||
vi = ldf.varinfo | ||
|
||
# Perform initial step. | ||
results = DynamicHMC.mcmc_keep_warmup( | ||
rng, ℓ, 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport() | ||
rng, ldf, 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport() | ||
) | ||
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state) | ||
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q) | ||
|
@@ -81,32 +64,31 @@ function DynamicPPL.initialstep( | |
vi = DynamicPPL.setlogp!!(vi, Q.ℓq) | ||
|
||
# Create first sample and state. | ||
sample = Turing.Inference.Transition(model, vi) | ||
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) | ||
sample = Turing.Inference.Transition(ldf.model, vi) | ||
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ) | ||
|
||
return sample, state | ||
end | ||
|
||
function AbstractMCMC.step( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model, | ||
spl::DynamicPPL.Sampler{<:DynamicNUTS}, | ||
ldf::DynamicPPL.LogDensityFunction, | ||
spl::DynamicNUTS, | ||
state::DynamicNUTSState; | ||
kwargs..., | ||
) | ||
# Compute next sample. | ||
vi = state.vi | ||
ℓ = state.logdensity | ||
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize) | ||
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ldf, state.stepsize) | ||
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) | ||
|
||
# Update the variables. | ||
vi = DynamicPPL.unflatten(vi, Q.q) | ||
vi = DynamicPPL.setlogp!!(vi, Q.ℓq) | ||
|
||
# Create next sample and state. | ||
sample = Turing.Inference.Transition(model, vi) | ||
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) | ||
sample = Turing.Inference.Transition(ldf.model, vi) | ||
newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize) | ||
|
||
return sample, newstate | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
module Inference | ||
|
||
using DynamicPPL: | ||
DynamicPPL, | ||
@model, | ||
Metadata, | ||
VarInfo, | ||
LogDensityFunction, | ||
SimpleVarInfo, | ||
AbstractVarInfo, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Being pedantic and annoying, I think it'll be better if we reorder so that the |
||
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only | ||
# implemented for NTVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it | ||
# or implement it for other VarInfo types and export it from DPPL. | ||
|
@@ -24,6 +28,7 @@ | |
DefaultContext, | ||
PriorContext, | ||
LikelihoodContext, | ||
SamplingContext, | ||
set_flag!, | ||
unset_flag! | ||
using Distributions, Libtask, Bijectors | ||
|
@@ -32,14 +37,14 @@ | |
using ..Turing: PROGRESS, Turing | ||
using StatsFuns: logsumexp | ||
using Random: AbstractRNG | ||
using DynamicPPL | ||
using AbstractMCMC: AbstractModel, AbstractSampler | ||
using DocStringExtensions: FIELDS, TYPEDEF, TYPEDFIELDS | ||
using DataStructures: OrderedSet | ||
using DataStructures: OrderedSet, OrderedDict | ||
using Accessors: Accessors | ||
|
||
import ADTypes | ||
import AbstractMCMC | ||
import AbstractPPL | ||
import AdvancedHMC | ||
const AHMC = AdvancedHMC | ||
import AdvancedMH | ||
|
@@ -74,8 +79,6 @@ | |
PG, | ||
RepeatSampler, | ||
Prior, | ||
assume, | ||
observe, | ||
predict, | ||
externalsampler | ||
|
||
|
@@ -244,8 +247,8 @@ | |
# This is type piracy (at least for SampleFromPrior). | ||
function AbstractMCMC.bundle_samples( | ||
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, | ||
model::AbstractModel, | ||
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, | ||
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction}, | ||
spl::AbstractSampler, | ||
state, | ||
chain_type::Type{MCMCChains.Chains}; | ||
save_state=false, | ||
|
@@ -256,6 +259,11 @@ | |
thinning=1, | ||
kwargs..., | ||
) | ||
model = if model_or_ldf isa DynamicPPL.LogDensityFunction | ||
model_or_ldf.model | ||
else | ||
model_or_ldf | ||
end | ||
# Convert transitions to array format. | ||
# Also retrieve the variable names. | ||
varnames, vals = _params_to_array(model, ts) | ||
|
@@ -307,12 +315,17 @@ | |
# This is type piracy (for SampleFromPrior). | ||
function AbstractMCMC.bundle_samples( | ||
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, | ||
model::AbstractModel, | ||
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, | ||
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction}, | ||
spl::AbstractSampler, | ||
state, | ||
chain_type::Type{Vector{NamedTuple}}; | ||
kwargs..., | ||
) | ||
model = if model_or_ldf isa DynamicPPL.LogDensityFunction | ||
model_or_ldf.model | ||
else | ||
model_or_ldf | ||
end | ||
return map(ts) do t | ||
# Construct a dictionary of pairs `vn => value`. | ||
params = OrderedDict(getparams(model, t)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many of the changes in sghmc.jl are quite similar to the ones in this file, so I added some comments explaining.