From d70e1be46058e912121b41dd0e2f0724c57474c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:21:20 +0100 Subject: [PATCH 01/22] added sampling context and unwrap_childcontext --- src/contexts.jl | 63 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 4d4f30bdc..1ee43f2b2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,39 @@ +""" + unwrap_childcontext(context::AbstractContext) + +Return a tuple of the child context of a `context`, or `nothing` if the context does +not wrap any other context, and a function `f(c::AbstractContext)` that constructs +an instance of `context` in which the child context is replaced with `c`. + +Falls back to `(nothing, _ -> context)`. +""" +function unwrap_childcontext(context::AbstractContext) + reconstruct_context(@nospecialize(x)) = context + return nothing, reconstruct_context +end + +""" + SamplingContext(rng, sampler, context) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext + rng::R + sampler::S + context::C +end + +function unwrap_childcontext(context::SamplingContext) + child = context.context + function reconstruct_samplingcontext(c::AbstractContext) + return SamplingContext(context.rng, context.sampler, c) + end + return child, reconstruct_samplingcontext +end + """ struct DefaultContext <: AbstractContext end @@ -53,6 +89,25 @@ function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) return MiniBatchContext(ctx, npoints / batch_size) end +function unwrap_childcontext(context::MiniBatchContext) + child = context.context + function reconstruct_minibatchcontext(c::AbstractContext) + return MiniBatchContext(c, context.loglike_scalar) + end + return child, reconstruct_minibatchcontext +end + +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext ctx::C end @@ -81,3 +136,11 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +function unwrap_childcontext(context::PrefixContext{P}) where {P} + child = context.context + function reconstruct_prefixcontext(c::AbstractContext) + return PrefixContext{P}(c) + end + return child, reconstruct_prefixcontext +end From f74399031eca9b5eaab824c16416a825c489f59c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:31 +0100 Subject: [PATCH 02/22] updated tilde methods --- src/context_implementations.jl | 451 +++++++++++++++++++++++++-------- 1 file changed, 352 insertions(+), 99 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0698b6cdf..8aa8ddfca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,28 +18,103 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return assume(rng, sampler, right, vn, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) + else + tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) + end +end + +# Leaf contexts +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(PriorContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume(::PriorContext, right, vn, inds, vi) + return assume(right, vn, inds, vi) +end +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, right, vn, vi) + return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, NoDist(right), vn, vi) + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::LikelihoodContext, right, vn, inds, vi) + return assume(NoDist(right), vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi +) + return assume(rng, sampler, NoDist(right), vn, inds, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) end """ @@ -50,27 +125,76 @@ accumulate the log probability, and return the sampled value. Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, sampler, right, vn, inds, vi) + value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vname, vinds, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vi) +end + +# Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return tilde_observe( + context.ctx, right, left, prefix(context, vname), vinds, vi + ) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.ctx, right, left, vi) end """ @@ -112,77 +236,179 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, inds, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + inds, + vi, ) + # Always overwrite the parameters with new ones. + r = init(rng, dist, sampler) if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - r = vi[vn] - end + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) + push!(vi, vn, r, dist, sampler) end + settrans!(vi, false, vn) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) + else + dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + end +end + +# `DefaultContext` +function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) + return dot_assume(right, vns, left, vi) +end + +function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value end function dot_tilde_assume( - rng, - ctx::PriorContext, + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + value, logp = dot_assume(right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -193,13 +419,26 @@ model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + inds, + vi, +) + @assert length(dist) == size(var, 1) + lp = sum(zip(vns, eachcol(var))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return var, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -214,6 +453,19 @@ function dot_assume( var .= r return var, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + inds, + vi, +) + # Make sure `var` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) + return var, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -323,18 +575,38 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) -end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +""" +function dot_tilde_observe(context::SamplingContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) end + +# Leaf contexts +dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) + +# `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +end + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) +end +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vi) +end """ dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) @@ -366,41 +638,22 @@ function dot_tilde_observe!(ctx, sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - value::AbstractMatrix, - vi, -) +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Distribution, - value::AbstractArray, - vi, -) +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::AbstractArray{<:Distribution}, - value::AbstractArray, - vi, -) +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end + From 3d2e7e2b4dfb2462e6732eb3f0b3ac5494d1696f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:48 +0100 Subject: [PATCH 03/22] updated model call signature --- src/model.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/model.jl b/src/model.jl index 7189b590e..250b89721 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,12 +88,18 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + return model(SamplingContext(rng, sampler, context), varinfo) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, sampler, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, sampler, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, sampler, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, sampler, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, sampler, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, sampler, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, sampler, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, sampler, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, sampler, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, sampler, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) end """ @@ -183,7 +189,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, DefaultContext()) return getlogp(varinfo) end @@ -195,7 +201,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +213,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end From 4f1d39694083bae41ac7e94cbb19640639512879 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:24:04 +0100 Subject: [PATCH 04/22] updated compiler --- src/compiler.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 20d8bf8ef..bc906f58c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -394,10 +394,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -407,7 +405,15 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + # in case someone accessed these + if __context__ isa $(DynamicPPL.SamplingContext) + __rng__ = __context__.rng + __sampler__ = __context__.sampler + end + + $(modelinfo[:body]) + end ## Build the model function. From b187d74efcfb5b9482f39022252477b5a0bc2cb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:27:37 +0100 Subject: [PATCH 05/22] formatting --- src/context_implementations.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8aa8ddfca..a0bf0381a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -45,7 +45,9 @@ end # Leaf contexts tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) +function tilde_assume( + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, right, vn, inds, vi) end @@ -184,14 +186,13 @@ function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) + return context.loglike_scalar * + tilde_observe(context.ctx, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe( - context.ctx, right, left, prefix(context, vname), vinds, vi - ) + return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.ctx, right, left, vi) @@ -296,7 +297,9 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) return if child_of_c === nothing dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) else - dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + dot_tilde_assume( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi + ) end end @@ -590,14 +593,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_observe(right, left, vi) +end # `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * + dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -656,4 +662,3 @@ function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end - From ee99f8ce5676c5cb571417e6dd4b3d9570922bf5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:30:54 +0100 Subject: [PATCH 06/22] added getsym for vectors --- src/varname.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..40c5c25e9 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,7 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From c4845d08b34b58bc30762b4be944c8924c9794e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:35:12 +0100 Subject: [PATCH 07/22] Update src/varname.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varname.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/varname.jl b/src/varname.jl index 40c5c25e9..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -40,6 +40,5 @@ Possibly existing indices of `varname` are neglected. return s in missings end - # HACK: Type-piracy. Is this really the way to go? AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From a0c05f39315c93d9ed43cefe227255310172f339 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:43:42 +0100 Subject: [PATCH 08/22] fixed some signatures for Model --- src/model.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/model.jl b/src/model.jl index 250b89721..8d353d2de 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,7 +88,7 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return model(SamplingContext(rng, sampler, context), varinfo) + return model(varinfo, SamplingContext(rng, sampler, context)) end (model::Model)(context::AbstractContext) = model(VarInfo(), context) @@ -115,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -124,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -140,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From 307cd7e1f3a1a02bd79284ccfc640848961e2bd9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:49:44 +0100 Subject: [PATCH 09/22] fixed a method call --- src/model.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8d353d2de..3a01f9bf3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,9 +94,9 @@ end (model::Model)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end @@ -151,7 +151,7 @@ end """ _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context From 597277119922a24d0783b78134a8f541eb05dc0e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:00:49 +0100 Subject: [PATCH 10/22] fixed method signatures --- src/compiler.jl | 8 ------ src/context_implementations.jl | 48 +++++++++++++++++----------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bc906f58c..8201a82f4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -283,7 +283,6 @@ function generate_tilde(left, right) return quote $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -300,9 +299,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -312,7 +309,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -334,7 +330,6 @@ function generate_dot_tilde(left, right) return quote $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -351,9 +346,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -363,7 +356,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a0bf0381a..d6ff3b5bd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(ctx, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, right, vn, inds, vi) + value, logp = tilde_assume(ctx, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(ctx, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(ctx, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `tilde(ctx, right, left, vi)`. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end @@ -415,15 +415,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(ctx, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -615,30 +615,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end From c4ecd0e676ab58aa13ef9995289d4368868d212c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:08:56 +0100 Subject: [PATCH 11/22] sort of fixed the matchingvalue functionality for model --- src/model.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 3a01f9bf3..2d74949c1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -157,7 +157,10 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + return quote + sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + model.f(model, varinfo, context, $(unwrap_args...)) + end end """ From a34b51cd60fffe8ff45903153550ff3486680f91 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 03:36:55 +0100 Subject: [PATCH 12/22] formatting --- src/compiler.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8201a82f4..dc70ae267 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -282,10 +282,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -329,10 +326,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end From e4a2cf81154e44b23af16e473d1361cf11225fc4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:02:41 +0100 Subject: [PATCH 13/22] removed left-over acclogp! that should not be here anymore --- src/context_implementations.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 42a336479..b088577f5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -345,16 +345,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(NoDist.(right), left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) end # `PriorContext` @@ -390,16 +386,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - value, logp = dot_assume(right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(right, left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, right, left, vn, inds, vi) end # `MiniBatchContext` From 7605785fff5407d558dd920720180efc0e41d885 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:04:29 +0100 Subject: [PATCH 14/22] export SamplingContext --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..3ad30972c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,6 +76,7 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts + SamplingContext, DefaultContext, LikelihoodContext, PriorContext, From 354ac52b0d2115b2df7d437407b98140a208ba5d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:25 +0100 Subject: [PATCH 15/22] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 64 +++++++++++++++++----------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b088577f5..f859b4619 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(ctx, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. +Falls back to `tilde_assume!(context, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, right, vn, inds, vi) - value, logp = tilde_assume(ctx, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vname, vinds, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, right, left, vi)`. +Falls back to `tilde(context, right, left, vi)`. """ -function tilde_observe!(ctx, right, left, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -302,11 +302,11 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) end # `DefaultContext` -function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(right, vns, left, vi) end -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end @@ -405,15 +405,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(ctx, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -583,17 +583,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) return dot_observe(right, left, vi) end # `MiniBatchContext` -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * - dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * + dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -605,30 +605,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vn, inds, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, right, left, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end From b7a2b3b5b5483eb11741a9a2c7b3135abaecbcad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:46 +0100 Subject: [PATCH 16/22] formatting --- src/context_implementations.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f859b4619..a8f279804 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -591,7 +591,9 @@ end function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe( + context::MiniBatchContext, sampler, right, left, vname, vinds, vi +) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end From 9e0fc9a9eecb6f74d54e0b4a1fad9cf94c0b41eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:39:41 +0100 Subject: [PATCH 17/22] use context instead of ctx for variables --- src/contexts.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 1ee43f2b2..8598fb633 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -71,7 +71,7 @@ LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -82,11 +82,11 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=DefaultContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) end function unwrap_childcontext(context::MiniBatchContext) @@ -109,23 +109,23 @@ unique. See also: [`@submodel`](@ref) """ struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end From 7a4a1a38ca6895c401e366e9b2707117b5ce36e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:40:18 +0100 Subject: [PATCH 18/22] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 47 ++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a8f279804..e66501aee 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -27,9 +27,9 @@ with a sampler. Falls back to ```julia -tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -112,11 +112,11 @@ function tilde_assume( end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end """ @@ -138,8 +138,8 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -159,8 +159,8 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -183,19 +183,19 @@ tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) return context.loglike_scalar * - tilde_observe(context.ctx, right, left, vname, vinds, vi) + tilde_observe(context.context, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.ctx, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ @@ -283,9 +283,9 @@ associated with a sampler. Falls back to ```julia -dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -396,12 +396,12 @@ end # `MiniBatchContext` function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -574,10 +574,10 @@ end Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value for a context associated with a sampler. -Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) + return dot_tilde_observe(context.context, right, left, vname, vinds, vi) end # Leaf contexts @@ -589,21 +589,24 @@ end # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) + return context.loglike_scalar * + dot_tilde_observe(context.context, sampler, right, left, vi) end function dot_tilde_observe( context::MiniBatchContext, sampler, right, left, vname, vinds, vi ) return context.loglike_scalar * - dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) end # `PrefixContext` function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return dot_tilde_observe( + context.context, right, left, prefix(context, vname), vinds, vi + ) end function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ From 7899473512e85089f0a6497823e804c0a2a7e12c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:27 +0100 Subject: [PATCH 19/22] Update src/compiler.jl Co-authored-by: David Widmann --- src/compiler.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dc70ae267..8734b72ed 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -392,12 +392,6 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. evaluatordef[:body] = quote - # in case someone accessed these - if __context__ isa $(DynamicPPL.SamplingContext) - __rng__ = __context__.rng - __sampler__ = __context__.sampler - end - $(modelinfo[:body]) end From 1630476c742f82e3f45096923e39a4b2da6150a2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:41 +0100 Subject: [PATCH 20/22] Update src/context_implementations.jl Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e66501aee..4fd787c86 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -138,6 +138,7 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From 6892d2b1276ef92f6765c3272673ac7a58682465 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:37:22 +0100 Subject: [PATCH 21/22] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4fd787c86..5647cd5fc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -160,6 +160,7 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From f52550f7538d19c1b1b099cca454ab43ec5bfdda Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Jun 2021 08:31:53 +0100 Subject: [PATCH 22/22] introduce EvaluationContext and make PriorContext and LikelihoodContext wrappers --- src/DynamicPPL.jl | 2 +- src/context_implementations.jl | 241 ++++++++++++--------------------- src/contexts.jl | 79 ++++++++--- src/model.jl | 4 +- src/varinfo.jl | 2 +- 5 files changed, 145 insertions(+), 183 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f6ba5bcb5..66dc16487 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,7 +76,7 @@ export AbstractVarInfo, SampleFromUniform, # Contexts SamplingContext, - DefaultContext, + EvaluationContext, LikelihoodContext, PriorContext, MiniBatchContext, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d1fd7b0ba..044198447 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -34,50 +34,32 @@ if the context `context.context` does not call any other context, as indicated b where `c` is a context in which the order of the sampling context and its child are swapped. """ function tilde_assume(context::SamplingContext, right, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) - else - tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) + if context.context isa Nothing + return assume(context.rng, context.sampler, right, vn, inds, vi) end -end -# Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi -) - return assume(rng, sampler, right, vn, inds, vi) + return tilde_assume(propogate_context(context), right, vn, inds, vi) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) +function tilde_assume(context::EvaluationContext, right, vn, inds, vi) + if context.context isa Nothing + return assume(right, vn, inds, vi) end - return tilde_assume(PriorContext(), right, vn, inds, vi) + + return tilde_assume(propogate_context(context), right, vn, inds, vi) end -function tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, -) + +# Leaf contexts +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) -end -function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, inds, vi) + return tilde_assume(PriorContext(context.context), right, vn, inds, vi) end -function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, inds, vi) + +function tilde_assume(context::PriorContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) @@ -85,30 +67,11 @@ function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(LikelihoodContext(), right, vn, inds, vi) + return tilde_assume(LikelihoodContext(context.context), right, vn, inds, vi) end -function tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, -) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) -end -function tilde_assume(::LikelihoodContext, right, vn, inds, vi) - return assume(NoDist(right), vn, inds, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi -) - return assume(rng, sampler, NoDist(right), vn, inds, vi) + +function tilde_assume(context::LikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, NoDist(right), vn, inds, vi) end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) @@ -146,14 +109,19 @@ context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls which the order of the sampling context and its child are swapped. """ function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) - else - c + if context.context isa Nothing + return observe(right, left, vi) end - return tilde_observe(fallback_context, right, left, vname, vinds, vi) + + return tilde_observe(propogate_context(context), right, left, vname, vinds, vi) +end + +function tilde_observe(context::EvaluationContext, right, left, vname, vinds, vi) + if context.context isa Nothing + return observe(right, left, vi) + end + + return tilde_observe(propogate_context(context), right, left, vname, vinds, vi) end """ @@ -168,26 +136,32 @@ context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls which the order of the sampling context and its child are swapped. """ function tilde_observe(context::SamplingContext, right, left, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) - else - c + if context.context isa Nothing + return observe(right, left, vi) end - return tilde_observe(fallback_context, right, left, vi) + + return tilde_observe(propogate_context(context), right, left, vi) +end + +function tilde_observe(context::EvaluationContext, right, left, vi) + if context.context isa Nothing + return observe(right, left, vi) + end + + return tilde_observe(propogate_context(context), right, left, vi) end # Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +function tilde_observe(context::LikelihoodContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) +end # `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vinds, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vinds, vi) end @@ -292,24 +266,19 @@ if the context `context.context` does not call any other context, as indicated b where `c` is a context in which the order of the sampling context and its child are swapped. """ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) - else - dot_tilde_assume( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi - ) + if context.context isa Nothing + return dot_assume(context.rng, context.sampler, right, vn, left, vi) end -end -# `DefaultContext` -function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(right, vns, left, vi) + return dot_tilde_assume(propogate_context(context), right, left, vn, inds, vi) end -function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(rng, sampler, right, vns, left, vi) +function dot_tilde_assume(context::EvaluationContext, right, left, vn, inds, vi) + if context.context isa Nothing + return dot_assume(right, left, vn, inds, vi) + end + + return dot_tilde_assume(propogate_context(context), right, left, vn, inds, vi) end # `LikelihoodContext` @@ -321,38 +290,14 @@ function dot_tilde_assume( _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + dot_tilde_assume(LikelihoodContext(context.context), _right, _left, _vns, inds, vi) else - dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - left, - vn, - inds, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) - else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(LikelihoodContext(context.context), right, left, vn, inds, vi) end end + function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - return dot_assume(NoDist.(right), left, vn, inds, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi -) - return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + return dot_tilde_assume(context.context, NoDist.(right), left, vn, inds, vi) end # `PriorContext` @@ -362,38 +307,14 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) - else - dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - left, - vn, - inds, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(PriorContext(context.context), _right, _left, _vns, inds, vi) else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(PriorContext(context.context), right, left, vn, inds, vi) end end + function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, inds, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi -) - return dot_assume(rng, sampler, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end # `MiniBatchContext` @@ -403,7 +324,9 @@ end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) + return dot_tilde_assume( + context.context, right, left, prefix.(Ref(context), vn), inds, vi + ) end """ @@ -576,27 +499,29 @@ probability, and return the observed value for a context associated with a sampl Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ -function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vname, vinds, vi) +function dot_tilde_observe( + context::Union{SamplingContext,EvaluationContext}, right, left, vi +) + if context.context isa Nothing + return dot_observe(right, left, vi) + end + + return dot_tilde_observe(propogate_context(context), right, left, vi) end # Leaf contexts -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) - return dot_observe(right, left, vi) +dot_tilde_observe(::PriorContext, right, left, vi) = 0 +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end # `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vi) +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) end -function dot_tilde_observe( - context::MiniBatchContext, sampler, right, left, vname, vinds, vi -) +function dot_tilde_observe(context::MiniBatchContext, right, left, vname, vinds, vi) return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) + dot_tilde_observe(context.context, right, left, vname, vinds, vi) end # `PrefixContext` diff --git a/src/contexts.jl b/src/contexts.jl index 6daa18776..39c30aa4c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -13,61 +13,100 @@ function unwrap_childcontext(context::AbstractContext) end """ - SamplingContext(rng, sampler, context) + propogate_context(context::AbstractContext) + +Wrap `context` in its child-context using [`unwrap_childcontext`](@ref), effectively +swapping the order of the two contexts. +""" +function propogate_context(context::AbstractContext) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return reconstruct_c(reconstruct_context(child_of_c)) +end + +""" + SamplingContext(rng=Random.GLOBAL_RNG, sampler=SampleFromPrior(), context=nothing) Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`EvaluationContext.`](@ref) """ -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext +struct SamplingContext{S<:AbstractSampler,C,R} <: AbstractContext rng::R sampler::S context::C end +function SamplingContext( + rng::Random.AbstractRNG, sampler::AbstractSampler=SampleFromPrior() +) + return SamplingContext(rng, sampler, nothing) +end +function SamplingContext(sampler::AbstractSampler=SampleFromPrior()) + return SamplingContext(Random.GLOBAL_RNG, sampler) +end + function unwrap_childcontext(context::SamplingContext) - child = context.context - function reconstruct_samplingcontext(c::AbstractContext) + function reconstruct_samplingcontext(c::Union{AbstractContext,Nothing}) return SamplingContext(context.rng, context.sampler, c) end - return child, reconstruct_samplingcontext + return context.context, reconstruct_samplingcontext end """ - struct DefaultContext <: AbstractContext end + EvaluationContext(context=nothing) -The `DefaultContext` is used by default to compute log the joint probability of the data -and parameters when running the model. +Create a context that allows you to evaluate the model without performing any sampling. +The `context` determines how the returned log density is computed when running the model. + +See also: [`SamplingContext`](@ref) """ -struct DefaultContext <: AbstractContext end +struct EvaluationContext{Ctx} <: AbstractContext + context::Ctx +end + +EvaluationContext() = EvaluationContext(nothing) + +function unwrap_childcontext(context::EvaluationContext) + function reconstruct_evaluationcontext(c::Union{AbstractContext,Nothing}) + return EvaluationContext(c) + end + return context.context, reconstruct_evaluationcontext +end """ - struct PriorContext{Tvars} <: AbstractContext + struct PriorContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ -struct PriorContext{Tvars} <: AbstractContext +struct PriorContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end -PriorContext() = PriorContext(nothing) +PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) +PriorContext(context::AbstractContext) = PriorContext(nothing, context) """ - struct LikelihoodContext{Tvars} <: AbstractContext + struct LikelihoodContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end The `LikelihoodContext` enables the computation of the log likelihood of the parameters when running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ -struct LikelihoodContext{Tvars} <: AbstractContext +struct LikelihoodContext{Tvars,Ctx} <: AbstractContext vars::Tvars + context::Ctx end -LikelihoodContext() = LikelihoodContext(nothing) +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) +LikelihoodContext(context::AbstractContext) = LikelihoodContext(nothing, context) """ struct MiniBatchContext{Tctx, T} <: AbstractContext @@ -85,16 +124,15 @@ struct MiniBatchContext{Tctx,T} <: AbstractContext context::Tctx loglike_scalar::T end -function MiniBatchContext(context=DefaultContext(); batch_size, npoints) +function MiniBatchContext(context=EvaluationContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end function unwrap_childcontext(context::MiniBatchContext) - child = context.context function reconstruct_minibatchcontext(c::AbstractContext) return MiniBatchContext(c, context.loglike_scalar) end - return child, reconstruct_minibatchcontext + return context.context, reconstruct_minibatchcontext end """ @@ -138,9 +176,8 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end function unwrap_childcontext(context::PrefixContext{P}) where {P} - child = context.context function reconstruct_prefixcontext(c::AbstractContext) return PrefixContext{P}(c) end - return child, reconstruct_prefixcontext + return context.context, reconstruct_prefixcontext end diff --git a/src/model.jl b/src/model.jl index 2d74949c1..430e09993 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,7 +86,7 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::Union{AbstractContext,Nothing}=nothing, ) return model(varinfo, SamplingContext(rng, sampler, context)) end @@ -192,7 +192,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, DefaultContext()) + model(varinfo, EvaluationContext()) return getlogp(varinfo) end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..01b55b9de 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -126,7 +126,7 @@ function VarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::Union{AbstractContext,Nothing}=nothing, ) varinfo = VarInfo() model(rng, varinfo, sampler, context)