Skip to content

Commit e673b69

Browse files
authored
Streamline internal de/conditioning interface (#776)
* Remove `condition` type piracy * Add tests for model conditioning syntax * Add tests for ConditionContext/decondition_context * Format * Bump patch version * Add ConditionContext docstring to docs * Fix type annotation of | in docs * Fix remaining bugs e.g. in nested `decondition_context`
1 parent 003ff2f commit e673b69

File tree

6 files changed

+199
-84
lines changed

6 files changed

+199
-84
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.33.0"
3+
version = "0.33.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ DynamicPPL.LogDensityFunction
6565
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
6666

6767
```@docs
68-
|(::Model, ::Any)
68+
|(::Model, ::Union{Tuple,NamedTuple,AbstractDict{<:VarName}})
6969
condition
7070
DynamicPPL.conditioned
7171
```
@@ -403,6 +403,7 @@ LikelihoodContext
403403
PriorContext
404404
MiniBatchContext
405405
PrefixContext
406+
ConditionContext
406407
```
407408

408409
### Samplers

src/contexts.jl

+50-74
Original file line numberDiff line numberDiff line change
@@ -309,20 +309,40 @@ function prefix(model::Model, ::Val{x}) where {x}
309309
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
310310
end
311311

312-
struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
312+
"""
313+
314+
ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
315+
316+
Model context that contains values that are to be conditioned on. The values
317+
can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
318+
an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
319+
@varname(b) => 2)`). The former is more performant, but the latter must be used
320+
when there are varnames that cannot be represented as symbols, e.g.
321+
`@varname(x[1])`.
322+
"""
323+
struct ConditionContext{
324+
Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext
325+
} <: AbstractContext
313326
values::Values
314327
context::Ctx
315328
end
316329

317330
const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
318331
const DictConditionContext = ConditionContext{<:AbstractDict}
319332

320-
ConditionContext(values) = ConditionContext(values, DefaultContext())
321-
322-
# Try to avoid nested `ConditionContext`.
333+
# Use DefaultContext as the default base context
334+
function ConditionContext(values::Union{NamedTuple,AbstractDict})
335+
return ConditionContext(values, DefaultContext())
336+
end
337+
# Optimisation when there are no values to condition on
338+
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
339+
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
340+
# values inside the child context, thus giving precedence to the outermost
341+
# `ConditionContext`.
323342
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
324-
# Note that this potentially overrides values from `context`, thus giving
325-
# precedence to the outmost `ConditionContext`.
343+
return ConditionContext(merge(context.values, values), childcontext(context))
344+
end
345+
function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext)
326346
return ConditionContext(merge(context.values, values), childcontext(context))
327347
end
328348

@@ -399,43 +419,6 @@ function getconditioned_nested(::IsParent, context, vn)
399419
end
400420
end
401421

402-
"""
403-
condition([context::AbstractContext,] values::NamedTuple)
404-
condition([context::AbstractContext]; values...)
405-
406-
Return `ConditionContext` with `values` and `context` if `values` is non-empty,
407-
otherwise return `context` which is [`DefaultContext`](@ref) by default.
408-
409-
See also: [`decondition`](@ref)
410-
"""
411-
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
412-
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
413-
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
414-
return condition((value, values...))
415-
end
416-
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
417-
return condition(DefaultContext(), values)
418-
end
419-
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
420-
function AbstractPPL.condition(
421-
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
422-
)
423-
return ConditionContext(values, context)
424-
end
425-
function AbstractPPL.condition(context::AbstractContext; values...)
426-
return condition(context, NamedTuple(values))
427-
end
428-
function AbstractPPL.condition(
429-
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
430-
)
431-
return condition(context, (value, values...))
432-
end
433-
function AbstractPPL.condition(
434-
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
435-
)
436-
return condition(context, Dict(values))
437-
end
438-
439422
"""
440423
decondition(context::AbstractContext, syms...)
441424
@@ -445,41 +428,34 @@ Note that this recursively traverses contexts, deconditioning all along the way.
445428
446429
See also: [`condition`](@ref)
447430
"""
448-
AbstractPPL.decondition(::IsLeaf, context, args...) = context
449-
function AbstractPPL.decondition(::IsParent, context, args...)
450-
return setchildcontext(context, decondition(childcontext(context), args...))
431+
decondition_context(::IsLeaf, context, args...) = context
432+
function decondition_context(::IsParent, context, args...)
433+
return setchildcontext(context, decondition_context(childcontext(context), args...))
451434
end
452-
function AbstractPPL.decondition(context, args...)
453-
return decondition(NodeTrait(context), context, args...)
435+
function decondition_context(context, args...)
436+
return decondition_context(NodeTrait(context), context, args...)
454437
end
455-
function AbstractPPL.decondition(context::ConditionContext)
456-
return decondition(childcontext(context))
457-
end
458-
function AbstractPPL.decondition(context::ConditionContext, sym)
459-
return condition(
460-
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
461-
)
438+
function decondition_context(context::ConditionContext)
439+
return decondition_context(childcontext(context))
462440
end
463-
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
464-
return decondition(
465-
condition(
466-
decondition(childcontext(context), syms...),
467-
BangBang.delete!!(context.values, sym),
468-
),
469-
syms...,
470-
)
471-
end
472-
473-
function AbstractPPL.decondition(
474-
context::NamedConditionContext, vn::VarName{sym}
475-
) where {sym}
476-
return condition(
477-
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
478-
)
441+
function decondition_context(context::ConditionContext, sym, syms...)
442+
new_values = deepcopy(context.values)
443+
for s in (sym, syms...)
444+
new_values = BangBang.delete!!(new_values, s)
445+
end
446+
return if length(new_values) == 0
447+
# No more values left, can unwrap
448+
decondition_context(childcontext(context), syms...)
449+
else
450+
ConditionContext(
451+
new_values, decondition_context(childcontext(context), sym, syms...)
452+
)
453+
end
479454
end
480-
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
481-
return condition(
482-
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
455+
function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym}
456+
return ConditionContext(
457+
BangBang.delete!!(context.values, sym),
458+
decondition_context(childcontext(context), vn),
483459
)
484460
end
485461

src/model.jl

+30-8
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ Return a `Model` which now treats variables on the right-hand side as observatio
9696
9797
See [`condition`](@ref) for more information and examples.
9898
"""
99-
Base.:|(model::Model, values) = condition(model, values)
99+
Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
100+
condition(model, values)
100101

101102
"""
102103
condition(model::Model; values...)
@@ -264,11 +265,32 @@ julia> conditioned_model_dict()
264265
1.0
265266
```
266267
"""
267-
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
268-
function AbstractPPL.condition(model::Model, value, values...)
269-
return contextualize(model, condition(model.context, value, values...))
268+
function AbstractPPL.condition(model::Model, values...)
269+
# Positional arguments - need to handle cases carefully
270+
return contextualize(
271+
model, ConditionContext(_make_conditioning_values(values...), model.context)
272+
)
273+
end
274+
function AbstractPPL.condition(model::Model; values...)
275+
# Keyword arguments -- just convert to a NamedTuple
276+
return contextualize(model, ConditionContext(NamedTuple(values), model.context))
270277
end
271278

279+
"""
280+
_make_conditioning_values(vals...)
281+
282+
Convert different types of input to either a `NamedTuple` or `AbstractDict` of
283+
conditioning values, suitable for storage in a `ConditionContext`.
284+
285+
This handles all the cases where `vals` is either already a NamedTuple or
286+
AbstractDict (e.g. `model | (x=1, y=2)`), as well as if they are splatted (e.g.
287+
`condition(model, x=1, y=2)`).
288+
"""
289+
_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values
290+
_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values)
291+
_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...)
292+
_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...)
293+
272294
"""
273295
decondition(model::Model)
274296
decondition(model::Model, variables...)
@@ -379,7 +401,7 @@ true
379401
```
380402
"""
381403
function AbstractPPL.decondition(model::Model, syms...)
382-
return contextualize(model, decondition(model.context, syms...))
404+
return contextualize(model, decondition_context(model.context, syms...))
383405
end
384406

385407
"""
@@ -413,7 +435,7 @@ julia> # Returns all the variables we have conditioned on + their values.
413435
(x = 100.0, m = 1.0)
414436
415437
julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
416-
cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0);
438+
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
417439
418440
julia> conditioned(cm)
419441
(x = 100.0, m = 1.0)
@@ -425,15 +447,15 @@ julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed
425447
a.m
426448
427449
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
428-
cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0);
450+
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
429451
430452
julia> conditioned(cm).x
431453
100.0
432454
433455
julia> conditioned(cm).var"a.m"
434456
1.0
435457
436-
julia> keys(VarInfo(cm)) # <= no variables are sampled
458+
julia> keys(VarInfo(cm)) # No variables are sampled
437459
VarName[]
438460
```
439461
"""

test/contexts.jl

+83
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using DynamicPPL:
1111
PointwiseLogdensityContext,
1212
contextual_isassumption,
1313
ConditionContext,
14+
decondition_context,
1415
hasconditioned,
1516
getconditioned,
1617
hasconditioned_nested,
@@ -196,6 +197,88 @@ end
196197
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
197198
end
198199

200+
@testset "ConditionContext" begin
201+
@testset "Nesting" begin
202+
@testset "NamedTuple" begin
203+
n1 = (x=1, y=2)
204+
n2 = (x=3,)
205+
# Values from outer context should override inner one
206+
ctx1 = ConditionContext(n1, ConditionContext(n2))
207+
@test ctx1.values == (x=1, y=2)
208+
# Check that the two ConditionContexts are collapsed
209+
@test childcontext(ctx1) isa DefaultContext
210+
# Then test the nesting the other way round
211+
ctx2 = ConditionContext(n2, ConditionContext(n1))
212+
@test ctx2.values == (x=3, y=2)
213+
@test childcontext(ctx2) isa DefaultContext
214+
end
215+
216+
@testset "Dict" begin
217+
# Same tests as NamedTuple above
218+
d1 = Dict(@varname(x) => 1, @varname(y) => 2)
219+
d2 = Dict(@varname(x) => 3)
220+
ctx1 = ConditionContext(d1, ConditionContext(d2))
221+
@test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2)
222+
@test childcontext(ctx1) isa DefaultContext
223+
ctx2 = ConditionContext(d2, ConditionContext(d1))
224+
@test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2)
225+
@test childcontext(ctx2) isa DefaultContext
226+
end
227+
end
228+
229+
@testset "decondition_context" begin
230+
@testset "NamedTuple" begin
231+
ctx = ConditionContext((x=1, y=2, z=3))
232+
# Decondition all variables
233+
@test decondition_context(ctx) isa DefaultContext
234+
# Decondition only some variables
235+
dctx = decondition_context(ctx, :x)
236+
@test dctx isa ConditionContext
237+
@test dctx.values == (y=2, z=3)
238+
dctx = decondition_context(ctx, :y, :z)
239+
@test dctx isa ConditionContext
240+
@test dctx.values == (x=1,)
241+
# Decondition all variables manually
242+
@test decondition_context(ctx, :x, :y, :z) isa DefaultContext
243+
end
244+
245+
@testset "Dict" begin
246+
ctx = ConditionContext(
247+
Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3)
248+
)
249+
# Decondition all variables
250+
@test decondition_context(ctx) isa DefaultContext
251+
# Decondition only some variables
252+
dctx = decondition_context(ctx, @varname(x))
253+
@test dctx isa ConditionContext
254+
@test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3)
255+
dctx = decondition_context(ctx, @varname(y), @varname(z))
256+
@test dctx isa ConditionContext
257+
@test dctx.values == Dict(@varname(x) => 1)
258+
# Decondition all variables manually
259+
@test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa
260+
DefaultContext
261+
end
262+
263+
@testset "Nesting" begin
264+
ctx = ConditionContext(
265+
(x=1, y=2), ConditionContext(Dict(@varname(a) => 3, @varname(b) => 4))
266+
)
267+
# Decondition an outer variable
268+
dctx = decondition_context(ctx, :x)
269+
@test dctx.values == (y=2,)
270+
@test childcontext(dctx).values == Dict(@varname(a) => 3, @varname(b) => 4)
271+
# Decondition an inner variable
272+
dctx = decondition_context(ctx, @varname(a))
273+
@test dctx.values == (x=1, y=2)
274+
@test childcontext(dctx).values == Dict(@varname(b) => 4)
275+
# Try deconditioning everything
276+
dctx = decondition_context(ctx)
277+
@test dctx isa DefaultContext
278+
end
279+
end
280+
end
281+
199282
@testset "FixedContext" begin
200283
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
201284
retval = model()

0 commit comments

Comments
 (0)