Skip to content

WrappedContext post introduction of SamplingContext #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jun 1, 2021

This PR introduces the WrappedContext discussed in #249 but in a post #253 world.

As I've mentioned before, I'm of the opinion that this hierarchy simplifies implementation and imposes some very natural constraints. The argument against taking this approach was that it was unclear whether it was worth it/the benefit wasn't quite clear.

But I'll leave this draft here in case we want to revisit the idea either for inclusion or just for discussion.

@devmotion devmotion mentioned this pull request Jun 3, 2021
1 task
bors bot pushed a commit that referenced this pull request Aug 4, 2021
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .

(I hope you're proud of me @devmotion )

## Current state of things
Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods:

```julia
tilde_assume(...)
tilde_observe(...)
dot_tilde_assume(...)
dot_tilde_observe(...)
```

But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254.

## This sucks
This means that:
1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to:
   1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.
   2. Implement the method for that particular context.
2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial.

This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.

And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. 

## Goal
A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`.

## Solution (this PR)
The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).

This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`.

As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either
1. replace `minibatchcontext.context` with `DefaultContext()`, or
2. recursively "rewrap" `DefaultContext()` in `minibatch.context`.

(1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.).

Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).

## Examples

### `GeneratedQuantitiesContext`

In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts:

```julia
using DynamicPPL, Distributions, Random
using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext

struct GeneratedQuantitiesContext{Ctx} <: AbstractContext
    context::Ctx
end
GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext())

# Define the `NodeTrait` for `GeneratedQuantitiesContext`.
DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent()
DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context

"""
    isgeneratedquantities(context)

Return `true` if `context` wants evaluation of model to execute
the `@generatedquantities` block.
"""
function isgeneratedquantities(context::AbstractContext)
    return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context)
end

# Define the behavior for the different `NodeType`s.
isgeneratedquantities(::IsLeaf, context::AbstractContext) = false
function isgeneratedquantities(::IsParent, context::AbstractContext)
    return isgeneratedquantities(childcontext(context))
end

# Specific implementations of `isgeneratedquantities`.
isgeneratedquantities(context::GeneratedQuantitiesContext) = true

"""
    @generatedquantities f(x)

Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`.
"""
macro generated_quantities(expr)
    return esc(generated_quantities_expr(expr))
end

function generated_quantities_expr(expr)
    return quote
        if isgeneratedquantities(__context__)
            $expr
        end
    end
end
```

And usage would be as follows:

```julia
julia> @model function demo(x)
           @generated_quantities a = []
           m ~ Normal()

           # "Expensive" piece of code that we don't want to compute
           # unless we're computing the generated quantities.
           @generated_quantities for i = 1:100
               push!(a, m + randn())
           end

           # Observe.
           x ~ Normal(m, 1.0)

           # Return additional fields if we're computing the generated quantities.
           @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a)

           return nothing
       end
demo (generic function with 1 method)

julia> m = demo(1.0); var_info = VarInfo(m);

julia> m(var_info) # (✓) returns nothing

julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything
(x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716  …  0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438])

julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext`
(x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995  …  -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179])
```

This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it.

But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. 

### `ConditionContext`
See #278.

Co-authored-by: Hong Ge <[email protected]>
bors bot pushed a commit that referenced this pull request Aug 5, 2021
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .

(I hope you're proud of me @devmotion )

## Current state of things
Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods:

```julia
tilde_assume(...)
tilde_observe(...)
dot_tilde_assume(...)
dot_tilde_observe(...)
```

But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254.

## This sucks
This means that:
1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to:
   1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.
   2. Implement the method for that particular context.
2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial.

This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.

And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. 

## Goal
A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`.

## Solution (this PR)
The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).

This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`.

As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either
1. replace `minibatchcontext.context` with `DefaultContext()`, or
2. recursively "rewrap" `DefaultContext()` in `minibatch.context`.

(1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.).

Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).

## Examples

### `GeneratedQuantitiesContext`

In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts:

```julia
using DynamicPPL, Distributions, Random
using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext

struct GeneratedQuantitiesContext{Ctx} <: AbstractContext
    context::Ctx
end
GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext())

# Define the `NodeTrait` for `GeneratedQuantitiesContext`.
DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent()
DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context

"""
    isgeneratedquantities(context)

Return `true` if `context` wants evaluation of model to execute
the `@generatedquantities` block.
"""
function isgeneratedquantities(context::AbstractContext)
    return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context)
end

# Define the behavior for the different `NodeType`s.
isgeneratedquantities(::IsLeaf, context::AbstractContext) = false
function isgeneratedquantities(::IsParent, context::AbstractContext)
    return isgeneratedquantities(childcontext(context))
end

# Specific implementations of `isgeneratedquantities`.
isgeneratedquantities(context::GeneratedQuantitiesContext) = true

"""
    @generatedquantities f(x)

Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`.
"""
macro generated_quantities(expr)
    return esc(generated_quantities_expr(expr))
end

function generated_quantities_expr(expr)
    return quote
        if isgeneratedquantities(__context__)
            $expr
        end
    end
end
```

And usage would be as follows:

```julia
julia> @model function demo(x)
           @generated_quantities a = []
           m ~ Normal()

           # "Expensive" piece of code that we don't want to compute
           # unless we're computing the generated quantities.
           @generated_quantities for i = 1:100
               push!(a, m + randn())
           end

           # Observe.
           x ~ Normal(m, 1.0)

           # Return additional fields if we're computing the generated quantities.
           @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a)

           return nothing
       end
demo (generic function with 1 method)

julia> m = demo(1.0); var_info = VarInfo(m);

julia> m(var_info) # (✓) returns nothing

julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything
(x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716  …  0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438])

julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext`
(x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995  …  -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179])
```

This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it.

But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. 

### `ConditionContext`
See #278.

Co-authored-by: Hong Ge <[email protected]>
@yebai
Copy link
Member

yebai commented Aug 14, 2021

I’m closing this now in favour of (private) discussions on ContextStack.

@yebai yebai closed this Aug 14, 2021
@yebai yebai deleted the tor/wrappedcontext-v3 branch January 31, 2022 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants