Skip to content

[Merged by Bors] - Introduce traits for contexts #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jul 24, 2021

This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .

(I hope you're proud of me @devmotion )

Current state of things

Currently, if one wants to implement a new AbstractContext one at least has to implement the following methods:

tilde_assume(...)
tilde_observe(...)
dot_tilde_assume(...)
dot_tilde_observe(...)

But there are also other methods that should be implemented but generally aren't properly handled, e.g. matchingvalue. And there might be more methods in the future, e.g. contextual_isassumption in #254.

This sucks

This means that:

  1. Implementing a new behavior for AbstractContext, e.g. contextual_isassumption, requires you to:
    1. Find all implementations of AbstractContext, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.
    2. Implement the method for that particular context.
  2. Implementing a new AbstractContext, e.g. Turing.OptimizationContext, requires you to find all the methods to implement and then do so. Again, non-trivial.

This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.

And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. MiniBatchContext only wants to change the *tilde_observe methods, and otherwise just defer to whatever implementation is available for its "childcontext" minibatchcontext.context.

Goal

A new AbstractContext should only (up to an additive factor) have to implement the behavior it wants to change, not all behaviors. E.g. MiniBatchContext should really only have to overload *tilde_observe.

Solution (this PR)

The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).

This PR takes are more extensible and less restrictive approach of introducing some traits for AbstractContext.

As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. MiniBatchContext and ConditionContext (from #278), and "leaf-contexts", e.g. DefaultContext and PriorContext (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a MiniBatchContext and DefaultContext? Well in that case we could either

  1. replace minibatchcontext.context with DefaultContext(), or
  2. recursively "rewrap" DefaultContext() in minibatch.context.

(1) has the issue that MiniBatchContext might be wrapping another context, e.g. PrefixContext, and so just replacing .context is dangerous. (2) seems like a better idea since DefaultContext should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in tilde_assume we call assume, etc.).

Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).

Examples

GeneratedQuantitiesContext

In DPPL we have the generated_quantities but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a @generated_quantities macro that will only be executed when called from generated_quantities. This can easily be achieved with contexts:

using DynamicPPL, Distributions, Random
using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext

struct GeneratedQuantitiesContext{Ctx} <: AbstractContext
    context::Ctx
end
GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext())

# Define the `NodeTrait` for `GeneratedQuantitiesContext`.
DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent()
DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context

"""
    isgeneratedquantities(context)

Return `true` if `context` wants evaluation of model to execute
the `@generatedquantities` block.
"""
function isgeneratedquantities(context::AbstractContext)
    return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context)
end

# Define the behavior for the different `NodeType`s.
isgeneratedquantities(::IsLeaf, context::AbstractContext) = false
function isgeneratedquantities(::IsParent, context::AbstractContext)
    return isgeneratedquantities(childcontext(context))
end

# Specific implementations of `isgeneratedquantities`.
isgeneratedquantities(context::GeneratedQuantitiesContext) = true

"""
    @generatedquantities f(x)

Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`.
"""
macro generated_quantities(expr)
    return esc(generated_quantities_expr(expr))
end

function generated_quantities_expr(expr)
    return quote
        if isgeneratedquantities(__context__)
            $expr
        end
    end
end

And usage would be as follows:

julia> @model function demo(x)
           @generated_quantities a = []
           m ~ Normal()

           # "Expensive" piece of code that we don't want to compute
           # unless we're computing the generated quantities.
           @generated_quantities for i = 1:100
               push!(a, m + randn())
           end

           # Observe.
           x ~ Normal(m, 1.0)

           # Return additional fields if we're computing the generated quantities.
           @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a)

           return nothing
       end
demo (generic function with 1 method)

julia> m = demo(1.0); var_info = VarInfo(m);

julia> m(var_info) # (✓) returns nothing

julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything
(x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716    0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438])

julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext`
(x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995    -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179])

This would also just work for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. predict and generated_quantities, you do need it.

But whether or not we want this particular @generatedquantities macro in DPPL is not the point; the point is that we can implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it.

ConditionContext

See #278.

function dot_tilde_assume(context::AbstractContext, args...)
return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...)
end
function dot_tilde_assume(rng, context::AbstractContext, args...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should make context always be the first argument, even after unwrapping in SamplingContext. This will further reduce redundancy by a factor of 2.

We only did it this way to attempt to be non-breaking (it wasn't 😅 ), so might was well make the change

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
torfjelde and others added 3 commits July 29, 2021 04:33
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

bors try

bors bot added a commit that referenced this pull request Jul 29, 2021
@torfjelde torfjelde marked this pull request as ready for review July 30, 2021 15:09
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work, thanks @torfjelde - it seems like a good design direction to follow. I'm overall happy with the design and changes, although there are some further places we can discuss and improve in a separate PR, e.g. whether we want to introduce a StackedContext type as a generic container for all non-primative contexts. This StackContext can make the relationship of contexts inside explicit and can provide a set of APIs for manipulating them. With that, we can consider making all existing contexts primative.

Also, it seems that we cannot unify the setchildcontext function using traits, which is slightly inconsistent with the overall trait-based design.

Comment on lines +504 to +511
return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value)
end
function matchingvalue(::IsLeaf, context::AbstractContext, vi, value)
return matchingvalue(SampleFromPrior(), vi, value)
end
function matchingvalue(::IsParent, context::AbstractContext, vi, value)
return matchingvalue(childcontext(context), vi, value)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also very nice as it allows us to define something like a CUDAContext in the future which will ensure that all the arguments are moved to the GPU before execution.

@testset "setchildcontext" begin
@testset "nested contexts" begin
# Both of the following should result in the same context.
context1 = ParentContext(ParentContext(ParentContext()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note for the future: there should be a validation of context constructors, e.g. here the innermost context is still a parent context (without a leaf context), which is problematic. This is one of the reasons why the current context type hierarchy is a bit unsafe and confusing. One possible fix is to introduce an explicit context container StackedContext, then make all existing contexts primitive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No no, ParentContext does have a leaf. It's just that it has a default constructor that uses DefaultContext.
This validation is done by the type-parameters of the struct itself, so there's no need for explicit checks. If you try to construct a parent context without a context, well, then no such constructor exists.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possible fix is to introduce an explicit context container StackedContext, then make all existing contexts primitive.

IIUC this doesn't solve anything though; all it does is that we need to the unwrap from StackedContext but we're still left with all of the rest of the codebase in this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct - StackedStack won't simplify the code, but it 1) make primitive contexts easy to reason about, thus improve clarity. I find the current ParentContext quite hard to follow. 2) We can add extensive validity checking for StackedContext. This is also possible for the current ParentContext but is laborious since it involves repetitive work for each ParentContext subtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you maybe outline a bit more in detail what you mean? I think I'm not properly understanding 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m more of commenting on this for a future reminder. I’ll create an issue for it later with more examples and explanations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful; thanks!:)

@yebai
Copy link
Member

yebai commented Aug 4, 2021

bors r+

bors bot pushed a commit that referenced this pull request Aug 4, 2021
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .

(I hope you're proud of me @devmotion )

## Current state of things
Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods:

```julia
tilde_assume(...)
tilde_observe(...)
dot_tilde_assume(...)
dot_tilde_observe(...)
```

But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254.

## This sucks
This means that:
1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to:
   1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.
   2. Implement the method for that particular context.
2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial.

This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.

And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. 

## Goal
A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`.

## Solution (this PR)
The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).

This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`.

As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either
1. replace `minibatchcontext.context` with `DefaultContext()`, or
2. recursively "rewrap" `DefaultContext()` in `minibatch.context`.

(1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.).

Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).

## Examples

### `GeneratedQuantitiesContext`

In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts:

```julia
using DynamicPPL, Distributions, Random
using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext

struct GeneratedQuantitiesContext{Ctx} <: AbstractContext
    context::Ctx
end
GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext())

# Define the `NodeTrait` for `GeneratedQuantitiesContext`.
DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent()
DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context

"""
    isgeneratedquantities(context)

Return `true` if `context` wants evaluation of model to execute
the `@generatedquantities` block.
"""
function isgeneratedquantities(context::AbstractContext)
    return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context)
end

# Define the behavior for the different `NodeType`s.
isgeneratedquantities(::IsLeaf, context::AbstractContext) = false
function isgeneratedquantities(::IsParent, context::AbstractContext)
    return isgeneratedquantities(childcontext(context))
end

# Specific implementations of `isgeneratedquantities`.
isgeneratedquantities(context::GeneratedQuantitiesContext) = true

"""
    @generatedquantities f(x)

Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`.
"""
macro generated_quantities(expr)
    return esc(generated_quantities_expr(expr))
end

function generated_quantities_expr(expr)
    return quote
        if isgeneratedquantities(__context__)
            $expr
        end
    end
end
```

And usage would be as follows:

```julia
julia> @model function demo(x)
           @generated_quantities a = []
           m ~ Normal()

           # "Expensive" piece of code that we don't want to compute
           # unless we're computing the generated quantities.
           @generated_quantities for i = 1:100
               push!(a, m + randn())
           end

           # Observe.
           x ~ Normal(m, 1.0)

           # Return additional fields if we're computing the generated quantities.
           @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a)

           return nothing
       end
demo (generic function with 1 method)

julia> m = demo(1.0); var_info = VarInfo(m);

julia> m(var_info) # (✓) returns nothing

julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything
(x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716  …  0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438])

julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext`
(x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995  …  -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179])
```

This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it.

But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. 

### `ConditionContext`
See #278.

Co-authored-by: Hong Ge <[email protected]>
@torfjelde
Copy link
Member Author

bors cancel

@torfjelde
Copy link
Member Author

torfjelde commented Aug 4, 2021

@yebai Can you maybe approve TuringLang/Turing.jl#1672 first so we can run integration tests properly?

EDIT: Also just for future reference, I think it's best if you allow the PR-creator to officially do the bors r+ since they might have thoughts or concerns pop up even after the reviewer approval. There have been some premature merges as a result of this, e.g. the bug in my previous PR to fix the unwrap_right_left_vns thingy where I discovered the bug before merge but didn't get around to fixing it before the merge was triggered, only for me to return to the PR too late + a released version with the bug in it.

@yebai
Copy link
Member

yebai commented Aug 4, 2021

@yebai Can you maybe approve TuringLang/Turing.jl#1672 first so we can run integration tests properly?

Done.

@torfjelde
Copy link
Member Author

Great, thanks! I'll make a release, and once that's done we can run bors r+ again.

@torfjelde
Copy link
Member Author

bors try

bors bot added a commit that referenced this pull request Aug 4, 2021
@torfjelde
Copy link
Member Author

bors cancel

@torfjelde
Copy link
Member Author

bors try

@bors
Copy link
Contributor

bors bot commented Aug 4, 2021

try

Already running a review

@yebai
Copy link
Member

yebai commented Aug 5, 2021

Bors r+

bors bot pushed a commit that referenced this pull request Aug 5, 2021
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .

(I hope you're proud of me @devmotion )

## Current state of things
Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods:

```julia
tilde_assume(...)
tilde_observe(...)
dot_tilde_assume(...)
dot_tilde_observe(...)
```

But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254.

## This sucks
This means that:
1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to:
   1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.
   2. Implement the method for that particular context.
2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial.

This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.

And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. 

## Goal
A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`.

## Solution (this PR)
The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).

This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`.

As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either
1. replace `minibatchcontext.context` with `DefaultContext()`, or
2. recursively "rewrap" `DefaultContext()` in `minibatch.context`.

(1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.).

Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).

## Examples

### `GeneratedQuantitiesContext`

In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts:

```julia
using DynamicPPL, Distributions, Random
using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext

struct GeneratedQuantitiesContext{Ctx} <: AbstractContext
    context::Ctx
end
GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext())

# Define the `NodeTrait` for `GeneratedQuantitiesContext`.
DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent()
DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context

"""
    isgeneratedquantities(context)

Return `true` if `context` wants evaluation of model to execute
the `@generatedquantities` block.
"""
function isgeneratedquantities(context::AbstractContext)
    return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context)
end

# Define the behavior for the different `NodeType`s.
isgeneratedquantities(::IsLeaf, context::AbstractContext) = false
function isgeneratedquantities(::IsParent, context::AbstractContext)
    return isgeneratedquantities(childcontext(context))
end

# Specific implementations of `isgeneratedquantities`.
isgeneratedquantities(context::GeneratedQuantitiesContext) = true

"""
    @generatedquantities f(x)

Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`.
"""
macro generated_quantities(expr)
    return esc(generated_quantities_expr(expr))
end

function generated_quantities_expr(expr)
    return quote
        if isgeneratedquantities(__context__)
            $expr
        end
    end
end
```

And usage would be as follows:

```julia
julia> @model function demo(x)
           @generated_quantities a = []
           m ~ Normal()

           # "Expensive" piece of code that we don't want to compute
           # unless we're computing the generated quantities.
           @generated_quantities for i = 1:100
               push!(a, m + randn())
           end

           # Observe.
           x ~ Normal(m, 1.0)

           # Return additional fields if we're computing the generated quantities.
           @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a)

           return nothing
       end
demo (generic function with 1 method)

julia> m = demo(1.0); var_info = VarInfo(m);

julia> m(var_info) # (✓) returns nothing

julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything
(x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716  …  0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438])

julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext`
(x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995  …  -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179])
```

This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it.

But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. 

### `ConditionContext`
See #278.

Co-authored-by: Hong Ge <[email protected]>
@bors bors bot changed the title Introduce traits for contexts [Merged by Bors] - Introduce traits for contexts Aug 5, 2021
@bors bors bot closed this Aug 5, 2021
@bors bors bot deleted the tor/context-traits branch August 5, 2021 12:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants