Skip to content

Commit 61b564e

Browse files
torfjeldeyebai
andcommitted
Introduction of SamplingContext: keeping it simple (#259)
This is #253 but the only motivation here is to get `SamplingContext` in, nothing relating to interactions with other contexts, etc. Co-authored-by: Hong Ge <[email protected]>
1 parent ef6da43 commit 61b564e

12 files changed

+649
-194
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ export AbstractVarInfo,
7575
SampleFromPrior,
7676
SampleFromUniform,
7777
# Contexts
78+
SamplingContext,
7879
DefaultContext,
7980
LikelihoodContext,
8081
PriorContext,

src/compiler.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,7 @@ function generate_tilde(left, right)
286286
if !(left isa Symbol || left isa Expr)
287287
return quote
288288
$(DynamicPPL.tilde_observe!)(
289-
__context__,
290-
__sampler__,
291-
$(DynamicPPL.check_tilde_rhs)($right),
292-
$left,
293-
__varinfo__,
289+
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
294290
)
295291
end
296292
end
@@ -304,9 +300,7 @@ function generate_tilde(left, right)
304300
$isassumption = $(DynamicPPL.isassumption(left))
305301
if $isassumption
306302
$left = $(DynamicPPL.tilde_assume!)(
307-
__rng__,
308303
__context__,
309-
__sampler__,
310304
$(DynamicPPL.unwrap_right_vn)(
311305
$(DynamicPPL.check_tilde_rhs)($right), $vn
312306
)...,
@@ -316,7 +310,6 @@ function generate_tilde(left, right)
316310
else
317311
$(DynamicPPL.tilde_observe!)(
318312
__context__,
319-
__sampler__,
320313
$(DynamicPPL.check_tilde_rhs)($right),
321314
$left,
322315
$vn,
@@ -337,11 +330,7 @@ function generate_dot_tilde(left, right)
337330
if !(left isa Symbol || left isa Expr)
338331
return quote
339332
$(DynamicPPL.dot_tilde_observe!)(
340-
__context__,
341-
__sampler__,
342-
$(DynamicPPL.check_tilde_rhs)($right),
343-
$left,
344-
__varinfo__,
333+
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
345334
)
346335
end
347336
end
@@ -355,9 +344,7 @@ function generate_dot_tilde(left, right)
355344
$isassumption = $(DynamicPPL.isassumption(left))
356345
if $isassumption
357346
$left .= $(DynamicPPL.dot_tilde_assume!)(
358-
__rng__,
359347
__context__,
360-
__sampler__,
361348
$(DynamicPPL.unwrap_right_left_vns)(
362349
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
363350
)...,
@@ -367,7 +354,6 @@ function generate_dot_tilde(left, right)
367354
else
368355
$(DynamicPPL.dot_tilde_observe!)(
369356
__context__,
370-
__sampler__,
371357
$(DynamicPPL.check_tilde_rhs)($right),
372358
$left,
373359
$vn,
@@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode)
398384
# Add the internal arguments to the user-specified arguments (positional + keywords).
399385
evaluatordef[:args] = vcat(
400386
[
401-
:(__rng__::$(Random.AbstractRNG)),
402387
:(__model__::$(DynamicPPL.Model)),
403388
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
404-
:(__sampler__::$(DynamicPPL.AbstractSampler)),
405389
:(__context__::$(DynamicPPL.AbstractContext)),
406390
],
407391
modelinfo[:allargs_exprs],
@@ -411,7 +395,9 @@ function build_output(modelinfo, linenumbernode)
411395
evaluatordef[:kwargs] = []
412396

413397
# Replace the user-provided function body with the version created by DynamicPPL.
414-
evaluatordef[:body] = modelinfo[:body]
398+
evaluatordef[:body] = quote
399+
$(modelinfo[:body])
400+
end
415401

416402
## Build the model function.
417403

@@ -449,8 +435,12 @@ end
449435

450436
"""
451437
matchingvalue(sampler, vi, value)
438+
matchingvalue(context::AbstractContext, vi, value)
439+
440+
Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.
452441
453-
Convert the `value` to the correct type for the `sampler` and the `vi` object.
442+
For a `context` that is _not_ a `SamplingContext`, we fall back to
443+
`matchingvalue(SampleFromPrior(), vi, value)`.
454444
"""
455445
function matchingvalue(sampler, vi, value)
456446
T = typeof(value)
@@ -467,6 +457,13 @@ function matchingvalue(sampler, vi, value)
467457
end
468458
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)
469459

460+
function matchingvalue(context::AbstractContext, vi, value)
461+
return matchingvalue(SampleFromPrior(), vi, value)
462+
end
463+
function matchingvalue(context::SamplingContext, vi, value)
464+
return matchingvalue(context.sampler, vi, value)
465+
end
466+
470467
"""
471468
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}
472469

0 commit comments

Comments
 (0)