Skip to content

Commit 06d319c

Browse files
torfjeldeyebai
andcommitted
Introduction of SamplingContext: keeping it simple (#259)
This is #253 but the only motivation here is to get `SamplingContext` in, nothing relating to interactions with other contexts, etc. Co-authored-by: Hong Ge <[email protected]>
1 parent ef6da43 commit 06d319c

13 files changed

+653
-200
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ export AbstractVarInfo,
7575
SampleFromPrior,
7676
SampleFromUniform,
7777
# Contexts
78+
SamplingContext,
7879
DefaultContext,
7980
LikelihoodContext,
8081
PriorContext,

src/compiler.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,7 @@ function generate_tilde(left, right)
286286
if !(left isa Symbol || left isa Expr)
287287
return quote
288288
$(DynamicPPL.tilde_observe!)(
289-
__context__,
290-
__sampler__,
291-
$(DynamicPPL.check_tilde_rhs)($right),
292-
$left,
293-
__varinfo__,
289+
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
294290
)
295291
end
296292
end
@@ -304,9 +300,7 @@ function generate_tilde(left, right)
304300
$isassumption = $(DynamicPPL.isassumption(left))
305301
if $isassumption
306302
$left = $(DynamicPPL.tilde_assume!)(
307-
__rng__,
308303
__context__,
309-
__sampler__,
310304
$(DynamicPPL.unwrap_right_vn)(
311305
$(DynamicPPL.check_tilde_rhs)($right), $vn
312306
)...,
@@ -316,7 +310,6 @@ function generate_tilde(left, right)
316310
else
317311
$(DynamicPPL.tilde_observe!)(
318312
__context__,
319-
__sampler__,
320313
$(DynamicPPL.check_tilde_rhs)($right),
321314
$left,
322315
$vn,
@@ -337,11 +330,7 @@ function generate_dot_tilde(left, right)
337330
if !(left isa Symbol || left isa Expr)
338331
return quote
339332
$(DynamicPPL.dot_tilde_observe!)(
340-
__context__,
341-
__sampler__,
342-
$(DynamicPPL.check_tilde_rhs)($right),
343-
$left,
344-
__varinfo__,
333+
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
345334
)
346335
end
347336
end
@@ -355,9 +344,7 @@ function generate_dot_tilde(left, right)
355344
$isassumption = $(DynamicPPL.isassumption(left))
356345
if $isassumption
357346
$left .= $(DynamicPPL.dot_tilde_assume!)(
358-
__rng__,
359347
__context__,
360-
__sampler__,
361348
$(DynamicPPL.unwrap_right_left_vns)(
362349
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
363350
)...,
@@ -367,7 +354,6 @@ function generate_dot_tilde(left, right)
367354
else
368355
$(DynamicPPL.dot_tilde_observe!)(
369356
__context__,
370-
__sampler__,
371357
$(DynamicPPL.check_tilde_rhs)($right),
372358
$left,
373359
$vn,
@@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode)
398384
# Add the internal arguments to the user-specified arguments (positional + keywords).
399385
evaluatordef[:args] = vcat(
400386
[
401-
:(__rng__::$(Random.AbstractRNG)),
402387
:(__model__::$(DynamicPPL.Model)),
403388
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
404-
:(__sampler__::$(DynamicPPL.AbstractSampler)),
405389
:(__context__::$(DynamicPPL.AbstractContext)),
406390
],
407391
modelinfo[:allargs_exprs],
@@ -449,8 +433,12 @@ end
449433

450434
"""
451435
matchingvalue(sampler, vi, value)
436+
matchingvalue(context::AbstractContext, vi, value)
437+
438+
Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.
452439
453-
Convert the `value` to the correct type for the `sampler` and the `vi` object.
440+
For a `context` that is _not_ a `SamplingContext`, we fall back to
441+
`matchingvalue(SampleFromPrior(), vi, value)`.
454442
"""
455443
function matchingvalue(sampler, vi, value)
456444
T = typeof(value)
@@ -465,7 +453,16 @@ function matchingvalue(sampler, vi, value)
465453
return value
466454
end
467455
end
468-
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)
456+
function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType)
457+
return get_matching_type(sampler, vi, value)
458+
end
459+
460+
function matchingvalue(context::AbstractContext, vi, value)
461+
return matchingvalue(SampleFromPrior(), vi, value)
462+
end
463+
function matchingvalue(context::SamplingContext, vi, value)
464+
return matchingvalue(context.sampler, vi, value)
465+
end
469466

470467
"""
471468
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}

0 commit comments

Comments
 (0)