Skip to content

Support DPPL 0.37 #2550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Distributions = "0.25.77"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.36"
DynamicPPL = "0.37"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.8.8"
Expand Down
27 changes: 13 additions & 14 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -57,8 +56,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
init_vals = DynamicPPL.getparams(f.ldf)
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -74,8 +72,8 @@ function Optim.optimize(
end

function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
return _optimize(f, args...; kwargs...)
end

"""
Expand Down Expand Up @@ -104,8 +102,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -127,8 +124,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
init_vals = DynamicPPL.getparams(f.ldf)
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -144,9 +140,10 @@ function Optim.optimize(
end

function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
return _optimize(f, args...; kwargs...)
end

"""
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)

Expand All @@ -166,7 +163,9 @@ function _optimize(
# whether initialisation is really necessary at all
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
vi = DynamicPPL.link(vi, f.ldf.model)
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
f = Optimisation.OptimLogDensity(
f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype
)
init_vals = DynamicPPL.getparams(f.ldf)

# Optimize!
Expand All @@ -184,7 +183,7 @@ function _optimize(
vi = f.ldf.varinfo
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
logdensity_optimum = Optimisation.OptimLogDensity(
f.ldf.model, vi_optimum, f.ldf.context
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
)
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
varnames = map(Symbol ∘ first, vns_vals_iter)
Expand Down
7 changes: 5 additions & 2 deletions src/essential/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ function AdvancedPS.advance!(
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
)
# Make sure we load/reset the rng in the new replaying mechanism
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
# TODO(mhauru) Stop ignoring the return value.
DynamicPPL.increment_num_produce!!(trace.model.f.varinfo)
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
score = consume(trace.model.ctask)
if score === nothing
Expand All @@ -44,11 +45,13 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
end

function AdvancedPS.reset_model(trace::TracedModel)
DynamicPPL.reset_num_produce!(trace.varinfo)
new_vi = DynamicPPL.reset_num_produce!!(trace.varinfo)
trace = TracedModel(trace.model, trace.sampler, new_vi, trace.evaluator)
return trace
end

function AdvancedPS.reset_logprob!(trace::TracedModel)
# TODO(mhauru) Stop ignoring the return value.
DynamicPPL.resetlogp!!(trace.model.varinfo)
return trace
end
Expand Down
22 changes: 9 additions & 13 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using DynamicPPL:
push!!,
setlogp!!,
getlogp,
getlogjoint,
VarName,
getsym,
getdist,
Expand All @@ -22,8 +23,6 @@ using DynamicPPL:
SampleFromPrior,
SampleFromUniform,
DefaultContext,
PriorContext,
LikelihoodContext,
set_flag!,
unset_flag!
using Distributions, Libtask, Bijectors
Expand Down Expand Up @@ -75,7 +74,6 @@ export InferenceAlgorithm,
RepeatSampler,
Prior,
assume,
observe,
predict,
externalsampler

Expand Down Expand Up @@ -182,12 +180,10 @@ function AbstractMCMC.step(
state=nothing;
kwargs...,
)
vi = VarInfo()
vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),))
vi = last(
DynamicPPL.evaluate!!(
model,
VarInfo(),
SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()),
),
DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior()))
)
return vi, nothing
end
Expand Down Expand Up @@ -228,7 +224,7 @@ end
Transition(θ, lp) = Transition(θ, lp, nothing)
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
θ = getparams(model, vi)
lp = getlogp(vi)
lp = getlogjoint(vi)
return Transition(θ, lp, getstats(t))
end

Expand All @@ -241,10 +237,10 @@ function metadata(t::Transition)
end
end

DynamicPPL.getlogp(t::Transition) = t.lp
DynamicPPL.getlogjoint(t::Transition) = t.lp

# Metadata of VarInfo object
metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),)
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)

# TODO: Implement additional checks for certain samplers, e.g.
# HMC not supporting discrete parameters.
Expand Down Expand Up @@ -381,7 +377,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
end

function get_transition_extras(ts::AbstractVector{<:VarInfo})
valmat = reshape([getlogp(t) for t in ts], :, 1)
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
return [:lp], valmat
end

Expand Down Expand Up @@ -594,7 +590,7 @@ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`

julia> transitions = Turing.Inference.transitions_from_chain(m, chain);

julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
2-element Array{Float64,1}:
-3.6294991938628374
-2.5697948166987845
Expand Down
22 changes: 8 additions & 14 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi),
DynamicPPL.LogDensityFunction(
DynamicPPL.LogDensityFunction{:LogLikelihood}(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
),
Expand All @@ -59,7 +59,7 @@ function AbstractMCMC.step(

# update sample and log-likelihood
vi = DynamicPPL.unflatten(vi, sample)
vi = setlogp!!(vi, state.loglikelihood)
vi = setloglikelihood!!(vi, state.loglikelihood)

return Transition(model, vi), vi
end
Expand Down Expand Up @@ -108,20 +108,14 @@ end
# Mean of prior distribution
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}

(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)

function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
)
return DynamicPPL.tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
)
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
end

function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
function DynamicPPL.tilde_observe!!(
ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi
)
return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi)
end
12 changes: 7 additions & 5 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
#
# Purpose: avoid triggering resampling of variables we're conditioning on.
# - Using standard `DynamicPPL.condition` results in conditioned variables being treated
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`.
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`.
# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to
# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable
# rather than only for the "true" observations.
Expand Down Expand Up @@ -177,24 +177,26 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
DynamicPPL.tilde_assume(child_context, right, vn, vi)
elseif has_conditioned_gibbs(context, vn)
# Short-circuit the tilde assume if `vn` is present in `context`.
value, lp, _ = DynamicPPL.tilde_assume(
# TODO(mhauru) Fix accumulation here. In this branch anything that gets
# accumulated just gets discarded with `_`.
value, _ = DynamicPPL.tilde_assume(
child_context, right, vn, get_global_varinfo(context)
)
value, lp, vi
value, vi
else
# If the varname has not been conditioned on, nor is it a target variable, its
# presumably a new variable that should be sampled from its prior. We need to add
# this new variable to the global `varinfo` of the context, but not to the local one
# being used by the current sampler.
value, lp, new_global_vi = DynamicPPL.tilde_assume(
value, new_global_vi = DynamicPPL.tilde_assume(
child_context,
DynamicPPL.SampleFromPrior(),
right,
vn,
get_global_varinfo(context),
)
set_global_varinfo!(context, new_global_vi)
value, lp, vi
value, vi
end
end

Expand Down
15 changes: 7 additions & 8 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function DynamicPPL.initialstep(
end

# Cache current log density.
log_density_old = getlogp(vi)
log_density_old = getloglikelihood(vi)

# Find good eps if not provided one
if iszero(spl.alg.ϵ)
Expand Down Expand Up @@ -227,10 +227,12 @@ function DynamicPPL.initialstep(
# Update `vi` based on acceptance
if t.stat.is_accept
vi = DynamicPPL.unflatten(vi, t.z.θ)
vi = setlogp!!(vi, t.stat.log_density)
# TODO(mhauru) Is setloglikelihood! the right thing here?
vi = setloglikelihood!!(vi, t.stat.log_density)
else
vi = DynamicPPL.unflatten(vi, theta)
vi = setlogp!!(vi, log_density_old)
# TODO(mhauru) Is setloglikelihood! the right thing here?
vi = setloglikelihood!!(vi, log_density_old)
end

transition = Transition(model, vi, t)
Expand Down Expand Up @@ -275,7 +277,8 @@ function AbstractMCMC.step(
vi = state.vi
if t.stat.is_accept
vi = DynamicPPL.unflatten(vi, t.z.θ)
vi = setlogp!!(vi, t.stat.log_density)
# TODO(mhauru) Is setloglikelihood! the right thing here?
vi = setloglikelihood!!(vi, t.stat.log_density)
end

# Compute next transition and state.
Expand Down Expand Up @@ -501,10 +504,6 @@ function DynamicPPL.assume(
return DynamicPPL.assume(dist, vn, vi)
end

function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
return DynamicPPL.observe(d, value, vi)
end

####
#### Default HMC stepsize and mass matrix adaptor
####
Expand Down
4 changes: 0 additions & 4 deletions src/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,3 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName
end
return r, 0, vi
end

function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi)
return logpdf(dist, value), vi
end
4 changes: 0 additions & 4 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,3 @@ function DynamicPPL.assume(
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
return retval
end

function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
end
Loading