Skip to content

rogue idea: make SamplingContext a leaf context + make a new InitialisationContext #955

Open
@penelopeysm

Description

@penelopeysm

Why leaf context?

The main thing that made me come up with this was the observation that nested SamplingContext makes no sense semantically. You're either sampling with a sampler, or you're not -- you can't be sampling with two samplers.

Extra rambly context

Specifically, it was while working on TuringLang/Turing.jl#2588: in that PR, I had to make sure that any LogDensityFunction passed to sample is equipped with a SamplingContext before sending it down the AbstractMCMC pipeline. The easiest way to do this is to just wrap the existing context in a new SamplingContext, but what happens if the existing context already has a SamplingContext? Surely this should replace the existing SamplingContext (preferable), or error (less preferable).

It is technically possible to keep SamplingContext as a parent context while forbidding nesting (e.g. with an inner constructor), but that's just an objectively worse data structure.

Once we remove PriorContext and LikelihoodContext (technically done already, just needs to be released when ready), the only other leaf context would be DynamicTransformationContext. There is no world where we use DynamicTransformationContext together with SamplingContext. In other words they are mutually exclusive:

This function ... ...calls evaluate!! with this context
(e.g.) logjoint DefaultContext
sample!! (pending #952) SamplingContext
link!! DynamicTransformationContext

In other words, leaf contexts can actually be used to distinguish what we are doing with a model. (See also: #510, #953, for related discussions on this idea of using contexts to figure out what operations are being carried out.)

Why InitialisationContext?

The other part of this proposal, which I think might be more controversial, is to introduce a new leaf context, i.e. InitialisationContext. The rough idea is that, just like how SamplingContext carries a sampler with it, InitialisationContext will carry an initialisation strategy with it. That is to say, I would like something like this:

abstract type InitStrategy

# replacement for SampleFromPrior
struct Prior <: InitStrategy end

# replacement for SampleFromUniform
struct UniformLinked <: InitStrategy
    lower::Real
    upper::Real
end

struct InitContext{R<:Random.AbstractRNG,I<:InitStrategy} <: AbstractContext
    rng::R
    strategy::I
end

# I'm partly using `Init` instead of `Initialisation` because it's much shorter, but
# it also avoids the AmE/BrE spelling difference

One might ask why create a new leaf context? After all, right now we are using SamplingContext(rng, SampleFromPrior(), ...) to initialise values in a VarInfo just fine.

One nice benefit is that having a specific InitStrategy struct allows us to expose, at a high level, different ways of initialising values in a VarInfo. See TuringLang/Turing.jl#2476 (comment) -- I think, the implementation of that would readily present itself if we had something like this.

I think the main difference though is to clarify the behaviour when coming across a variable that isn't yet initialised. The point would be that SamplingContext is exclusively used to allow for samplers to override the tilde-pipeline, whereas if you want to add new values to a VarInfo, you would have to use InitialisationContext. Right now, these two types of behaviour are all smushed together in the implementation for assume(..., ::SampleFromPrior, ...):

# TODO: Remove this thing.
# SampleFromPrior and SampleFromUniform
function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
# if that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
# Also, if we use !! we shouldn't ignore the return value.
BangBang.setindex!!(vi, f(r), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, vn, dist)
push!!(vi, vn, f(r), dist)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist)
end
end
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
return r, logpdf(dist, r) - logjac, vi
end
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi)
function observe(right::Distribution, left, vi)
increment_num_produce!(vi)
return Distributions.loglikelihood(right, left), vi
end

And it doesn't help that quite a number of actual samplers simply delegate to SampleFromPrior (grep for this in Turing) and so the behaviour is quite difficult to predict.

Having these two separate would allow for a cleaner separation of concerns.

Potential issue 1: when the concerns aren't separate

The immediate concern I had with this is that sometimes there are uninitialised variables when sampling. For example:

@model function patho()
    x ~ Normal()
    if x > 0
        y ~ Normal()
    else
        z ~ Normal()
    end
end

It would be very possible to create a VarInfo that has x and y, and then on a second evaluation come across z, in which case we would have to decide how to deal with it.

The truth though is that this model currently errors:

julia> sample(patho(), NUTS(), 100)
ERROR: type NamedTuple has no field z [...]

julia> sample(patho(), MH(), 100)
ERROR: type NamedTuple has no field y [...]

so we aren't immediately losing any functionality, and indeed I think that having separation of concerns might help us implement this (cursed) functionality in a better way, if it's something we actually want to do.

The other potential issue

To the best of my knowledge, SamplingContext doesn't have any nasty interactions with other contexts, in the sense that (e.g.) I don't think it ever makes a difference whether SamplingContext is nested inside ConditionContext or the other way around. Gibbs, as always, is the one area where I cannot say this with confidence.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions