@@ -309,20 +309,40 @@ function prefix(model::Model, ::Val{x}) where {x}
309
309
return contextualize (model, PrefixContext {Symbol(x)} (model. context))
310
310
end
311
311
312
- struct ConditionContext{Values,Ctx<: AbstractContext } <: AbstractContext
312
+ """
313
+
314
+ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
315
+
316
+ Model context that contains values that are to be conditioned on. The values
317
+ can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
318
+ an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
319
+ @varname(b) => 2)`). The former is more performant, but the latter must be used
320
+ when there are varnames that cannot be represented as symbols, e.g.
321
+ `@varname(x[1])`.
322
+ """
323
+ struct ConditionContext{
324
+ Values<: Union{NamedTuple,AbstractDict{<:VarName}} ,Ctx<: AbstractContext
325
+ } <: AbstractContext
313
326
values:: Values
314
327
context:: Ctx
315
328
end
316
329
317
330
const NamedConditionContext{Names} = ConditionContext{<: NamedTuple{Names} }
318
331
const DictConditionContext = ConditionContext{<: AbstractDict }
319
332
320
- ConditionContext (values) = ConditionContext (values, DefaultContext ())
321
-
322
- # Try to avoid nested `ConditionContext`.
333
+ # Use DefaultContext as the default base context
334
+ function ConditionContext (values:: Union{NamedTuple,AbstractDict} )
335
+ return ConditionContext (values, DefaultContext ())
336
+ end
337
+ # Optimisation when there are no values to condition on
338
+ ConditionContext (:: NamedTuple{()} , context:: AbstractContext ) = context
339
+ # Collapse consecutive levels of `ConditionContext`. Note that this overrides
340
+ # values inside the child context, thus giving precedence to the outermost
341
+ # `ConditionContext`.
323
342
function ConditionContext (values:: NamedTuple , context:: NamedConditionContext )
324
- # Note that this potentially overrides values from `context`, thus giving
325
- # precedence to the outmost `ConditionContext`.
343
+ return ConditionContext (merge (context. values, values), childcontext (context))
344
+ end
345
+ function ConditionContext (values:: AbstractDict{<:VarName} , context:: DictConditionContext )
326
346
return ConditionContext (merge (context. values, values), childcontext (context))
327
347
end
328
348
@@ -399,43 +419,6 @@ function getconditioned_nested(::IsParent, context, vn)
399
419
end
400
420
end
401
421
402
- """
403
- condition([context::AbstractContext,] values::NamedTuple)
404
- condition([context::AbstractContext]; values...)
405
-
406
- Return `ConditionContext` with `values` and `context` if `values` is non-empty,
407
- otherwise return `context` which is [`DefaultContext`](@ref) by default.
408
-
409
- See also: [`decondition`](@ref)
410
- """
411
- AbstractPPL. condition (; values... ) = condition (NamedTuple (values))
412
- AbstractPPL. condition (values:: NamedTuple ) = condition (DefaultContext (), values)
413
- function AbstractPPL. condition (value:: Pair{<:VarName} , values:: Pair{<:VarName} ...)
414
- return condition ((value, values... ))
415
- end
416
- function AbstractPPL. condition (values:: NTuple{<:Any,<:Pair{<:VarName}} )
417
- return condition (DefaultContext (), values)
418
- end
419
- AbstractPPL. condition (context:: AbstractContext , values:: NamedTuple{()} ) = context
420
- function AbstractPPL. condition (
421
- context:: AbstractContext , values:: Union{AbstractDict,NamedTuple}
422
- )
423
- return ConditionContext (values, context)
424
- end
425
- function AbstractPPL. condition (context:: AbstractContext ; values... )
426
- return condition (context, NamedTuple (values))
427
- end
428
- function AbstractPPL. condition (
429
- context:: AbstractContext , value:: Pair{<:VarName} , values:: Pair{<:VarName} ...
430
- )
431
- return condition (context, (value, values... ))
432
- end
433
- function AbstractPPL. condition (
434
- context:: AbstractContext , values:: NTuple{<:Any,Pair{<:VarName}}
435
- )
436
- return condition (context, Dict (values))
437
- end
438
-
439
422
"""
440
423
decondition(context::AbstractContext, syms...)
441
424
@@ -445,41 +428,34 @@ Note that this recursively traverses contexts, deconditioning all along the way.
445
428
446
429
See also: [`condition`](@ref)
447
430
"""
448
- AbstractPPL . decondition (:: IsLeaf , context, args... ) = context
449
- function AbstractPPL . decondition (:: IsParent , context, args... )
450
- return setchildcontext (context, decondition (childcontext (context), args... ))
431
+ decondition_context (:: IsLeaf , context, args... ) = context
432
+ function decondition_context (:: IsParent , context, args... )
433
+ return setchildcontext (context, decondition_context (childcontext (context), args... ))
451
434
end
452
- function AbstractPPL . decondition (context, args... )
453
- return decondition (NodeTrait (context), context, args... )
435
+ function decondition_context (context, args... )
436
+ return decondition_context (NodeTrait (context), context, args... )
454
437
end
455
- function AbstractPPL. decondition (context:: ConditionContext )
456
- return decondition (childcontext (context))
457
- end
458
- function AbstractPPL. decondition (context:: ConditionContext , sym)
459
- return condition (
460
- decondition (childcontext (context), sym), BangBang. delete!! (context. values, sym)
461
- )
438
+ function decondition_context (context:: ConditionContext )
439
+ return decondition_context (childcontext (context))
462
440
end
463
- function AbstractPPL. decondition (context:: ConditionContext , sym, syms... )
464
- return decondition (
465
- condition (
466
- decondition (childcontext (context), syms... ),
467
- BangBang. delete!! (context. values, sym),
468
- ),
469
- syms... ,
470
- )
471
- end
472
-
473
- function AbstractPPL. decondition (
474
- context:: NamedConditionContext , vn:: VarName{sym}
475
- ) where {sym}
476
- return condition (
477
- decondition (childcontext (context), vn), BangBang. delete!! (context. values, sym)
478
- )
441
+ function decondition_context (context:: ConditionContext , sym, syms... )
442
+ new_values = deepcopy (context. values)
443
+ for s in (sym, syms... )
444
+ new_values = BangBang. delete!! (new_values, s)
445
+ end
446
+ return if length (new_values) == 0
447
+ # No more values left, can unwrap
448
+ decondition_context (childcontext (context), syms... )
449
+ else
450
+ ConditionContext (
451
+ new_values, decondition_context (childcontext (context), sym, syms... )
452
+ )
453
+ end
479
454
end
480
- function AbstractPPL. decondition (context:: ConditionContext , vn:: VarName )
481
- return condition (
482
- decondition (childcontext (context), vn), BangBang. delete!! (context. values, vn)
455
+ function decondition_context (context:: NamedConditionContext , vn:: VarName{sym} ) where {sym}
456
+ return ConditionContext (
457
+ BangBang. delete!! (context. values, sym),
458
+ decondition_context (childcontext (context), vn),
483
459
)
484
460
end
485
461
0 commit comments