Skip to content

Commit 34be85c

Browse files
Also pass in context as an argument to acclogp!! (#563)
* also pass in `context` as an argument to `acclogp!!` * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * bump patch version --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent c9489aa commit 34be85c

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.24.1"
3+
version = "0.24.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/abstract_varinfo.jl

+10-2
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,20 @@ Set the log of the joint probability of the observed data and parameters sampled
107107
function setlogp!! end
108108

109109
"""
110-
acclogp!!(vi::AbstractVarInfo, logp)
110+
acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp)
111111
112112
Add `logp` to the value of the log of the joint probability of the observed data and
113113
parameters sampled in `vi`, mutating if it makes sense.
114114
"""
115-
function acclogp!! end
115+
function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp)
116+
return acclogp!!(NodeTrait(context), context, vi, logp)
117+
end
118+
function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp)
119+
return acclogp!!(vi, logp)
120+
end
121+
function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp)
122+
return acclogp!!(childcontext(context), vi, logp)
123+
end
116124

117125
"""
118126
resetlogp!!(vi::AbstractVarInfo)

src/context_implementations.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ probability of `vi` with the returned value.
115115
"""
116116
function tilde_assume!!(context, right, vn, vi)
117117
value, logp, vi = tilde_assume(context, right, vn, vi)
118-
return value, acclogp!!(vi, logp)
118+
return value, acclogp!!(context, vi, logp)
119119
end
120120

121121
# observe
@@ -181,7 +181,7 @@ probability of `vi` with the returned value.
181181
"""
182182
function tilde_observe!!(context, right, left, vi)
183183
logp, vi = tilde_observe(context, right, left, vi)
184-
return left, acclogp!!(vi, logp)
184+
return left, acclogp!!(context, vi, logp)
185185
end
186186

187187
function assume(rng, spl::Sampler, dist)
@@ -383,7 +383,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
383383
"""
384384
function dot_tilde_assume!!(context, right, left, vn, vi)
385385
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
386-
return value, acclogp!!(vi, logp), vi
386+
return value, acclogp!!(context, vi, logp), vi
387387
end
388388

389389
# `dot_assume`
@@ -634,7 +634,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`.
634634
"""
635635
function dot_tilde_observe!!(context, right, left, vi)
636636
logp, vi = dot_tilde_observe(context, right, left, vi)
637-
return left, acclogp!!(vi, logp)
637+
return left, acclogp!!(context, vi, logp)
638638
end
639639

640640
# Falls back to non-sampler definition.

src/utils.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ true
7070
"""
7171
macro addlogprob!(ex)
7272
return quote
73-
$(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex)))
73+
$(esc(:(__varinfo__))) = acclogp!!(
74+
$(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex))
75+
)
7476
end
7577
end
7678

0 commit comments

Comments
 (0)