@@ -14,6 +14,15 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
14
14
require_gradient (spl:: Sampler ) = false
15
15
require_particles (spl:: Sampler ) = false
16
16
17
+ # Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline.
18
+ function acclogp_assume!! (context:: AbstractContext , vi:: AbstractVarInfo , logp)
19
+ return acclogp!! (context, vi, logp)
20
+ end
21
+
22
+ function acclogp_observe!! (context:: AbstractContext , vi:: AbstractVarInfo , logp)
23
+ return acclogp!! (context, vi, logp)
24
+ end
25
+
17
26
# assume
18
27
"""
19
28
tilde_assume(context::SamplingContext, right, vn, vi)
@@ -115,7 +124,7 @@ probability of `vi` with the returned value.
115
124
"""
116
125
function tilde_assume!! (context, right, vn, vi)
117
126
value, logp, vi = tilde_assume (context, right, vn, vi)
118
- return value, acclogp !! (context, vi, logp)
127
+ return value, acclogp_assume !! (context, vi, logp)
119
128
end
120
129
121
130
# observe
@@ -181,7 +190,7 @@ probability of `vi` with the returned value.
181
190
"""
182
191
function tilde_observe!! (context, right, left, vi)
183
192
logp, vi = tilde_observe (context, right, left, vi)
184
- return left, acclogp !! (context, vi, logp)
193
+ return left, acclogp_observe !! (context, vi, logp)
185
194
end
186
195
187
196
function assume (rng, spl:: Sampler , dist)
@@ -383,7 +392,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
383
392
"""
384
393
function dot_tilde_assume!! (context, right, left, vn, vi)
385
394
value, logp, vi = dot_tilde_assume (context, right, left, vn, vi)
386
- return value, acclogp !! (context, vi, logp), vi
395
+ return value, acclogp_assume !! (context, vi, logp), vi
387
396
end
388
397
389
398
# `dot_assume`
@@ -539,7 +548,8 @@ function get_and_set_val!(
539
548
if istrans (vi)
540
549
push!! .((vi,), vns, reconstruct_and_link .((vi,), vns, dists, r), dists, (spl,))
541
550
# NOTE: Need to add the correction.
542
- acclogp!! (vi, sum (logabsdetjac .(bijector .(dists), r)))
551
+ # FIXME : This is not great.
552
+ acclogp_assume!! (vi, sum (logabsdetjac .(bijector .(dists), r)))
543
553
# `push!!` sets the trans-flag to `false` by default.
544
554
settrans!! .((vi,), true , vns)
545
555
else
@@ -634,7 +644,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`.
634
644
"""
635
645
function dot_tilde_observe!! (context, right, left, vi)
636
646
logp, vi = dot_tilde_observe (context, right, left, vi)
637
- return left, acclogp !! (context, vi, logp)
647
+ return left, acclogp_observe !! (context, vi, logp)
638
648
end
639
649
640
650
# Falls back to non-sampler definition.
0 commit comments