Skip to content

sample with LogDensityFunction: part 2 - ess.jl + mh.jl #2590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: py/ldf-hmc
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

# Because this is a pain to implement all at once, we do it for one sampler at a time.
# This type tells us which samplers have been 'updated' to the new interface.

const LDFCompatibleSampler = Union{Hamiltonian}
const LDFCompatibleSampler = Union{Hamiltonian,ESS,MH}

"""
sample(
Expand Down Expand Up @@ -80,7 +79,7 @@
ctx = if ldf.context isa SamplingContext
ldf.context
else
SamplingContext(rng, spl)
SamplingContext(rng, spl, ldf.context)

Check warning on line 82 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L82

Added line #L82 was not covered by tests
end
Comment on lines 79 to 83
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise the existing context won't be obeyed

# Note that, in particular, sampling can mutate the variables in the LDF's
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
Expand Down Expand Up @@ -164,7 +163,7 @@
ctx = if ldf.context isa SamplingContext
ldf.context
else
SamplingContext(rng, spl)
SamplingContext(rng, spl, ldf.context)

Check warning on line 166 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L166

Added line #L166 was not covered by tests
end
# Note that, in particular, sampling can mutate the variables in the LDF's
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
Expand Down Expand Up @@ -282,7 +281,7 @@
initial_params = get(kwargs, :initial_params, nothing)
link = requires_unconstrained_space(spl)
vi = initialise_varinfo(rng, model, spl, initial_params, link)
ctx = SamplingContext(rng, spl)
ctx = SamplingContext(rng, spl, model.context)
ldf = LogDensityFunction(model, vi, ctx; adtype=get_adtype(spl))
# No need to run check_model again
return AbstractMCMC.sample(rng, ldf, spl, N; kwargs..., check_model=false)
Expand Down Expand Up @@ -331,7 +330,7 @@
initial_params = get(kwargs, :initial_params, nothing)
link = requires_unconstrained_space(spl)
vi = initialise_varinfo(rng, model, spl, initial_params, link)
ctx = SamplingContext(rng, spl)
ctx = SamplingContext(rng, spl, model.context)
ldf = LogDensityFunction(model, vi, ctx; adtype=get_adtype(spl))
# No need to run check_model again
return AbstractMCMC.sample(
Expand Down
73 changes: 33 additions & 40 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@
│ 1 │ m │ 0.824853 │
```
"""
struct ESS <: InferenceAlgorithm end
struct ESS <: AbstractSampler end

DynamicPPL.initialsampler(::ESS) = DynamicPPL.SampleFromPrior()
update_sample_kwargs(::ESS, ::Integer, kwargs) = kwargs
get_adtype(::ESS) = nothing
requires_unconstrained_space(::ESS) = false

Check warning on line 28 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L25-L28

Added lines #L25 - L28 were not covered by tests

# always accept in the first step
function DynamicPPL.initialstep(
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
)
function AbstractMCMC.step(rng::AbstractRNG, ldf::LogDensityFunction, spl::ESS; kwargs...)
vi = ldf.varinfo

Check warning on line 32 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
for vn in keys(vi)
dist = getdist(vi, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("ESS only supports Gaussian prior distributions")
end
return Transition(model, vi), vi
return Transition(ldf.model, vi), vi

Check warning on line 38 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L38

Added line #L38 was not covered by tests
end

function AbstractMCMC.step(
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
rng::AbstractRNG, ldf::LogDensityFunction, spl::ESS, vi::AbstractVarInfo; kwargs...
)
# obtain previous sample
f = vi[:]
Expand All @@ -45,14 +49,13 @@
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing)

# compute next state
# Note: `f_loglikelihood` effectively calculates the log-likelihood (not
# log-joint, despite the use of `LDP.logdensity`) because `tilde_assume` is
# overloaded on `SamplingContext(rng, ESS(), ...)` below.
f_loglikelihood = Base.Fix1(LogDensityProblems.logdensity, ldf)

Check warning on line 55 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L55

Added line #L55 was not covered by tests
sample, state = AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi),
DynamicPPL.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
),
EllipticalSliceSampling.ESSModel(ESSPrior(ldf.model, spl, vi), f_loglikelihood),
EllipticalSliceSampling.ESS(),
oldstate,
)
Expand All @@ -61,67 +64,57 @@
vi = DynamicPPL.unflatten(vi, sample)
vi = setlogp!!(vi, state.loglikelihood)

return Transition(model, vi), vi
return Transition(ldf.model, vi), vi

Check warning on line 67 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L67

Added line #L67 was not covered by tests
end

# Prior distribution of considered random variable
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
struct ESSPrior{M<:Model,V<:AbstractVarInfo,T}
model::M
sampler::S
sampler::ESS
varinfo::V
μ::T

function ESSPrior{M,S,V}(
model::M, sampler::S, varinfo::V
) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
function ESSPrior{M,V}(

Check warning on line 77 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L77

Added line #L77 was not covered by tests
model::M, sampler::ESS, varinfo::V
) where {M<:Model,V<:AbstractVarInfo}
vns = keys(varinfo)
μ = mapreduce(vcat, vns) do vn
dist = getdist(varinfo, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("[ESS] only supports Gaussian prior distributions")
DynamicPPL.tovec(mean(dist))
end
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
return new{M,V,typeof(μ)}(model, sampler, varinfo, μ)

Check warning on line 87 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L87

Added line #L87 was not covered by tests
end
end

function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo)
return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo)
function ESSPrior(model::Model, sampler::ESS, varinfo::AbstractVarInfo)
return ESSPrior{typeof(model),typeof(varinfo)}(model, sampler, varinfo)

Check warning on line 92 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end

# Ensure that the prior is a Gaussian distribution (checked in the constructor)
EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true

# Only define out-of-place sampling
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
vns = keys(varinfo)
for vn in vns
set_flag!(varinfo, vn, "del")
# TODO(penelopeysm): This is ugly -- need to set 'del' flag because
# otherwise DynamicPPL.SampleWithPrior will just use the existing
# parameters in the varinfo. In general SampleWithPrior etc. need to be
# reworked.
for vn in keys(p.varinfo)
set_flag!(p.varinfo, vn, "del")

Check warning on line 105 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L104-L105

Added lines #L104 - L105 were not covered by tests
end
p.model(rng, varinfo, sampler)
return varinfo[:]
_, vi = DynamicPPL.evaluate!!(p.model, p.varinfo, SamplingContext(rng, p.sampler))
return vi[:]

Check warning on line 108 in src/mcmc/ess.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/ess.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end

# Mean of prior distribution
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}

(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)

function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
rng::Random.AbstractRNG, ::DefaultContext, ::ESS, right, vn, vi
)
return DynamicPPL.tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
)
end

function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
end
126 changes: 37 additions & 89 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
```

"""
struct MH{P} <: InferenceAlgorithm
struct MH{P} <: AbstractSampler
proposals::P

function MH(proposals...)
Expand Down Expand Up @@ -139,18 +139,23 @@
end
end

# Some of the proposals require working in unconstrained space.
transform_maybe(proposal::AMH.Proposal) = proposal
function transform_maybe(proposal::AMH.RandomWalkProposal)
return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal))
end

function MH(model::Model; proposal_type=AMH.StaticProposal)
priors = DynamicPPL.extract_priors(model)
props = Tuple([proposal_type(prop) for prop in values(priors)])
vars = Tuple(map(Symbol, collect(keys(priors))))
priors = map(transform_maybe, NamedTuple{vars}(props))
return AMH.MetropolisHastings(priors)
Comment on lines -142 to -153
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code was broken for a long time and nobody bothered to update it. the idea would be that this would return an AMH.MetropolisHastings which then needed to be wrapped in an ExternalSampler. however the external sampler would break (see tests below) I just removed it because it wasn't documented and wasn't working.

# Turing sampler interface
DynamicPPL.initialsampler(::MH) = DynamicPPL.SampleFromPrior()
get_adtype(::MH) = nothing
update_sample_kwargs(::MH, ::Integer, kwargs) = kwargs
requires_unconstrained_space(::MH) = false
requires_unconstrained_space(::MH{<:AdvancedMH.RandomWalkProposal}) = true

Check warning on line 147 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L146-L147

Added lines #L146 - L147 were not covered by tests
# `NamedTuple` of proposals. TODO: It seems, at some point, that there
# was an intent to extract the parameters from the NamedTuple and to only
# link those parameters that corresponded to RandomWalkProposals. See
# https://github.com/TuringLang/Turing.jl/issues/1583.
requires_unconstrained_space(::MH{NamedTuple{(),Tuple{}}}) = false
@generated function requires_unconstrained_space(

Check warning on line 153 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L152-L153

Added lines #L152 - L153 were not covered by tests
::MH{<:NamedTuple{names,props}}
) where {names,props}
# If we have a `NamedTuple` with proposals, we check if all of them are
# `AdvancedMH.RandomWalkProposal`. If so, we need to link.
return all(prop -> prop <: AdvancedMH.RandomWalkProposal, props.parameters)

Check warning on line 158 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L158

Added line #L158 was not covered by tests
end

#####################
Expand Down Expand Up @@ -188,7 +193,7 @@

This variant uses the `set_namedtuple!` function to update the `VarInfo`.
"""
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} =
const MHLogDensityFunction{M<:Model,S<:MH,V<:AbstractVarInfo} =
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}

function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)
Expand Down Expand Up @@ -219,16 +224,16 @@
end

"""
dist_val_tuple(spl::Sampler{<:MH}, vi::VarInfo)
dist_val_tuple(spl::MH, vi::VarInfo)

Return two `NamedTuples`.

The first `NamedTuple` has symbols as keys and distributions as values.
The second `NamedTuple` has model symbols as keys and their stored values as values.
"""
function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo)
function dist_val_tuple(spl::MH, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo)
vns = all_varnames_grouped_by_symbol(vi)
dt = _dist_tuple(spl.alg.proposals, vi, vns)
dt = _dist_tuple(spl.proposals, vi, vns)
vt = _val_tuple(vi, vns)
return dt, vt
end
Expand Down Expand Up @@ -270,34 +275,9 @@
end
_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = ()

# Utility functions to link
should_link(varinfo, sampler, proposal) = false
function should_link(varinfo, sampler, proposal::NamedTuple{(),Tuple{}})
# If it's an empty `NamedTuple`, we're using the priors as proposals
# in which case we shouldn't link.
return false
end
function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal)
return true
end
# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`!
function should_link(
varinfo, sampler, proposal::NamedTuple{names,vals}
) where {names,vals<:NTuple{<:Any,<:AdvancedMH.RandomWalkProposal}}
return true
end

function maybe_link!!(varinfo, sampler, proposal, model)
return if should_link(varinfo, sampler, proposal)
DynamicPPL.link!!(varinfo, model)
else
varinfo
end
end

# Make a proposal if we don't have a covariance proposal matrix (the default).
function propose!!(
rng::AbstractRNG, vi::AbstractVarInfo, model::Model, spl::Sampler{<:MH}, proposal
rng::AbstractRNG, vi::AbstractVarInfo, ldf::LogDensityFunction, spl::MH, proposal
)
# Retrieve distribution and value NamedTuples.
dt, vt = dist_val_tuple(spl, vi)
Expand All @@ -307,16 +287,7 @@
prev_trans = AMH.Transition(vt, getlogp(vi), false)

# Make a new transition.
densitymodel = AMH.DensityModel(
Base.Fix1(
LogDensityProblems.logdensity,
DynamicPPL.LogDensityFunction(
model,
vi,
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
),
),
)
densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, ldf))
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

# TODO: Make this compatible with immutable `VarInfo`.
Expand All @@ -329,70 +300,47 @@
function propose!!(
rng::AbstractRNG,
vi::AbstractVarInfo,
model::Model,
spl::Sampler{<:MH},
ldf::LogDensityFunction,
spl::MH,
proposal::AdvancedMH.RandomWalkProposal,
)
# If this is the case, we can just draw directly from the proposal
# matrix.
vals = vi[:]

# Create a sampler and the previous transition.
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals)
mh_sampler = AMH.MetropolisHastings(spl.proposals)

Check warning on line 312 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L312

Added line #L312 was not covered by tests
prev_trans = AMH.Transition(vals, getlogp(vi), false)

# Make a new transition.
densitymodel = AMH.DensityModel(
Base.Fix1(
LogDensityProblems.logdensity,
DynamicPPL.LogDensityFunction(
model,
vi,
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
),
),
)
densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, ldf))

Check warning on line 316 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L316

Added line #L316 was not covered by tests
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

return setlogp!!(DynamicPPL.unflatten(vi, trans.params), trans.lp)
end

function DynamicPPL.initialstep(
rng::AbstractRNG,
model::AbstractModel,
spl::Sampler{<:MH},
vi::AbstractVarInfo;
kwargs...,
)
# If we're doing random walk with a covariance matrix,
# just link everything before sampling.
vi = maybe_link!!(vi, spl, spl.alg.proposals, model)

return Transition(model, vi), vi
function AbstractMCMC.step(rng::AbstractRNG, ldf::LogDensityFunction, spl::MH; kwargs...)
vi = ldf.varinfo
return Transition(ldf.model, vi), vi
end

function AbstractMCMC.step(
rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, vi::AbstractVarInfo; kwargs...
rng::AbstractRNG, ldf::LogDensityFunction, spl::MH, vi::AbstractVarInfo; kwargs...
)
# Cases:
# 1. A covariance proposal matrix
# 2. A bunch of NamedTuples that specify the proposal space
vi = propose!!(rng, vi, model, spl, spl.alg.proposals)

return Transition(model, vi), vi
vi = propose!!(rng, vi, ldf, spl, spl.proposals)
return Transition(ldf.model, vi), vi
end

####
#### Compiler interface, i.e. tilde operators.
####
function DynamicPPL.assume(
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi
rng::Random.AbstractRNG, ::MH, dist::Distribution, vn::VarName, vi
)
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
return retval
return DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
end

function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
function DynamicPPL.observe(::MH, d::Distribution, value, vi)
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
end
Loading
Loading