@@ -286,11 +286,7 @@ function generate_tilde(left, right)
286
286
if ! (left isa Symbol || left isa Expr)
287
287
return quote
288
288
$ (DynamicPPL. tilde_observe!)(
289
- __context__,
290
- __sampler__,
291
- $ (DynamicPPL. check_tilde_rhs)($ right),
292
- $ left,
293
- __varinfo__,
289
+ __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
294
290
)
295
291
end
296
292
end
@@ -304,9 +300,7 @@ function generate_tilde(left, right)
304
300
$ isassumption = $ (DynamicPPL. isassumption (left))
305
301
if $ isassumption
306
302
$ left = $ (DynamicPPL. tilde_assume!)(
307
- __rng__,
308
303
__context__,
309
- __sampler__,
310
304
$ (DynamicPPL. unwrap_right_vn)(
311
305
$ (DynamicPPL. check_tilde_rhs)($ right), $ vn
312
306
). .. ,
@@ -316,7 +310,6 @@ function generate_tilde(left, right)
316
310
else
317
311
$ (DynamicPPL. tilde_observe!)(
318
312
__context__,
319
- __sampler__,
320
313
$ (DynamicPPL. check_tilde_rhs)($ right),
321
314
$ left,
322
315
$ vn,
@@ -337,11 +330,7 @@ function generate_dot_tilde(left, right)
337
330
if ! (left isa Symbol || left isa Expr)
338
331
return quote
339
332
$ (DynamicPPL. dot_tilde_observe!)(
340
- __context__,
341
- __sampler__,
342
- $ (DynamicPPL. check_tilde_rhs)($ right),
343
- $ left,
344
- __varinfo__,
333
+ __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
345
334
)
346
335
end
347
336
end
@@ -355,9 +344,7 @@ function generate_dot_tilde(left, right)
355
344
$ isassumption = $ (DynamicPPL. isassumption (left))
356
345
if $ isassumption
357
346
$ left .= $ (DynamicPPL. dot_tilde_assume!)(
358
- __rng__,
359
347
__context__,
360
- __sampler__,
361
348
$ (DynamicPPL. unwrap_right_left_vns)(
362
349
$ (DynamicPPL. check_tilde_rhs)($ right), $ left, $ vn
363
350
). .. ,
@@ -367,7 +354,6 @@ function generate_dot_tilde(left, right)
367
354
else
368
355
$ (DynamicPPL. dot_tilde_observe!)(
369
356
__context__,
370
- __sampler__,
371
357
$ (DynamicPPL. check_tilde_rhs)($ right),
372
358
$ left,
373
359
$ vn,
@@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode)
398
384
# Add the internal arguments to the user-specified arguments (positional + keywords).
399
385
evaluatordef[:args ] = vcat (
400
386
[
401
- :(__rng__:: $ (Random. AbstractRNG)),
402
387
:(__model__:: $ (DynamicPPL. Model)),
403
388
:(__varinfo__:: $ (DynamicPPL. AbstractVarInfo)),
404
- :(__sampler__:: $ (DynamicPPL. AbstractSampler)),
405
389
:(__context__:: $ (DynamicPPL. AbstractContext)),
406
390
],
407
391
modelinfo[:allargs_exprs ],
@@ -411,7 +395,9 @@ function build_output(modelinfo, linenumbernode)
411
395
evaluatordef[:kwargs ] = []
412
396
413
397
# Replace the user-provided function body with the version created by DynamicPPL.
414
- evaluatordef[:body ] = modelinfo[:body ]
398
+ evaluatordef[:body ] = quote
399
+ $ (modelinfo[:body ])
400
+ end
415
401
416
402
# # Build the model function.
417
403
449
435
450
436
"""
451
437
matchingvalue(sampler, vi, value)
438
+ matchingvalue(context::AbstractContext, vi, value)
439
+
440
+ Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.
452
441
453
- Convert the `value` to the correct type for the `sampler` and the `vi` object.
442
+ For a `context` that is _not_ a `SamplingContext`, we fall back to
443
+ `matchingvalue(SampleFromPrior(), vi, value)`.
454
444
"""
455
445
function matchingvalue (sampler, vi, value)
456
446
T = typeof (value)
@@ -467,6 +457,13 @@ function matchingvalue(sampler, vi, value)
467
457
end
468
458
matchingvalue (sampler, vi, value:: FloatOrArrayType ) = get_matching_type (sampler, vi, value)
469
459
460
+ function matchingvalue (context:: AbstractContext , vi, value)
461
+ return matchingvalue (SampleFromPrior (), vi, value)
462
+ end
463
+ function matchingvalue (context:: SamplingContext , vi, value)
464
+ return matchingvalue (context. sampler, vi, value)
465
+ end
466
+
470
467
"""
471
468
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}
472
469
0 commit comments