-
Notifications
You must be signed in to change notification settings - Fork 226
sample with LogDensityFunction: part 2 - ess.jl
+ mh.jl
#2590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: py/ldf-hmc
Are you sure you want to change the base?
Changes from all commits
8a0fb57
47e3c2f
b076b44
3e1082e
538f27d
882b844
c4ba82a
6527af6
852bfe7
ec885a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,7 @@ | |
``` | ||
|
||
""" | ||
struct MH{P} <: InferenceAlgorithm | ||
struct MH{P} <: AbstractSampler | ||
proposals::P | ||
|
||
function MH(proposals...) | ||
|
@@ -139,18 +139,23 @@ | |
end | ||
end | ||
|
||
# Some of the proposals require working in unconstrained space. | ||
transform_maybe(proposal::AMH.Proposal) = proposal | ||
function transform_maybe(proposal::AMH.RandomWalkProposal) | ||
return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal)) | ||
end | ||
|
||
function MH(model::Model; proposal_type=AMH.StaticProposal) | ||
priors = DynamicPPL.extract_priors(model) | ||
props = Tuple([proposal_type(prop) for prop in values(priors)]) | ||
vars = Tuple(map(Symbol, collect(keys(priors)))) | ||
priors = map(transform_maybe, NamedTuple{vars}(props)) | ||
return AMH.MetropolisHastings(priors) | ||
Comment on lines
-142
to
-153
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this code was broken for a long time and nobody bothered to update it. the idea would be that this would return an |
||
# Turing sampler interface | ||
DynamicPPL.initialsampler(::MH) = DynamicPPL.SampleFromPrior() | ||
get_adtype(::MH) = nothing | ||
update_sample_kwargs(::MH, ::Integer, kwargs) = kwargs | ||
requires_unconstrained_space(::MH) = false | ||
requires_unconstrained_space(::MH{<:AdvancedMH.RandomWalkProposal}) = true | ||
# `NamedTuple` of proposals. TODO: It seems, at some point, that there | ||
# was an intent to extract the parameters from the NamedTuple and to only | ||
# link those parameters that corresponded to RandomWalkProposals. See | ||
# https://github.com/TuringLang/Turing.jl/issues/1583. | ||
requires_unconstrained_space(::MH{NamedTuple{(),Tuple{}}}) = false | ||
@generated function requires_unconstrained_space( | ||
::MH{<:NamedTuple{names,props}} | ||
) where {names,props} | ||
# If we have a `NamedTuple` with proposals, we check if all of them are | ||
# `AdvancedMH.RandomWalkProposal`. If so, we need to link. | ||
return all(prop -> prop <: AdvancedMH.RandomWalkProposal, props.parameters) | ||
end | ||
|
||
##################### | ||
|
@@ -188,7 +193,7 @@ | |
|
||
This variant uses the `set_namedtuple!` function to update the `VarInfo`. | ||
""" | ||
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = | ||
const MHLogDensityFunction{M<:Model,S<:MH,V<:AbstractVarInfo} = | ||
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} | ||
|
||
function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple) | ||
|
@@ -219,16 +224,16 @@ | |
end | ||
|
||
""" | ||
dist_val_tuple(spl::Sampler{<:MH}, vi::VarInfo) | ||
dist_val_tuple(spl::MH, vi::VarInfo) | ||
|
||
Return two `NamedTuples`. | ||
|
||
The first `NamedTuple` has symbols as keys and distributions as values. | ||
The second `NamedTuple` has model symbols as keys and their stored values as values. | ||
""" | ||
function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo) | ||
function dist_val_tuple(spl::MH, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo) | ||
vns = all_varnames_grouped_by_symbol(vi) | ||
dt = _dist_tuple(spl.alg.proposals, vi, vns) | ||
dt = _dist_tuple(spl.proposals, vi, vns) | ||
vt = _val_tuple(vi, vns) | ||
return dt, vt | ||
end | ||
|
@@ -270,34 +275,9 @@ | |
end | ||
_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = () | ||
|
||
# Utility functions to link | ||
should_link(varinfo, sampler, proposal) = false | ||
function should_link(varinfo, sampler, proposal::NamedTuple{(),Tuple{}}) | ||
# If it's an empty `NamedTuple`, we're using the priors as proposals | ||
# in which case we shouldn't link. | ||
return false | ||
end | ||
function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal) | ||
return true | ||
end | ||
# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`! | ||
function should_link( | ||
varinfo, sampler, proposal::NamedTuple{names,vals} | ||
) where {names,vals<:NTuple{<:Any,<:AdvancedMH.RandomWalkProposal}} | ||
return true | ||
end | ||
|
||
function maybe_link!!(varinfo, sampler, proposal, model) | ||
return if should_link(varinfo, sampler, proposal) | ||
DynamicPPL.link!!(varinfo, model) | ||
else | ||
varinfo | ||
end | ||
end | ||
|
||
# Make a proposal if we don't have a covariance proposal matrix (the default). | ||
function propose!!( | ||
rng::AbstractRNG, vi::AbstractVarInfo, model::Model, spl::Sampler{<:MH}, proposal | ||
rng::AbstractRNG, vi::AbstractVarInfo, ldf::LogDensityFunction, spl::MH, proposal | ||
) | ||
# Retrieve distribution and value NamedTuples. | ||
dt, vt = dist_val_tuple(spl, vi) | ||
|
@@ -307,16 +287,7 @@ | |
prev_trans = AMH.Transition(vt, getlogp(vi), false) | ||
|
||
# Make a new transition. | ||
densitymodel = AMH.DensityModel( | ||
Base.Fix1( | ||
LogDensityProblems.logdensity, | ||
DynamicPPL.LogDensityFunction( | ||
model, | ||
vi, | ||
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), | ||
), | ||
), | ||
) | ||
densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, ldf)) | ||
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) | ||
|
||
# TODO: Make this compatible with immutable `VarInfo`. | ||
|
@@ -329,70 +300,47 @@ | |
function propose!!( | ||
rng::AbstractRNG, | ||
vi::AbstractVarInfo, | ||
model::Model, | ||
spl::Sampler{<:MH}, | ||
ldf::LogDensityFunction, | ||
spl::MH, | ||
proposal::AdvancedMH.RandomWalkProposal, | ||
) | ||
# If this is the case, we can just draw directly from the proposal | ||
# matrix. | ||
vals = vi[:] | ||
|
||
# Create a sampler and the previous transition. | ||
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) | ||
mh_sampler = AMH.MetropolisHastings(spl.proposals) | ||
prev_trans = AMH.Transition(vals, getlogp(vi), false) | ||
|
||
# Make a new transition. | ||
densitymodel = AMH.DensityModel( | ||
Base.Fix1( | ||
LogDensityProblems.logdensity, | ||
DynamicPPL.LogDensityFunction( | ||
model, | ||
vi, | ||
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), | ||
), | ||
), | ||
) | ||
densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, ldf)) | ||
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) | ||
|
||
return setlogp!!(DynamicPPL.unflatten(vi, trans.params), trans.lp) | ||
end | ||
|
||
function DynamicPPL.initialstep( | ||
rng::AbstractRNG, | ||
model::AbstractModel, | ||
spl::Sampler{<:MH}, | ||
vi::AbstractVarInfo; | ||
kwargs..., | ||
) | ||
# If we're doing random walk with a covariance matrix, | ||
# just link everything before sampling. | ||
vi = maybe_link!!(vi, spl, spl.alg.proposals, model) | ||
|
||
return Transition(model, vi), vi | ||
function AbstractMCMC.step(rng::AbstractRNG, ldf::LogDensityFunction, spl::MH; kwargs...) | ||
vi = ldf.varinfo | ||
return Transition(ldf.model, vi), vi | ||
end | ||
|
||
function AbstractMCMC.step( | ||
rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, vi::AbstractVarInfo; kwargs... | ||
rng::AbstractRNG, ldf::LogDensityFunction, spl::MH, vi::AbstractVarInfo; kwargs... | ||
) | ||
# Cases: | ||
# 1. A covariance proposal matrix | ||
# 2. A bunch of NamedTuples that specify the proposal space | ||
vi = propose!!(rng, vi, model, spl, spl.alg.proposals) | ||
|
||
return Transition(model, vi), vi | ||
vi = propose!!(rng, vi, ldf, spl, spl.proposals) | ||
return Transition(ldf.model, vi), vi | ||
end | ||
|
||
#### | ||
#### Compiler interface, i.e. tilde operators. | ||
#### | ||
function DynamicPPL.assume( | ||
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi | ||
rng::Random.AbstractRNG, ::MH, dist::Distribution, vn::VarName, vi | ||
) | ||
# Just defer to `SampleFromPrior`. | ||
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) | ||
return retval | ||
return DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) | ||
end | ||
|
||
function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi) | ||
function DynamicPPL.observe(::MH, d::Distribution, value, vi) | ||
return DynamicPPL.observe(SampleFromPrior(), d, value, vi) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise the existing context won't be obeyed