@@ -286,11 +286,7 @@ function generate_tilde(left, right)
286
286
if ! (left isa Symbol || left isa Expr)
287
287
return quote
288
288
$ (DynamicPPL. tilde_observe!)(
289
- __context__,
290
- __sampler__,
291
- $ (DynamicPPL. check_tilde_rhs)($ right),
292
- $ left,
293
- __varinfo__,
289
+ __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
294
290
)
295
291
end
296
292
end
@@ -304,9 +300,7 @@ function generate_tilde(left, right)
304
300
$ isassumption = $ (DynamicPPL. isassumption (left))
305
301
if $ isassumption
306
302
$ left = $ (DynamicPPL. tilde_assume!)(
307
- __rng__,
308
303
__context__,
309
- __sampler__,
310
304
$ (DynamicPPL. unwrap_right_vn)(
311
305
$ (DynamicPPL. check_tilde_rhs)($ right), $ vn
312
306
). .. ,
@@ -316,7 +310,6 @@ function generate_tilde(left, right)
316
310
else
317
311
$ (DynamicPPL. tilde_observe!)(
318
312
__context__,
319
- __sampler__,
320
313
$ (DynamicPPL. check_tilde_rhs)($ right),
321
314
$ left,
322
315
$ vn,
@@ -337,11 +330,7 @@ function generate_dot_tilde(left, right)
337
330
if ! (left isa Symbol || left isa Expr)
338
331
return quote
339
332
$ (DynamicPPL. dot_tilde_observe!)(
340
- __context__,
341
- __sampler__,
342
- $ (DynamicPPL. check_tilde_rhs)($ right),
343
- $ left,
344
- __varinfo__,
333
+ __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
345
334
)
346
335
end
347
336
end
@@ -355,9 +344,7 @@ function generate_dot_tilde(left, right)
355
344
$ isassumption = $ (DynamicPPL. isassumption (left))
356
345
if $ isassumption
357
346
$ left .= $ (DynamicPPL. dot_tilde_assume!)(
358
- __rng__,
359
347
__context__,
360
- __sampler__,
361
348
$ (DynamicPPL. unwrap_right_left_vns)(
362
349
$ (DynamicPPL. check_tilde_rhs)($ right), $ left, $ vn
363
350
). .. ,
@@ -367,7 +354,6 @@ function generate_dot_tilde(left, right)
367
354
else
368
355
$ (DynamicPPL. dot_tilde_observe!)(
369
356
__context__,
370
- __sampler__,
371
357
$ (DynamicPPL. check_tilde_rhs)($ right),
372
358
$ left,
373
359
$ vn,
@@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode)
398
384
# Add the internal arguments to the user-specified arguments (positional + keywords).
399
385
evaluatordef[:args ] = vcat (
400
386
[
401
- :(__rng__:: $ (Random. AbstractRNG)),
402
387
:(__model__:: $ (DynamicPPL. Model)),
403
388
:(__varinfo__:: $ (DynamicPPL. AbstractVarInfo)),
404
- :(__sampler__:: $ (DynamicPPL. AbstractSampler)),
405
389
:(__context__:: $ (DynamicPPL. AbstractContext)),
406
390
],
407
391
modelinfo[:allargs_exprs ],
449
433
450
434
"""
451
435
matchingvalue(sampler, vi, value)
436
+ matchingvalue(context::AbstractContext, vi, value)
437
+
438
+ Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.
452
439
453
- Convert the `value` to the correct type for the `sampler` and the `vi` object.
440
+ For a `context` that is _not_ a `SamplingContext`, we fall back to
441
+ `matchingvalue(SampleFromPrior(), vi, value)`.
454
442
"""
455
443
function matchingvalue (sampler, vi, value)
456
444
T = typeof (value)
@@ -465,7 +453,16 @@ function matchingvalue(sampler, vi, value)
465
453
return value
466
454
end
467
455
end
468
- matchingvalue (sampler, vi, value:: FloatOrArrayType ) = get_matching_type (sampler, vi, value)
456
+ function matchingvalue (sampler:: AbstractSampler , vi, value:: FloatOrArrayType )
457
+ return get_matching_type (sampler, vi, value)
458
+ end
459
+
460
+ function matchingvalue (context:: AbstractContext , vi, value)
461
+ return matchingvalue (SampleFromPrior (), vi, value)
462
+ end
463
+ function matchingvalue (context:: SamplingContext , vi, value)
464
+ return matchingvalue (context. sampler, vi, value)
465
+ end
469
466
470
467
"""
471
468
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}
0 commit comments