diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9eb4d9675..89564d613 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -75,7 +75,8 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts - DefaultContext, + SamplingContext, + EvaluationContext, LikelihoodContext, PriorContext, MiniBatchContext, diff --git a/src/compiler.jl b/src/compiler.jl index 2e368d32b..e647df99c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -286,11 +286,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -304,9 +300,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -316,7 +310,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -337,11 +330,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -355,9 +344,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -367,7 +354,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -411,7 +395,9 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + $(modelinfo[:body]) + end ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 60df298b5..044198447 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,86 +18,187 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return assume(rng, sampler, right, vn, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +``` +if the context `context.context` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + if context.context isa Nothing + return assume(context.rng, context.sampler, right, vn, inds, vi) + end + + return tilde_assume(propogate_context(context), right, vn, inds, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::EvaluationContext, right, vn, inds, vi) + if context.context isa Nothing + return assume(right, vn, inds, vi) + end + + return tilde_assume(propogate_context(context), right, vn, inds, vi) +end + +# Leaf contexts +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, right, vn, vi) + return tilde_assume(PriorContext(context.context), right, vn, inds, vi) +end + +function tilde_assume(context::PriorContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, NoDist(right), vn, vi) + return tilde_assume(LikelihoodContext(context.context), right, vn, inds, vi) +end + +function tilde_assume(context::LikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, NoDist(right), vn, inds, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(context, right, vn, inds, vi)`. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. + +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + if context.context isa Nothing + return observe(right, left, vi) + end + + return tilde_observe(propogate_context(context), right, left, vname, vinds, vi) +end + +function tilde_observe(context::EvaluationContext, right, left, vname, vinds, vi) + if context.context isa Nothing + return observe(right, left, vi) + end + + return tilde_observe(propogate_context(context), right, left, vname, vinds, vi) +end + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. + +Falls back to `tilde_observe(context.context, right, left, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + if context.context isa Nothing + return observe(right, left, vi) + end + + return tilde_observe(propogate_context(context), right, left, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +function tilde_observe(context::EvaluationContext, right, left, vi) + if context.context isa Nothing + return observe(right, left, vi) + end + + return tilde_observe(propogate_context(context), right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) + +# Leaf contexts +tilde_observe(::PriorContext, right, left, vi) = 0 +function tilde_observe(context::LikelihoodContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vinds, vi) + return context.loglike_scalar * + tilde_observe(context.context, right, left, vname, vinds, vi) +end + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) +end +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(context, right, left, vname, vinds, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `tilde(context, right, left, vi)`. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -110,94 +211,152 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, inds, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + inds, + vi, ) + # Always overwrite the parameters with new ones. + r = init(rng, dist, sampler) if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - r = vi[vn] - end + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) + push!(vi, vn, r, dist, sampler) end + settrans!(vi, false, vn) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) - return dot_assume(rng, sampler, right, vns, left, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) +``` +if the context `context.context` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + if context.context isa Nothing + return dot_assume(context.rng, context.sampler, right, vn, left, vi) + end + + return dot_tilde_assume(propogate_context(context), right, left, vn, inds, vi) +end + +function dot_tilde_assume(context::EvaluationContext, right, left, vn, inds, vi) + if context.context isa Nothing + return dot_assume(right, left, vn, inds, vi) + end + + return dot_tilde_assume(propogate_context(context), right, left, vn, inds, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, - sampler, - right, - left, - vns::AbstractArray{<:VarName{sym}}, - inds, - vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(context.context), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(context.context), right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) + +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, NoDist.(right), left, vn, inds, vi) end -function dot_tilde_assume( - rng, - ctx::PriorContext, - sampler, - right, - left, - vns::AbstractArray{<:VarName{sym}}, - inds, - vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(context.context), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(context.context), right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end + +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume( + context.context, right, left, prefix.(Ref(context), vn), inds, vi + ) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + inds, + vi, +) + @assert length(dist) == size(var, 1) + lp = sum(zip(vns, eachcol(var))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return var, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -211,6 +370,19 @@ function dot_assume( lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) return r, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + inds, + vi, +) + # Make sure `var` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) + return var, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -319,84 +491,94 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. +""" +function dot_tilde_observe( + context::Union{SamplingContext,EvaluationContext}, right, left, vi +) + if context.context isa Nothing + return dot_observe(right, left, vi) + end + + return dot_tilde_observe(propogate_context(context), right, left, vi) +end + +# Leaf contexts +dot_tilde_observe(::PriorContext, right, left, vi) = 0 +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +# `MiniBatchContext` +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +function dot_tilde_observe(context::MiniBatchContext, right, left, vname, vinds, vi) + return context.loglike_scalar * + dot_tilde_observe(context.context, right, left, vname, vinds, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return dot_tilde_observe( + context.context, right, left, prefix(context, vname), vinds, vi + ) +end +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(context, right, left, vn, inds, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - value::AbstractMatrix, - vi, -) +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Distribution, - value::AbstractArray, - vi, -) +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::AbstractArray{<:Distribution}, - value::AbstractArray, - vi, -) +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end diff --git a/src/contexts.jl b/src/contexts.jl index 2c23531c6..39c30aa4c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,41 +1,116 @@ """ - struct DefaultContext <: AbstractContext end + unwrap_childcontext(context::AbstractContext) + +Return a tuple of the child context of a `context`, or `nothing` if the context does +not wrap any other context, and a function `f(c::AbstractContext)` that constructs +an instance of `context` in which the child context is replaced with `c`. + +Falls back to `(nothing, _ -> context)`. +""" +function unwrap_childcontext(context::AbstractContext) + reconstruct_context(@nospecialize(x)) = context + return nothing, reconstruct_context +end + +""" + propogate_context(context::AbstractContext) + +Wrap `context` in its child-context using [`unwrap_childcontext`](@ref), effectively +swapping the order of the two contexts. +""" +function propogate_context(context::AbstractContext) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return reconstruct_c(reconstruct_context(child_of_c)) +end + +""" + SamplingContext(rng=Random.GLOBAL_RNG, sampler=SampleFromPrior(), context=nothing) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`EvaluationContext.`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C,R} <: AbstractContext + rng::R + sampler::S + context::C +end + +function SamplingContext( + rng::Random.AbstractRNG, sampler::AbstractSampler=SampleFromPrior() +) + return SamplingContext(rng, sampler, nothing) +end +function SamplingContext(sampler::AbstractSampler=SampleFromPrior()) + return SamplingContext(Random.GLOBAL_RNG, sampler) +end + +function unwrap_childcontext(context::SamplingContext) + function reconstruct_samplingcontext(c::Union{AbstractContext,Nothing}) + return SamplingContext(context.rng, context.sampler, c) + end + return context.context, reconstruct_samplingcontext +end -The `DefaultContext` is used by default to compute log the joint probability of the data -and parameters when running the model. """ -struct DefaultContext <: AbstractContext end + EvaluationContext(context=nothing) + +Create a context that allows you to evaluate the model without performing any sampling. +The `context` determines how the returned log density is computed when running the model. + +See also: [`SamplingContext`](@ref) +""" +struct EvaluationContext{Ctx} <: AbstractContext + context::Ctx +end + +EvaluationContext() = EvaluationContext(nothing) + +function unwrap_childcontext(context::EvaluationContext) + function reconstruct_evaluationcontext(c::Union{AbstractContext,Nothing}) + return EvaluationContext(c) + end + return context.context, reconstruct_evaluationcontext +end """ - struct PriorContext{Tvars} <: AbstractContext + struct PriorContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ -struct PriorContext{Tvars} <: AbstractContext +struct PriorContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end -PriorContext() = PriorContext(nothing) +PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) +PriorContext(context::AbstractContext) = PriorContext(nothing, context) """ - struct LikelihoodContext{Tvars} <: AbstractContext + struct LikelihoodContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end The `LikelihoodContext` enables the computation of the log likelihood of the parameters when running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ -struct LikelihoodContext{Tvars} <: AbstractContext +struct LikelihoodContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end -LikelihoodContext() = LikelihoodContext(nothing) +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) +LikelihoodContext(context::AbstractContext) = LikelihoodContext(nothing, context) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -46,31 +121,49 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=EvaluationContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) +end + +function unwrap_childcontext(context::MiniBatchContext) + function reconstruct_minibatchcontext(c::AbstractContext) + return MiniBatchContext(c, context.loglike_scalar) + end + return context.context, reconstruct_minibatchcontext end +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end @@ -81,3 +174,10 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +function unwrap_childcontext(context::PrefixContext{P}) where {P} + function reconstruct_prefixcontext(c::AbstractContext) + return PrefixContext{P}(c) + end + return context.context, reconstruct_prefixcontext +end diff --git a/src/model.jl b/src/model.jl index 7189b590e..430e09993 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,14 +86,20 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::Union{AbstractContext,Nothing}=nothing, ) + return model(varinfo, SamplingContext(rng, sampler, context)) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,27 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return quote + sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + model.f(model, varinfo, context, $(unwrap_args...)) + end end """ @@ -183,7 +192,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, EvaluationContext()) return getlogp(varinfo) end @@ -195,7 +204,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +216,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..01b55b9de 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -126,7 +126,7 @@ function VarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::Union{AbstractContext,Nothing}=nothing, ) varinfo = VarInfo() model(rng, varinfo, sampler, context) diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,6 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym