diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..9eef8fac8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,7 +76,8 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts - DefaultContext, + EvaluationContext, + SamplingContext, LikelihoodContext, PriorContext, MiniBatchContext, diff --git a/src/compiler.jl b/src/compiler.jl index bef7d11c2..908849519 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -52,6 +52,49 @@ end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x +""" + unwrap_right_vn(right, vn) + +Return the unwrapped distribution on the right-hand side and variable name on the left-hand +side of a `~` expression such as `x ~ Normal()`. + +This is used mainly to unwrap `NamedDist` distributions. +""" +unwrap_right_vn(right, vn) = right, vn +unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) + +""" + unwrap_right_left_vns(context, right, left, vns) + +Return the unwrapped distributions on the right-hand side and values and variable names on the +left-hand side of a `.~` expression such as `x .~ Normal()`. + +This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the +variables. +""" +unwrap_right_left_vns(right, left, vns) = right, left, vns +function unwrap_right_left_vns(right::NamedDist, left, vns) + return unwrap_right_left_vns(right.dist, left, right.name) +end +function unwrap_right_left_vns( + right::MultivariateDistribution, left::AbstractMatrix, vn::VarName +) + vns = map(axes(left, 2)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end +function unwrap_right_left_vns( + right::Union{Distribution,AbstractArray{<:Distribution}}, + left::AbstractArray, + vn::VarName, +) + vns = map(CartesianIndices(left)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end + ################# # Main Compiler # ################# @@ -242,12 +285,8 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.tilde_observe)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + $(DynamicPPL.tilde_observe!)( + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -260,19 +299,17 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume)( - __rng__, + $left = $(DynamicPPL.tilde_assume!)( __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $vn, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., $inds, __varinfo__, ) else - $(DynamicPPL.tilde_observe)( + $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -292,12 +329,8 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.dot_tilde_observe)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + $(DynamicPPL.dot_tilde_observe!)( + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -310,20 +343,17 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume)( - __rng__, + $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + )..., $inds, __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe)( + $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -354,10 +384,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -367,7 +395,15 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + # in case someone accessed these + if __context__ isa $(DynamicPPL.SamplingContext) + __rng__ = __context__.rng + __sampler__ = __context__.sampler + end + + $(modelinfo[:body]) + end ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index afc5e4da3..344bc0816 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -17,100 +17,70 @@ require_particles(spl::Sampler) = false _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x -# assume -function tilde(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return _tilde(rng, sampler, right, vn, vi) -end -function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return _tilde(rng, sampler, right, vn, vi) -end -function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return _tilde(rng, sampler, NoDist(right), vn, vi) -end -function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, inds, vi) -end -function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) +function _getvalue(nt::NamedTuple, sym::Symbol, inds=()) + value = getfield(nt, sym) + return _getindex(value, inds) end +# assume """ - tilde_assume(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(ctx, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde(ctx, right, vn, inds, vi)`. """ -function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, right, vn, inds, vi) + value, logp = tilde_assume(ctx, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function _tilde(rng, sampler, right, vn::VarName, vi) - return assume(rng, sampler, right, vn, vi) +function tilde_assume(ctx::SamplingContext, right, left, vn, inds, vi) + return assume(ctx.rng, ctx.sampler, right, left, vn, inds, vi) end -function _tilde(rng, sampler, right::NamedDist, vn::VarName, vi) - return _tilde(rng, sampler, right.dist, right.name, vi) +function tilde_assume(ctx::EvaluationContext, right, left::Nothing, vn, inds, vi) + return assume(ctx.sampler, right, vi[vn], vn, inds, vi) +end +function tilde_assume(ctx::EvaluationContext, right, left, vn, inds, vi) + return assume(ctx.sampler, right, left, vn, inds, vi) end # observe -function tilde(ctx::DefaultContext, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) -end -function tilde(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) -end -function tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) -end -function tilde(ctx::PrefixContext, sampler, right, left, vi) - return tilde(ctx.ctx, sampler, right, left, vi) +function tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) + return observe(ctx.sampler, right, left, vi) end """ - tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe(ctx, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(ctx, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `_tilde_observe(ctx, right, left, vi)`. """ -function tilde_observe(ctx, sampler, right, left, vi) - logp = tilde(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end -_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -120,106 +90,96 @@ function observe(spl::Sampler, weight) end function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + left::Nothing, + vn, + inds, + vi, ) + r = init(rng, dist, spl) if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - r = vi[vn] - end + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) else - r = init(rng, dist, spl) push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) end + settrans!(vi, false, vn) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end +function assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + left, + vn, + inds, + vi, +) + r = left + if haskey(vi, vn) + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) + else + push!(vi, vn, r, dist, spl) + end + settrans!(vi, false, vn) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) +end + +function assume( + spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, vn, inds, vi +) + return left, Bijectors.logpdf_with_trans(dist, left, istrans(vi, vn)) +end + function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi + spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, vi ) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(dist, left) end # .~ functions # assume -function dot_tilde(rng, ctx::DefaultContext, sampler, right, left, vn::VarName, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) - return _dot_tilde(rng, sampler, dist, left, vns, vi) -end -function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) - end - return _dot_tilde(rng, sampler, NoDist.(dist), left, vns, vi) -end -function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) - return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) -end -function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) - end - return _dot_tilde(rng, sampler, dist, left, vns, vi) -end - """ - dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(ctx, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. """ -function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function get_vns_and_dist(dist::NamedDist, var, vn::VarName) - return get_vns_and_dist(dist.dist, var, dist.name) +function dot_tilde_assume(ctx::SamplingContext, right, left, vns, inds, vi) + return dot_assume(ctx.rng, ctx.sampler, right, vns, left, vi) end -function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName) - getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i))) - return getvn.(1:size(var, 2)), dist -end -function get_vns_and_dist( - dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vn::VarName -) - getvn = ind -> VarName(vn, (vn.indexing..., Tuple(ind))) - return getvn.(CartesianIndices(var)), dist + +function dot_tilde_assume(ctx::EvaluationContext, right, left, vns, inds, vi) + return dot_assume(ctx.sampler, right, vns, left, vi) end -function _dot_tilde(rng, sampler, right, left, vns::AbstractArray{<:VarName}, vi) - return dot_assume(rng, sampler, right, vns, left, vi) +function dot_tilde_assume(ctx::EvaluationContext, right, left::Nothing, vns, inds, vi) + return dot_assume(ctx.sampler, right, vns, vi[vns], vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( +function dot_tilde_assume( rng, + ctx, sampler::AbstractSampler, right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, vn::AbstractVector{<:VarName}, + inds, vi, ) return throw(DimensionMismatch(AMBIGUITY_MSG)) @@ -230,29 +190,81 @@ function dot_assume( spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, vns::AbstractVector{<:VarName}, - var::AbstractMatrix, + var::Nothing, vi, ) @assert length(dist) == size(var, 1) r = get_and_set_val!(rng, vi, vns, dist, spl) lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - var .= r + return r, lp +end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + vns::AbstractVector{<:VarName}, + var, + vi, +) + @assert length(dist) == size(var, 1) + r = set_val!(vi, vns, dist, var) + lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return r, lp +end + +function dot_assume( + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + vns::AbstractVector{<:VarName}, + var::AbstractMatrix, + vi, +) + @assert length(dist) == size(var, 1) + lp = sum(Bijectors.logpdf_with_trans(dist, var, istrans(vi, vns[1]))) return var, lp end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, - var::AbstractArray, + var::Nothing, vi, ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - var .= r + return r, lp +end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var, + vi, +) + r = set_val!(vi, vns, dists, var) + # Make sure `r` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp +end + +function dot_assume( + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var::AbstractArray, + vi, +) + # Make sure `r` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) return var, lp end + function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) return error( "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" @@ -268,18 +280,12 @@ function get_and_set_val!( ) n = length(vns) if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = vi[vns] + r = init(rng, dist, spl, n) + for i in 1:n + vn = vns[i] + vi[vn] = vectorize(dist, r[:, i]) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) end else r = init(rng, dist, spl, n) @@ -300,20 +306,14 @@ function get_and_set_val!( spl::Union{SampleFromPrior,SampleFromUniform}, ) if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - for i in eachindex(vns) - vn = vns[i] - dist = dists isa AbstractArray ? dists[i] : dists - vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = reshape(vi[vec(vns)], size(vns)) + f = (vn, dist) -> init(rng, dist, spl) + r = f.(vns, dists) + for i in eachindex(vns) + vn = vns[i] + dist = dists isa AbstractArray ? dists[i] : dists + vi[vn] = vectorize(dist, r[i]) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) end else f = (vn, dist) -> init(rng, dist, spl) @@ -348,53 +348,40 @@ function set_val!( end # observe -function dot_tilde(ctx::DefaultContext, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) -end -function dot_tilde(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) -end -function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) -end - """ - dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, right, left, vi)`. """ -function dot_tilde_observe(ctx, sampler, right, left, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end -function _dot_tilde(sampler, right, left::AbstractArray, vi) - return dot_observe(sampler, right, left, vi) +function dot_tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) + return dot_observe(ctx.sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( +function dot_observe( sampler::AbstractSampler, right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, @@ -441,3 +428,9 @@ function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" ) end + +# includes +include("context_implementations/prior.jl") +include("context_implementations/likelihood.jl") +include("context_implementations/minibatch.jl") +include("context_implementations/prefix.jl") diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl new file mode 100644 index 000000000..14054856d --- /dev/null +++ b/src/context_implementations/likelihood.jl @@ -0,0 +1,31 @@ +function tilde_assume(ctx::LikelihoodContext, right, left, vn::VarName, inds, vi) + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + _getvalue(ctx.vars, getsym(vn), inds) + else + left + end + return tilde_assume( + rewrap(childcontext(ctx), EvaluationContext()), NoDist(right), var, vn, inds, vi + ) +end + +function tilde_observe(ctx::LikelihoodContext, right, left, vi) + return tilde_observe(childcontext(ctx), right, left, vi) +end + +function dot_tilde_assume( + ctx::LikelihoodContext, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi +) where {sym} + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + _getvalue(ctx.vars, sym, inds) + else + left + end + return dot_tilde_assume( + rewrap(childcontext(ctx), EvaluationContext()), NoDist.(right), var, vns, inds, vi + ) +end + +function dot_tilde_observe(ctx::LikelihoodContext, right, left, vi) + return dot_tilde_observe(childcontext(ctx), right, left, vi) +end diff --git a/src/context_implementations/minibatch.jl b/src/context_implementations/minibatch.jl new file mode 100644 index 000000000..91f5be57d --- /dev/null +++ b/src/context_implementations/minibatch.jl @@ -0,0 +1,15 @@ +function tilde_assume(ctx::MiniBatchContext, right, left, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, left, vn, inds, vi) +end + +function tilde_observe(ctx::MiniBatchContext, right, left, vi) + return ctx.loglike_scalar * tilde_observe(childcontext(ctx), right, left, vi) +end + +function dot_tilde_assume(ctx::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, left, vn, inds, vi) +end + +function dot_tilde_observe(ctx::MiniBatchContext, right, left, vi) + return ctx.loglike_scalar * dot_tilde_observe(childcontext(ctx), right, left, vi) +end diff --git a/src/context_implementations/prefix.jl b/src/context_implementations/prefix.jl new file mode 100644 index 000000000..09329ad2f --- /dev/null +++ b/src/context_implementations/prefix.jl @@ -0,0 +1,17 @@ +function tilde_assume(ctx::PrefixContext, right, left, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, left, prefix(ctx, vn), inds, vi) +end + +function tilde_observe(ctx::PrefixContext, right, left, vi) + return tilde_observe(childcontext(ctx), right, left, vi) +end + +function dot_tilde_assume(ctx::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume( + childcontext(ctx), right, left, map(Base.Fix1(prefix, ctx), vn), inds, vi + ) +end + +function dot_tilde_observe(ctx::PrefixContext, right, left, vi) + return dot_tilde_observe(childcontext(ctx), right, left, vi) +end diff --git a/src/context_implementations/prior.jl b/src/context_implementations/prior.jl new file mode 100644 index 000000000..c259254fe --- /dev/null +++ b/src/context_implementations/prior.jl @@ -0,0 +1,27 @@ +function tilde_assume(ctx::PriorContext, right, left, vn, inds, vi) + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + _getvalue(ctx.vars, getsym(vn), inds) + else + left + end + return tilde_assume(childcontext(ctx), right, var, vn, inds, vi) +end + +function tilde_observe(ctx::PriorContext, right, left, vi) + return 0 +end + +function dot_tilde_assume( + ctx::PriorContext, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi +) where {sym} + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + _getvalue(ctx.vars, getsym(vn), inds) + else + left + end + return dot_tilde_assume(childcontext(ctx), right, var, vns, inds, vi) +end + +function dot_tilde_observe(ctx::PriorContext, right, left, vi) + return 0 +end diff --git a/src/contexts.jl b/src/contexts.jl index 4d4f30bdc..1329d6703 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -4,80 +4,58 @@ The `DefaultContext` is used by default to compute log the joint probability of the data and parameters when running the model. """ -struct DefaultContext <: AbstractContext end +abstract type PrimitiveContext <: AbstractContext end +struct EvaluationContext{S<:AbstractSampler} <: PrimitiveContext + # TODO: do we even need the sampler these days? + sampler::S +end +EvaluationContext() = EvaluationContext(SampleFromPrior()) -""" - struct PriorContext{Tvars} <: AbstractContext - vars::Tvars - end +struct SamplingContext{R<:Random.AbstractRNG,S<:AbstractSampler} <: PrimitiveContext + rng::R + sampler::S +end +SamplingContext(sampler=SampleFromPrior()) = SamplingContext(Random.GLOBAL_RNG, sampler) + +######################## +### Wrapped contexts ### +######################## +abstract type WrappedContext{LeafCtx<:PrimitiveContext} <: AbstractContext end -The `PriorContext` enables the computation of the log prior of the parameters `vars` when -running the model. """ -struct PriorContext{Tvars} <: AbstractContext - vars::Tvars -end -PriorContext() = PriorContext(nothing) + childcontext(ctx) + +Returns the child-context of `ctx`. +Returns `nothing` if `ctx` is not a `WrappedContext`. """ - struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars - end +childcontext(ctx::WrappedContext) = ctx.ctx +childcontext(ctx::AbstractContext) = nothing -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values -of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ -struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars -end -LikelihoodContext() = LikelihoodContext(nothing) + unwrap(ctx::AbstractContext) +Returns the unwrapped context from `ctx`. """ - struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx - loglike_scalar::T - end +unwrap(ctx::WrappedContext) = unwrap(ctx.ctx) +unwrap(ctx::AbstractContext) = ctx -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing -`log(prior) + log(likelihood of all the data points)` in the expectation. """ -struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx - loglike_scalar::T -end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) -end + unwrappedtype(ctx::AbstractContext) -struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C -end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) -end +Returns the type of the unwrapped context from `ctx`. +""" +unwrappedtype(ctx::AbstractContext) = typeof(ctx) +unwrappedtype(ctx::WrappedContext{LeafCtx}) where {LeafCtx} = LeafCtx -const PREFIX_SEPARATOR = Symbol(".") +""" + rewrap(parent::WrappedContext, leaf::PrimitiveContext) -function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} -) where {PrefixInner,PrefixOuter} - if @generated - :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( - ctx.ctx - )) - else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) - end -end +Rewraps `leaf` in `parent`. Supports nested `WrappedContext`. +""" +rewrap(::AbstractContext, leaf::PrimitiveContext) = leaf -function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing)) - else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) - end -end +include("contexts/prior.jl") +include("contexts/likelihood.jl") +include("contexts/minibatch.jl") +include("contexts/prefix.jl") diff --git a/src/contexts/likelihood.jl b/src/contexts/likelihood.jl new file mode 100644 index 000000000..d2b673af9 --- /dev/null +++ b/src/contexts/likelihood.jl @@ -0,0 +1,23 @@ +""" + struct LikelihoodContext{Tvars} <: AbstractContext + vars::Tvars + end + +The `LikelihoodContext` enables the computation of the log likelihood of the parameters when +running the model. `vars` can be used to evaluate the log likelihood for specific values +of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. +""" +struct LikelihoodContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} + vars::Tvars + ctx::Ctx + + function LikelihoodContext(vars, ctx) + return new{typeof(vars),typeof(ctx),unwrappedtype(ctx)}(vars, ctx) + end +end +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) +LikelihoodContext(ctx::AbstractContext) = LikelihoodContext(nothing, ctx) + +function rewrap(parent::LikelihoodContext, leaf::PrimitiveContext) + return LikelihoodContext(parent.vars, rewrap(childcontext(parent), leaf)) +end diff --git a/src/contexts/minibatch.jl b/src/contexts/minibatch.jl new file mode 100644 index 000000000..a0c257b84 --- /dev/null +++ b/src/contexts/minibatch.jl @@ -0,0 +1,31 @@ +""" + struct MiniBatchContext{Tctx, T} <: AbstractContext + ctx::Tctx + loglike_scalar::T + end + +The `MiniBatchContext` enables the computation of +`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the +`loglike_scalar` field, typically equal to `the number of data points / batch size`. +This is useful in batch-based stochastic gradient descent algorithms to be optimizing +`log(prior) + log(likelihood of all the data points)` in the expectation. +""" +struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} + loglike_scalar::T + ctx::Ctx + + function MiniBatchContext(loglike_scalar, ctx::AbstractContext) + return new{typeof(loglike_scalar),typeof(ctx),unwrappedtype(ctx)}( + loglike_scalar, ctx + ) + end +end + +MiniBatchContext(loglike_scalar) = MiniBatchContext(loglike_scalar, EvaluationContext()) +function MiniBatchContext(ctx::AbstractContext=EvaluationContext(); batch_size, npoints) + return MiniBatchContext(npoints / batch_size, ctx) +end + +function rewrap(parent::MiniBatchContext, leaf::PrimitiveContext) + return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) +end diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl new file mode 100644 index 000000000..63da8cbba --- /dev/null +++ b/src/contexts/prefix.jl @@ -0,0 +1,34 @@ +struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} + ctx::C + + function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} + return new{Prefix,typeof(ctx),unwrappedtype(ctx)}(ctx) + end +end +PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(EvaluationContext()) + +function rewrap(parent::PrefixContext{Prefix}, leaf::PrimitiveContext) where {Prefix} + return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) +end + +const PREFIX_SEPARATOR = Symbol(".") + +function PrefixContext{PrefixInner}( + ctx::PrefixContext{PrefixOuter} +) where {PrefixInner,PrefixOuter} + if @generated + :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( + ctx.ctx + )) + else + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + end +end + +function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} + if @generated + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) + else + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) + end +end diff --git a/src/contexts/prior.jl b/src/contexts/prior.jl new file mode 100644 index 000000000..451752872 --- /dev/null +++ b/src/contexts/prior.jl @@ -0,0 +1,20 @@ +""" + struct PriorContext{Tvars} <: AbstractContext + vars::Tvars + end + +The `PriorContext` enables the computation of the log prior of the parameters `vars` when +running the model. +""" +struct PriorContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} + vars::Tvars + ctx::Ctx + + PriorContext(vars, ctx) = new{typeof(vars),typeof(ctx),unwrappedtype(ctx)}(vars, ctx) +end +PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) +PriorContext(ctx::AbstractContext) = PriorContext(nothing, ctx) + +function rewrap(parent::PriorContext, leaf::PrimitiveContext) + return PriorContext(parent.vars, rewrap(childcontext(parent), leaf)) +end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..247c19278 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,5 +1,5 @@ # Context version -struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext +struct PointwiseLikelihoodContext{A,Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglikelihoods::A ctx::Ctx end @@ -7,7 +7,15 @@ end function PointwiseLikelihoodContext( likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx),unwrappedtype(ctx)}( + likelihoods, ctx + ) +end + +function rewrap(parent::PointwiseLikelihoodContext, leaf::PrimitiveContext) + return PointwiseLikelihoodContext( + parent.loglikelihoods, rewrap(childcontext(parent), leaf) + ) end function Base.push!( @@ -52,25 +60,32 @@ function Base.push!( return ctx.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end -function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, left, vn, inds, vi) +end + +function tilde_observe!(ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi) + # This is slightly unfortunate since it is not completely generic... + # Ideally we would call `tilde_observe` recursively but then we don't get the + # loglikelihood value. + logp = tilde_observe(childcontext(ctx), right, left, vi) acclogp!(vi, logp) - return value + + # track loglikelihood value + push!(ctx, vname, logp) + + return left end -function tilde_observe( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi -) +function dot_tilde_observe!(ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi) # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) + logp = dot_tilde_observe(childcontext(ctx), right, left, vi) acclogp!(vi, logp) # track loglikelihood value @@ -150,17 +165,16 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, ctx) end niters = size(chain, 1) @@ -174,6 +188,6 @@ end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) + model(varinfo, ctx) return ctx.loglikelihoods end diff --git a/src/model.jl b/src/model.jl index 7189b590e..5151283ad 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,12 +86,12 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(rng, sampler), ) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end function (model::Model)(args...) @@ -108,8 +108,15 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end +# without VarInfo and without AbstractSampler +function (model::Model)( + rng::Random.AbstractRNG, varinfo::AbstractVarInfo, context::AbstractContext +) + return model(rng, varinfo, SampleFromPrior(), context) +end + """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +125,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +141,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ @@ -183,7 +190,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, SampleFromPrior(), EvaluationContext()) return getlogp(varinfo) end @@ -195,7 +202,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, SampleFromPrior(), PriorContext(nothing, EvaluationContext())) return getlogp(varinfo) end @@ -207,7 +214,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, SampleFromPrior(), LikelihoodContext(nothing, EvaluationContext())) return getlogp(varinfo) end diff --git a/src/prob_macro.jl b/src/prob_macro.jl index d761e9fdc..dabd897e6 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -142,11 +142,11 @@ function logprior( # When all of model args are on the lhs of |, this is also equal to the logjoint. model = make_prior_model(left, right, _model) - vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi + vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext(SamplingContext())) : _vi foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - model(vi, SampleFromPrior(), PriorContext(left)) + model(vi, SampleFromPrior(), PriorContext(left, EvaluationContext())) return getlogp(vi) end diff --git a/src/sampler.jl b/src/sampler.jl index 5e97a64e3..c124e8eb4 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -51,7 +51,7 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - model(rng, vi, sampler) + model(rng, vi, sampler, SamplingContext()) return vi, nothing end @@ -78,9 +78,9 @@ function AbstractMCMC.step( # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled if _spl isa SampleFromUniform - model(rng, vi, SampleFromPrior()) + model(rng, vi, SampleFromPrior(), EvaluationContext()) else - model(rng, vi, _spl) + model(rng, vi, _spl, EvaluationContext()) end end diff --git a/src/varinfo.jl b/src/varinfo.jl index e5e71eed1..ea394a032 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -126,7 +126,7 @@ function VarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(), ) varinfo = VarInfo() model(rng, varinfo, sampler, context) @@ -1399,90 +1399,3 @@ function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) return indices end - -""" - setval_and_resample!(vi::AbstractVarInfo, x) - setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")` -for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these -variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::AbstractVarInfo, x) - return _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x)) -end -function setval_and_resample!( - vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - return _apply!( - _setval_and_resample_kernel!, - vi, - chains.value[sample_idx, :, chain_idx], - keys(chains), - ) -end - -function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural) - val = reduce(vcat, values[sorted_indices]) - setval!(vi, val, vn) - settrans!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end diff --git a/test/varinfo.jl b/test/varinfo.jl index c936ad67c..26d5ce9c9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -202,37 +202,6 @@ DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end end