Skip to content

Commit 11a75ff

Browse files
torfjeldeyebai
andauthored
Add acclogp_assume!! and acclogp_observe!! (#565)
* add hooks for acclogp!! depending on whether it's from an `assume` or `observe` statement * bump patch version * Update src/context_implementations.jl * Update Project.toml Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: Hong Ge <[email protected]>
1 parent b52e4c2 commit 11a75ff

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.24.3"
3+
4+
version = "0.24.4"
45

56
[deps]
67
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/context_implementations.jl

+15-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
1414
require_gradient(spl::Sampler) = false
1515
require_particles(spl::Sampler) = false
1616

17+
# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline.
18+
function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp)
19+
return acclogp!!(context, vi, logp)
20+
end
21+
22+
function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp)
23+
return acclogp!!(context, vi, logp)
24+
end
25+
1726
# assume
1827
"""
1928
tilde_assume(context::SamplingContext, right, vn, vi)
@@ -115,7 +124,7 @@ probability of `vi` with the returned value.
115124
"""
116125
function tilde_assume!!(context, right, vn, vi)
117126
value, logp, vi = tilde_assume(context, right, vn, vi)
118-
return value, acclogp!!(context, vi, logp)
127+
return value, acclogp_assume!!(context, vi, logp)
119128
end
120129

121130
# observe
@@ -181,7 +190,7 @@ probability of `vi` with the returned value.
181190
"""
182191
function tilde_observe!!(context, right, left, vi)
183192
logp, vi = tilde_observe(context, right, left, vi)
184-
return left, acclogp!!(context, vi, logp)
193+
return left, acclogp_observe!!(context, vi, logp)
185194
end
186195

187196
function assume(rng, spl::Sampler, dist)
@@ -383,7 +392,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
383392
"""
384393
function dot_tilde_assume!!(context, right, left, vn, vi)
385394
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
386-
return value, acclogp!!(context, vi, logp), vi
395+
return value, acclogp_assume!!(context, vi, logp), vi
387396
end
388397

389398
# `dot_assume`
@@ -539,7 +548,8 @@ function get_and_set_val!(
539548
if istrans(vi)
540549
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
541550
# NOTE: Need to add the correction.
542-
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
551+
# FIXME: This is not great.
552+
acclogp_assume!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
543553
# `push!!` sets the trans-flag to `false` by default.
544554
settrans!!.((vi,), true, vns)
545555
else
@@ -634,7 +644,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`.
634644
"""
635645
function dot_tilde_observe!!(context, right, left, vi)
636646
logp, vi = dot_tilde_observe(context, right, left, vi)
637-
return left, acclogp!!(context, vi, logp)
647+
return left, acclogp_observe!!(context, vi, logp)
638648
end
639649

640650
# Falls back to non-sampler definition.

0 commit comments

Comments
 (0)