Skip to content

Commit d635a17

Browse files
torfjeldepenelopeysmgithub-actions[bot]
authored
Some clean up of contexts (#711)
* fixed incorrect implementation of `dot_tilde_assume` for `PrefixContext` * removed `vars` field from `PriorContext` and `LikelihoodContext` as it's no longer used functionality (was dropped when we dropped the logprob-macro) * replaced `NoDist` with `nodist` * fixed method ambiguity issue * added missing `Distributions.rand!` definition for `NoDist` * added more elaborate testing of evaluation of contexts * added `DynamicPPL.TestUtils.test_context` for testing contexts and replaced much of the `test/contexts.jl` with calls to this method * added proper testing for PrefixContext of all demo models * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added some dropped tests * Update src/test_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/test_utils.jl Co-authored-by: Penelope Yong <[email protected]> * Update test/debug_utils.jl Co-authored-by: Penelope Yong <[email protected]> * bump patch version --------- Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 48921d3 commit d635a17

File tree

7 files changed

+136
-251
lines changed

7 files changed

+136
-251
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.30.5"
3+
version = "0.30.6"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/context_implementations.jl

Lines changed: 7 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -77,49 +77,11 @@ function tilde_assume(
7777
return tilde_assume(rng, childcontext(context), args...)
7878
end
7979

80-
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
81-
if haskey(context.vars, getsym(vn))
82-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
83-
settrans!!(vi, false, vn)
84-
end
85-
return tilde_assume(PriorContext(), right, vn, vi)
86-
end
87-
function tilde_assume(
88-
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
89-
)
90-
if haskey(context.vars, getsym(vn))
91-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
92-
settrans!!(vi, false, vn)
93-
end
94-
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
95-
end
96-
97-
function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
98-
if haskey(context.vars, getsym(vn))
99-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
100-
settrans!!(vi, false, vn)
101-
end
102-
return tilde_assume(LikelihoodContext(), right, vn, vi)
103-
end
104-
function tilde_assume(
105-
rng::Random.AbstractRNG,
106-
context::LikelihoodContext{<:NamedTuple},
107-
sampler,
108-
right,
109-
vn,
110-
vi,
111-
)
112-
if haskey(context.vars, getsym(vn))
113-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
114-
settrans!!(vi, false, vn)
115-
end
116-
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
117-
end
11880
function tilde_assume(::LikelihoodContext, right, vn, vi)
119-
return assume(NoDist(right), vn, vi)
81+
return assume(nodist(right), vn, vi)
12082
end
12183
function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi)
122-
return assume(rng, sampler, NoDist(right), vn, vi)
84+
return assume(rng, sampler, nodist(right), vn, vi)
12385
end
12486

12587
function tilde_assume(context::PrefixContext, right, vn, vi)
@@ -328,37 +290,6 @@ function dot_tilde_assume(
328290
end
329291

330292
# `LikelihoodContext`
331-
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
332-
return if haskey(context.vars, getsym(vn))
333-
var = get(context.vars, vn)
334-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
335-
set_val!(vi, _vns, _right, _left)
336-
settrans!!.((vi,), false, _vns)
337-
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
338-
else
339-
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
340-
end
341-
end
342-
function dot_tilde_assume(
343-
rng::Random.AbstractRNG,
344-
context::LikelihoodContext{<:NamedTuple},
345-
sampler,
346-
right,
347-
left,
348-
vn,
349-
vi,
350-
)
351-
return if haskey(context.vars, getsym(vn))
352-
var = get(context.vars, vn)
353-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
354-
set_val!(vi, _vns, _right, _left)
355-
settrans!!.((vi,), false, _vns)
356-
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
357-
else
358-
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
359-
end
360-
end
361-
362293
function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
363294
return dot_assume(nodist(right), left, vn, vi)
364295
end
@@ -368,46 +299,16 @@ function dot_tilde_assume(
368299
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
369300
end
370301

371-
# `PriorContext`
372-
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
373-
return if haskey(context.vars, getsym(vn))
374-
var = get(context.vars, vn)
375-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
376-
set_val!(vi, _vns, _right, _left)
377-
settrans!!.((vi,), false, _vns)
378-
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
379-
else
380-
dot_tilde_assume(PriorContext(), right, left, vn, vi)
381-
end
382-
end
383-
function dot_tilde_assume(
384-
rng::Random.AbstractRNG,
385-
context::PriorContext{<:NamedTuple},
386-
sampler,
387-
right,
388-
left,
389-
vn,
390-
vi,
391-
)
392-
return if haskey(context.vars, getsym(vn))
393-
var = get(context.vars, vn)
394-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
395-
set_val!(vi, _vns, _right, _left)
396-
settrans!!.((vi,), false, _vns)
397-
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
398-
else
399-
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
400-
end
401-
end
402-
403302
# `PrefixContext`
404303
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
405-
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
304+
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
406305
end
407306

408-
function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
307+
function dot_tilde_assume(
308+
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi
309+
)
409310
return dot_tilde_assume(
410-
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
311+
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
411312
)
412313
end
413314

src/contexts.jl

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ DefaultContext()
5353
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior
5454
5555
julia> DynamicPPL.childcontext(ctx_prior)
56-
PriorContext{Nothing}(nothing)
56+
PriorContext()
5757
```
5858
"""
5959
setchildcontext
@@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))
9797
9898
julia> # Replace the leaf context with another leaf.
9999
leafcontext(setleafcontext(ctx, PriorContext()))
100-
PriorContext{Nothing}(nothing)
100+
PriorContext()
101101
102102
julia> # Append another parent context.
103103
setleafcontext(ctx, ParentContext(DefaultContext()))
@@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
195195
NodeTrait(context::DefaultContext) = IsLeaf()
196196

197197
"""
198-
struct PriorContext{Tvars} <: AbstractContext
199-
vars::Tvars
200-
end
198+
PriorContext <: AbstractContext
201199
202-
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
203-
running the model.
200+
A leaf context resulting in the exclusion of likelihood terms when running the model.
204201
"""
205-
struct PriorContext{Tvars} <: AbstractContext
206-
vars::Tvars
207-
end
208-
PriorContext() = PriorContext(nothing)
202+
struct PriorContext <: AbstractContext end
209203
NodeTrait(context::PriorContext) = IsLeaf()
210204

211205
"""
212-
struct LikelihoodContext{Tvars} <: AbstractContext
213-
vars::Tvars
214-
end
206+
LikelihoodContext <: AbstractContext
215207
216-
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
217-
running the model. `vars` can be used to evaluate the log likelihood for specific values
218-
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
208+
A leaf context resulting in the exclusion of prior terms when running the model.
219209
"""
220-
struct LikelihoodContext{Tvars} <: AbstractContext
221-
vars::Tvars
222-
end
223-
LikelihoodContext() = LikelihoodContext(nothing)
210+
struct LikelihoodContext <: AbstractContext end
224211
NodeTrait(context::LikelihoodContext) = IsLeaf()
225212

226213
"""

src/distribution_wrappers.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ Base.length(dist::NoDist) = Base.length(dist.dist)
4242
Base.size(dist::NoDist) = Base.size(dist.dist)
4343

4444
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
45+
# NOTE(torfjelde): Need this to avoid stack overflow.
46+
function Distributions.rand!(
47+
rng::Random.AbstractRNG,
48+
d::NoDist{Distributions.ArrayLikeVariate{N}},
49+
x::AbstractArray{<:Real,N},
50+
) where {N}
51+
return Distributions.rand!(rng, d.dist, x)
52+
end
4553
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
4654
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
4755
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})

src/test_utils/contexts.jl

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,6 @@
33
#
44
# Utilities for testing contexts.
55

6-
"""
7-
test_context_interface(context)
8-
9-
Test that `context` implements the `AbstractContext` interface.
10-
"""
11-
function test_context_interface(context)
12-
# Is a subtype of `AbstractContext`.
13-
@test context isa DynamicPPL.AbstractContext
14-
# Should implement `NodeTrait.`
15-
@test DynamicPPL.NodeTrait(context) isa Union{DynamicPPL.IsParent,DynamicPPL.IsLeaf}
16-
# If it's a parent.
17-
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsParent
18-
# Should implement `childcontext` and `setchildcontext`
19-
@test DynamicPPL.setchildcontext(context, DynamicPPL.childcontext(context)) ==
20-
context
21-
end
22-
end
23-
246
"""
257
Context that multiplies each log-prior by mod
268
used to test whether varwise_logpriors respects child-context.
@@ -60,3 +42,71 @@ function DynamicPPL.dot_tilde_observe(
6042
logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
6143
return logp * context.mod, vi
6244
end
45+
46+
# Dummy context to test nested behaviors.
47+
struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext
48+
context::C
49+
end
50+
TestParentContext() = TestParentContext(DefaultContext())
51+
DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent()
52+
DynamicPPL.childcontext(context::TestParentContext) = context.context
53+
DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child)
54+
function Base.show(io::IO, c::TestParentContext)
55+
return print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")")
56+
end
57+
58+
"""
59+
test_context(context::AbstractContext, model::Model)
60+
61+
Test that `context` correctly implements the `AbstractContext` interface for `model`.
62+
63+
This method ensures that `context`
64+
- Correctly implements the `AbstractContext` interface.
65+
- Correctly implements the tilde-pipeline.
66+
"""
67+
function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
68+
# `NodeTrait`.
69+
node_trait = DynamicPPL.NodeTrait(context)
70+
# Throw error immediately if it it's missing a `NodeTrait` implementation.
71+
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
72+
throw(ValueError("Invalid NodeTrait: $node_trait"))
73+
74+
# The interface methods.
75+
if node_trait isa DynamicPPL.IsParent
76+
# `childcontext` and `setchildcontext`
77+
# With new child context
78+
childcontext_new = TestParentContext()
79+
@test DynamicPPL.childcontext(
80+
DynamicPPL.setchildcontext(context, childcontext_new)
81+
) == childcontext_new
82+
end
83+
84+
# To see change, let's make sure we're using a different leaf context than the current.
85+
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
86+
PriorContext()
87+
else
88+
DefaultContext()
89+
end
90+
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
91+
leafcontext_new
92+
93+
# Setting the child context to a leaf should now change the leafcontext accordingly.
94+
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
95+
@test DynamicPPL.setchildcontext(context_with_new_leaf) ===
96+
DynamicPPL.setleafcontext(context_with_new_leaf) ===
97+
leafcontext_new
98+
99+
# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
100+
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.
101+
# NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the
102+
# context might alter which variables are present, their names, etc., e.g. `PrefixContext`.
103+
# TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos.
104+
# Untyped varinfo.
105+
varinfo_untyped = DynamicPPL.VarInfo()
106+
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true)
107+
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true)
108+
# Typed varinfo.
109+
varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped)
110+
@test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true)
111+
@test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true)
112+
end

0 commit comments

Comments
 (0)