-
Notifications
You must be signed in to change notification settings - Fork 35
[Merged by Bors] - Introduce traits for contexts #286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
function dot_tilde_assume(context::AbstractContext, args...) | ||
return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) | ||
end | ||
function dot_tilde_assume(rng, context::AbstractContext, args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should make context
always be the first argument, even after unwrapping in SamplingContext
. This will further reduce redundancy by a factor of 2.
We only did it this way to attempt to be non-breaking (it wasn't 😅 ), so might was well make the change
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
bors try |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work, thanks @torfjelde - it seems like a good design direction to follow. I'm overall happy with the design and changes, although there are some further places we can discuss and improve in a separate PR, e.g. whether we want to introduce a StackedContext
type as a generic container for all non-primative contexts. This StackContext
can make the relationship of contexts inside explicit and can provide a set of APIs for manipulating them. With that, we can consider making all existing contexts primative.
Also, it seems that we cannot unify the setchildcontext
function using traits, which is slightly inconsistent with the overall trait-based design.
return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) | ||
end | ||
function matchingvalue(::IsLeaf, context::AbstractContext, vi, value) | ||
return matchingvalue(SampleFromPrior(), vi, value) | ||
end | ||
function matchingvalue(::IsParent, context::AbstractContext, vi, value) | ||
return matchingvalue(childcontext(context), vi, value) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also very nice as it allows us to define something like a CUDAContext
in the future which will ensure that all the arguments are moved to the GPU before execution.
@testset "setchildcontext" begin | ||
@testset "nested contexts" begin | ||
# Both of the following should result in the same context. | ||
context1 = ParentContext(ParentContext(ParentContext())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A note for the future: there should be a validation of context constructors, e.g. here the innermost context is still a parent context (without a leaf context), which is problematic. This is one of the reasons why the current context type hierarchy is a bit unsafe and confusing. One possible fix is to introduce an explicit context container StackedContext
, then make all existing contexts primitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No no, ParentContext
does have a leaf. It's just that it has a default constructor that uses DefaultContext
.
This validation is done by the type-parameters of the struct itself, so there's no need for explicit checks. If you try to construct a parent context without a context, well, then no such constructor exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One possible fix is to introduce an explicit context container StackedContext, then make all existing contexts primitive.
IIUC this doesn't solve anything though; all it does is that we need to the unwrap from StackedContext
but we're still left with all of the rest of the codebase in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct - StackedStack
won't simplify the code, but it 1) make primitive contexts easy to reason about, thus improve clarity. I find the current ParentContext
quite hard to follow. 2) We can add extensive validity checking for StackedContext
. This is also possible for the current ParentContext
but is laborious since it involves repetitive work for each ParentContext
subtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you maybe outline a bit more in detail what you mean? I think I'm not properly understanding 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m more of commenting on this for a future reminder. I’ll create an issue for it later with more examples and explanations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful; thanks!:)
bors r+ |
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 . (I hope you're proud of me @devmotion ) ## Current state of things Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods: ```julia tilde_assume(...) tilde_observe(...) dot_tilde_assume(...) dot_tilde_observe(...) ``` But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254. ## This sucks This means that: 1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to: 1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts. 2. Implement the method for that particular context. 2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial. This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons. And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. ## Goal A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`. ## Solution (this PR) The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced). This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`. As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either 1. replace `minibatchcontext.context` with `DefaultContext()`, or 2. recursively "rewrap" `DefaultContext()` in `minibatch.context`. (1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.). Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing). ## Examples ### `GeneratedQuantitiesContext` In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts: ```julia using DynamicPPL, Distributions, Random using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext struct GeneratedQuantitiesContext{Ctx} <: AbstractContext context::Ctx end GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext()) # Define the `NodeTrait` for `GeneratedQuantitiesContext`. DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent() DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context """ isgeneratedquantities(context) Return `true` if `context` wants evaluation of model to execute the `@generatedquantities` block. """ function isgeneratedquantities(context::AbstractContext) return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context) end # Define the behavior for the different `NodeType`s. isgeneratedquantities(::IsLeaf, context::AbstractContext) = false function isgeneratedquantities(::IsParent, context::AbstractContext) return isgeneratedquantities(childcontext(context)) end # Specific implementations of `isgeneratedquantities`. isgeneratedquantities(context::GeneratedQuantitiesContext) = true """ @generatedquantities f(x) Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`. """ macro generated_quantities(expr) return esc(generated_quantities_expr(expr)) end function generated_quantities_expr(expr) return quote if isgeneratedquantities(__context__) $expr end end end ``` And usage would be as follows: ```julia julia> @model function demo(x) @generated_quantities a = [] m ~ Normal() # "Expensive" piece of code that we don't want to compute # unless we're computing the generated quantities. @generated_quantities for i = 1:100 push!(a, m + randn()) end # Observe. x ~ Normal(m, 1.0) # Return additional fields if we're computing the generated quantities. @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a) return nothing end demo (generic function with 1 method) julia> m = demo(1.0); var_info = VarInfo(m); julia> m(var_info) # (✓) returns nothing julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything (x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716 … 0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438]) julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext` (x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995 … -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179]) ``` This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it. But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. ### `ConditionContext` See #278. Co-authored-by: Hong Ge <[email protected]>
bors cancel |
@yebai Can you maybe approve TuringLang/Turing.jl#1672 first so we can run integration tests properly? EDIT: Also just for future reference, I think it's best if you allow the PR-creator to officially do the |
Done. |
Great, thanks! I'll make a release, and once that's done we can run |
bors try |
bors cancel |
bors try |
tryAlready running a review |
Bors r+ |
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 . (I hope you're proud of me @devmotion ) ## Current state of things Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods: ```julia tilde_assume(...) tilde_observe(...) dot_tilde_assume(...) dot_tilde_observe(...) ``` But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254. ## This sucks This means that: 1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to: 1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts. 2. Implement the method for that particular context. 2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial. This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons. And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. ## Goal A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`. ## Solution (this PR) The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced). This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`. As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either 1. replace `minibatchcontext.context` with `DefaultContext()`, or 2. recursively "rewrap" `DefaultContext()` in `minibatch.context`. (1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.). Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing). ## Examples ### `GeneratedQuantitiesContext` In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts: ```julia using DynamicPPL, Distributions, Random using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext struct GeneratedQuantitiesContext{Ctx} <: AbstractContext context::Ctx end GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext()) # Define the `NodeTrait` for `GeneratedQuantitiesContext`. DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent() DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context """ isgeneratedquantities(context) Return `true` if `context` wants evaluation of model to execute the `@generatedquantities` block. """ function isgeneratedquantities(context::AbstractContext) return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context) end # Define the behavior for the different `NodeType`s. isgeneratedquantities(::IsLeaf, context::AbstractContext) = false function isgeneratedquantities(::IsParent, context::AbstractContext) return isgeneratedquantities(childcontext(context)) end # Specific implementations of `isgeneratedquantities`. isgeneratedquantities(context::GeneratedQuantitiesContext) = true """ @generatedquantities f(x) Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`. """ macro generated_quantities(expr) return esc(generated_quantities_expr(expr)) end function generated_quantities_expr(expr) return quote if isgeneratedquantities(__context__) $expr end end end ``` And usage would be as follows: ```julia julia> @model function demo(x) @generated_quantities a = [] m ~ Normal() # "Expensive" piece of code that we don't want to compute # unless we're computing the generated quantities. @generated_quantities for i = 1:100 push!(a, m + randn()) end # Observe. x ~ Normal(m, 1.0) # Return additional fields if we're computing the generated quantities. @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a) return nothing end demo (generic function with 1 method) julia> m = demo(1.0); var_info = VarInfo(m); julia> m(var_info) # (✓) returns nothing julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything (x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716 … 0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438]) julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext` (x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995 … -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179]) ``` This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it. But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. ### `ConditionContext` See #278. Co-authored-by: Hong Ge <[email protected]>
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .
(I hope you're proud of me @devmotion )
Current state of things
Currently, if one wants to implement a new
AbstractContext
one at least has to implement the following methods:But there are also other methods that should be implemented but generally aren't properly handled, e.g.
matchingvalue
. And there might be more methods in the future, e.g.contextual_isassumption
in #254.This sucks
This means that:
AbstractContext
, e.g.contextual_isassumption
, requires you to:AbstractContext
, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.AbstractContext
, e.g.Turing.OptimizationContext
, requires you to find all the methods to implement and then do so. Again, non-trivial.This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.
And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g.
MiniBatchContext
only wants to change the*tilde_observe
methods, and otherwise just defer to whatever implementation is available for its "childcontext"minibatchcontext.context
.Goal
A new
AbstractContext
should only (up to an additive factor) have to implement the behavior it wants to change, not all behaviors. E.g.MiniBatchContext
should really only have to overload*tilde_observe
.Solution (this PR)
The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).
This PR takes are more extensible and less restrictive approach of introducing some traits for
AbstractContext
.As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g.
MiniBatchContext
andConditionContext
(from #278), and "leaf-contexts", e.g.DefaultContext
andPriorContext
(which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine aMiniBatchContext
andDefaultContext
? Well in that case we could eitherminibatchcontext.context
withDefaultContext()
, orDefaultContext()
inminibatch.context
.(1) has the issue that
MiniBatchContext
might be wrapping another context, e.g.PrefixContext
, and so just replacing.context
is dangerous. (2) seems like a better idea sinceDefaultContext
should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. intilde_assume
we callassume
, etc.).Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).
Examples
GeneratedQuantitiesContext
In DPPL we have the
generated_quantities
but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a@generated_quantities
macro that will only be executed when called fromgenerated_quantities
. This can easily be achieved with contexts:And usage would be as follows:
This would also just work for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g.
predict
andgenerated_quantities
, you do need it.But whether or not we want this particular
@generatedquantities
macro in DPPL is not the point; the point is that we can implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it.ConditionContext
See #278.