diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f6ba5bcb5..66dc16487 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,7 +76,7 @@ export AbstractVarInfo, SampleFromUniform, # Contexts SamplingContext, - DefaultContext, + EvaluationContext, LikelihoodContext, PriorContext, MiniBatchContext, diff --git a/src/compiler.jl b/src/compiler.jl index 8734b72ed..8fbb77f16 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -391,7 +391,15 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. + @gensym leafctx evaluatordef[:body] = quote + # in case someone accessed these + $leafctx = DynamicPPL.unwrap(__context__) + if $leafctx isa $(DynamicPPL.SamplingContext) + __rng__ = $leafctx.rng + __sampler__ = $leafctx.sampler + end + $(modelinfo[:body]) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5647cd5fc..f970b475f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,105 +18,41 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -""" - tilde_assume(context::SamplingContext, right, vn, inds, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) -``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. -""" +# Leaf contexts function tilde_assume(context::SamplingContext, right, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) - else - tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) - end + return assume(context.rng, context.sampler, right, vn, inds, vi) end +tilde_assume(context::EvaluationContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -# Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi -) - return assume(rng, sampler, right, vn, inds, vi) +# Default for `WrappedContext` +function tilde_assume(context::WrappedContext, right, left, inds, vi) + return tilde_assume(childcontext(context), right, left, inds, vi) end +# `PriorContext` function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(PriorContext(), right, vn, inds, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, -) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) -end -function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, inds, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, inds, vi) + return tilde_assume(PriorContext(childcontext(context)), right, vn, inds, vi) end -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return tilde_assume(LikelihoodContext(), right, vn, inds, vi) +# `LikelihoodContext` +function tilde_assume(context::LikelihoodContext, right, vn, inds, vi) + return tilde_assume(childcontext(context), NoDist(right), vn, inds, vi) end -function tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, -) +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) -end -function tilde_assume(::LikelihoodContext, right, vn, inds, vi) - return assume(NoDist(right), vn, inds, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi -) - return assume(rng, sampler, NoDist(right), vn, inds, vi) -end - -function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) + return tilde_assume(LikelihoodContext(childcontext(context)), right, vn, inds, vi) end +# `PrefixContext` function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.context, right, prefix(context, vn), inds, vi) + return tilde_assume(childcontext(context), right, prefix(context, vn), inds, vi) end """ @@ -134,70 +70,22 @@ function tilde_assume!(context, right, vn, inds, vi) end # observe -""" - tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - -Handle observed variables with a `context` associated with a sampler. - -Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. -""" -function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) - else - c - end - return tilde_observe(fallback_context, right, left, vname, vinds, vi) +# Leaf contexts +function tilde_observe(context::Union{SamplingContext,EvaluationContext}, right, left, vi) + return observe(right, left, vi) end -""" - tilde_observe(context::SamplingContext, right, left, vi) - -Handle observed constants with a `context` associated with a sampler. - -Falls back to `tilde_observe(context.context, right, left, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. -""" -function tilde_observe(context::SamplingContext, right, left, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) - else - c - end - return tilde_observe(fallback_context, right, left, vi) +# Default for `WrappedContext` +function tilde_observe(context::WrappedContext, right, left, vi) + return tilde_observe(childcontext(context), right, left, vi) end -# Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +# `PriorContext` +tilde_observe(context::PriorContext, right, left, vi) = 0 # `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.context, right, left, vi) -end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * - tilde_observe(context.context, right, left, vname, vinds, vi) -end - -# `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) -end -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * tilde_observe(childcontext(context), right, left, vi) end """ @@ -276,84 +164,40 @@ end # .~ functions # assume -""" - dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value for a context -associated with a sampler. - -Falls back to -```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) -``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. -""" -function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) - else - dot_tilde_assume( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi - ) - end +# Leaf contexts +function dot_tilde_assume(context::SamplingContext, right, left, vns, _, vi) + return dot_assume(context.rng, context.sampler, right, vns, left, vi) end - -# `DefaultContext` -function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(right, vns, left, vi) +function dot_tilde_assume(context::EvaluationContext, right, left, vns, inds, vi) + return dot_assume(right, vns, left, inds, vi) end -function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(rng, sampler, right, vns, left, vi) +# Default for `WrappedContext` +function dot_tilde_assume(context::WrappedContext, right, left, vns, inds, vi) + return dot_tilde_assume(childcontext(context), right, vns, left, vi) end # `LikelihoodContext` -function dot_tilde_assume( - context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi -) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) - else - dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) - end +function dot_tilde_assume(context::LikelihoodContext, right, left, vns, inds, vi) + return dot_tilde_assume(childcontext(context), NoDist.(right), vns, left, vi) end function dot_tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - left, - vn, - inds, - vi, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi ) return if haskey(context.vars, getsym(vn)) var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume( + LikelihoodContext(childcontext(context)), _right, _left, _vns, inds, vi + ) else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume( + LikelihoodContext(childcontext(context)), right, left, vn, inds, vi + ) end end -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - return dot_assume(NoDist.(right), left, vn, inds, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi -) - return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) -end # `PriorContext` function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) @@ -362,48 +206,17 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) - else - dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - left, - vn, - inds, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(PriorContext(childcontext(context)), _right, _left, _vns, inds, vi) else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(PriorContext(childcontext(context)), right, left, vn, inds, vi) end end -function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, inds, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi -) - return dot_assume(rng, sampler, right, left, vn, inds, vi) -end - -# `MiniBatchContext` -function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) + return dot_tilde_assume( + childcontext(context), right, prefix.(Ref(context), vn), inds, vi + ) end """ @@ -570,46 +383,46 @@ function set_val!( end # observe -""" - dot_tilde_observe(context::SamplingContext, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value for a context associated with a sampler. - -Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. -""" -function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vname, vinds, vi) +function dot_tilde_observe( + context::Union{SamplingContext,EvaluationContext}, right, left, vi +) + return dot_observe(right, left, vi) end - -# Leaf contexts -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe( + context::Union{SamplingContext,EvaluationContext}, right, left, vname, vinds, vi +) return dot_observe(right, left, vi) end +# Default for `WrappedContext` +function dot_tilde_observe(context::WrappedContext, right, left, vi) + return dot_tilde_observe(childcontext(context), right, left, vi) +end +function dot_tilde_observe(context::WrappedContext, right, left, vname, vinds, vi) + return dot_tilde_observe(childcontext(context), right, left, vname, vinds, vi) +end + +# `PriorContext` +dot_tilde_observe(context::PriorContext, right, left, vi) = 0 +dot_tilde_observe(context::PriorContext, right, left, vname, vinds, vi) = 0 # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vi) + dot_tilde_observe(childcontext(context), sampler, right, left, vi) end function dot_tilde_observe( context::MiniBatchContext, sampler, right, left, vname, vinds, vi ) return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) + dot_tilde_observe(childcontext(context), sampler, right, left, vname, vinds, vi) end # `PrefixContext` function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) return dot_tilde_observe( - context.context, right, left, prefix(context, vname), vinds, vi + childcontext(context), right, left, prefix(context, vname), vinds, vi ) end -function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end """ dot_tilde_observe!(context, right, left, vname, vinds, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 8598fb633..ef4d600ea 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,46 +1,53 @@ +abstract type PrimitiveContext <: AbstractContext end +struct EvaluationContext{S<:AbstractSampler} <: PrimitiveContext + # TODO: do we even need the sampler these days? + sampler::S +end +EvaluationContext() = EvaluationContext(SampleFromPrior()) + +struct SamplingContext{R<:Random.AbstractRNG,S<:AbstractSampler} <: PrimitiveContext + rng::R + sampler::S +end +SamplingContext(sampler=SampleFromPrior()) = SamplingContext(Random.GLOBAL_RNG, sampler) + +######################## +### Wrapped contexts ### +######################## +abstract type WrappedContext{LeafCtx<:PrimitiveContext} <: AbstractContext end + """ - unwrap_childcontext(context::AbstractContext) + childcontext(context) -Return a tuple of the child context of a `context`, or `nothing` if the context does -not wrap any other context, and a function `f(c::AbstractContext)` that constructs -an instance of `context` in which the child context is replaced with `c`. +Returns the child-context of `context`. -Falls back to `(nothing, _ -> context)`. +Returns `nothing` if `context` is not a `WrappedContext`. """ -function unwrap_childcontext(context::AbstractContext) - reconstruct_context(@nospecialize(x)) = context - return nothing, reconstruct_context -end +childcontext(context::WrappedContext) = context.context +childcontext(context::AbstractContext) = nothing """ - SamplingContext(rng, sampler, context) + unwrap(context::AbstractContext) -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. +Returns the unwrapped context from `context`. +""" +unwrap(context::WrappedContext) = unwrap(context.context) +unwrap(context::AbstractContext) = context -See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) """ -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end + unwrappedtype(context::AbstractContext) -function unwrap_childcontext(context::SamplingContext) - child = context.context - function reconstruct_samplingcontext(c::AbstractContext) - return SamplingContext(context.rng, context.sampler, c) - end - return child, reconstruct_samplingcontext -end +Returns the type of the unwrapped context from `context`. +""" +unwrappedtype(context::AbstractContext) = typeof(context) +unwrappedtype(context::WrappedContext{LeafCtx}) where {LeafCtx} = LeafCtx """ - struct DefaultContext <: AbstractContext end + rewrap(parent::WrappedContext, leaf::PrimitiveContext) -The `DefaultContext` is used by default to compute log the joint probability of the data -and parameters when running the model. +Rewraps `leaf` in `parent`. Supports nested `WrappedContext`. """ -struct DefaultContext <: AbstractContext end +rewrap(::AbstractContext, leaf::PrimitiveContext) = leaf """ struct PriorContext{Tvars} <: AbstractContext @@ -50,10 +57,20 @@ struct DefaultContext <: AbstractContext end The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ -struct PriorContext{Tvars} <: AbstractContext +struct PriorContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars + context::Ctx + + function PriorContext(vars, context) + return new{typeof(vars),typeof(context),unwrappedtype(context)}(vars, context) + end +end +PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) +PriorContext(context::AbstractContext) = PriorContext(nothing, context) + +function rewrap(parent::PriorContext, leaf::PrimitiveContext) + return PriorContext(parent.vars, rewrap(childcontext(parent), leaf)) end -PriorContext() = PriorContext(nothing) """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -64,10 +81,20 @@ The `LikelihoodContext` enables the computation of the log likelihood of the par running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ -struct LikelihoodContext{Tvars} <: AbstractContext +struct LikelihoodContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars + context::Ctx + + function LikelihoodContext(vars, context) + return new{typeof(vars),typeof(context),unwrappedtype(context)}(vars, context) + end +end +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) +LikelihoodContext(context::AbstractContext) = LikelihoodContext(nothing, context) + +function rewrap(parent::LikelihoodContext, leaf::PrimitiveContext) + return LikelihoodContext(parent.vars, rewrap(childcontext(parent), leaf)) end -LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext @@ -81,20 +108,24 @@ The `MiniBatchContext` enables the computation of This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation. """ -struct MiniBatchContext{Tctx,T} <: AbstractContext - context::Tctx +struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglike_scalar::T + context::Ctx + + function MiniBatchContext(loglike_scalar, context::AbstractContext) + return new{typeof(loglike_scalar),typeof(context),unwrappedtype(context)}( + loglike_scalar, context + ) + end end -function MiniBatchContext(context=DefaultContext(); batch_size, npoints) - return MiniBatchContext(context, npoints / batch_size) + +MiniBatchContext(loglike_scalar) = MiniBatchContext(loglike_scalar, EvaluationContext()) +function MiniBatchContext(context::AbstractContext=EvaluationContext(); batch_size, npoints) + return MiniBatchContext(npoints / batch_size, context) end -function unwrap_childcontext(context::MiniBatchContext) - child = context.context - function reconstruct_minibatchcontext(c::AbstractContext) - return MiniBatchContext(c, context.loglike_scalar) - end - return child, reconstruct_minibatchcontext +function rewrap(parent::MiniBatchContext, leaf::PrimitiveContext) + return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) end """ @@ -108,11 +139,17 @@ unique. See also: [`@submodel`](@ref) """ -struct PrefixContext{Prefix,C} <: AbstractContext +struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} context::C + + function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return new{Prefix,typeof(context),unwrappedtype(context)}(context) + end end -function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(context)}(context) +PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(EvaluationContext()) + +function rewrap(parent::PrefixContext{Prefix}, leaf::PrimitiveContext) where {Prefix} + return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) end const PREFIX_SEPARATOR = Symbol(".") @@ -121,7 +158,7 @@ function PrefixContext{PrefixInner}( context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated - :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( + :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( context.context )) else @@ -131,16 +168,8 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing)) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) else VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end - -function unwrap_childcontext(context::PrefixContext{P}) where {P} - child = context.context - function reconstruct_prefixcontext(c::AbstractContext) - return PrefixContext{P}(c) - end - return child, reconstruct_prefixcontext -end diff --git a/src/model.jl b/src/model.jl index 2d74949c1..e7d86fe2a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,9 +86,12 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(rng, sampler), ) - return model(varinfo, SamplingContext(rng, sampler, context)) + # In case `context` is a `WrapperContext` of sorts, we need to `rewrap` to ensure + # that context has a `SamplingContext` as the leaf context. + context_new = rewrap(context, SamplingContext(rng, sampler)) + return model(varinfo, context_new) end (model::Model)(context::AbstractContext) = model(VarInfo(), context) @@ -158,7 +161,7 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] return quote - sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + sampler = unwrap(context).sampler model.f(model, varinfo, context, $(unwrap_args...)) end end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..8e030f4a1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -126,7 +126,7 @@ function VarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(rng, sampler), ) varinfo = VarInfo() model(rng, varinfo, sampler, context)