Skip to content

Commit 4f31771

Browse files
torfjeldeyebai
andcommitted
Simplification of tilde-callstack (#252)
This PR introduces the simplification of the tilde-callstack as discussed in #249. Copy-pasted from there: - Remove unnecessary complexity in `~` implementation. - Current calling hierarchy for a `~` statement is: - `tilde_assume` -> `tilde(rng, ...)` -> `_tilde(rng, ...)` -> `assume` - `tilde_observe` -> `tilde(...)` -> `_tilde(...)` -> `observe` - Similarly for `dot_tilde_assume` and `dot_tilde_observe`. - This is super-confusing and difficult to debug. - `_tilde` is currently only used for `NamedDist` to allow overriding the variable-name used for a particular `~` statement. - Propose the following changes: - Remove `_tilde` and handle `NamedDist` _before_ calling `tilde_assume`, etc. by using a `unpack_right_vns` (and `unpack_right_left_vns` for dot-statements) (thanks to @devmotion) - Rename `tilde_assume` (`tilde_observe`) and to `tilde_assume!` (`tilde_observe!`), and `tilde(rng, ...)` (`tilde(...)`) to `tilde_assume(rng, ...)` (`tilde_observe(...)`). - `tilde_assume!` simply calls `tilde_assume` followed by `acclogp(varinfo, result_from_tilde_assume)`, so the `!` here is to indicate that it's mutating the `logp` field in `VarInfo`. Co-authored-by: Hong Ge <[email protected]>
1 parent c95ccfa commit 4f31771

14 files changed

+728
-268
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.11.4"
3+
version = "0.12.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/DynamicPPL.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,20 @@ export AbstractVarInfo,
7575
SampleFromPrior,
7676
SampleFromUniform,
7777
# Contexts
78+
SamplingContext,
7879
DefaultContext,
7980
LikelihoodContext,
8081
PriorContext,
8182
MiniBatchContext,
8283
PrefixContext,
8384
assume,
8485
dot_assume,
85-
observer,
86+
observe,
8687
dot_observe,
87-
tilde,
88-
dot_tilde,
88+
tilde_assume,
89+
tilde_observe,
90+
dot_tilde_assume,
91+
dot_tilde_observe,
8992
# Pseudo distributions
9093
NamedDist,
9194
NoDist,

src/compiler.jl

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,49 @@ end
6161
check_tilde_rhs(x::Distribution) = x
6262
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
6363

64+
"""
65+
unwrap_right_vn(right, vn)
66+
67+
Return the unwrapped distribution on the right-hand side and variable name on the left-hand
68+
side of a `~` expression such as `x ~ Normal()`.
69+
70+
This is used mainly to unwrap `NamedDist` distributions.
71+
"""
72+
unwrap_right_vn(right, vn) = right, vn
73+
unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name)
74+
75+
"""
76+
unwrap_right_left_vns(right, left, vns)
77+
78+
Return the unwrapped distributions on the right-hand side and values and variable names on the
79+
left-hand side of a `.~` expression such as `x .~ Normal()`.
80+
81+
This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
82+
variables.
83+
"""
84+
unwrap_right_left_vns(right, left, vns) = right, left, vns
85+
function unwrap_right_left_vns(right::NamedDist, left, vns)
86+
return unwrap_right_left_vns(right.dist, left, right.name)
87+
end
88+
function unwrap_right_left_vns(
89+
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
90+
)
91+
vns = map(axes(left, 2)) do i
92+
return VarName(vn, (vn.indexing..., Tuple(i)))
93+
end
94+
return unwrap_right_left_vns(right, left, vns)
95+
end
96+
function unwrap_right_left_vns(
97+
right::Union{Distribution,AbstractArray{<:Distribution}},
98+
left::AbstractArray,
99+
vn::VarName,
100+
)
101+
vns = map(CartesianIndices(left)) do i
102+
return VarName(vn, (vn.indexing..., Tuple(i)))
103+
end
104+
return unwrap_right_left_vns(right, left, vns)
105+
end
106+
64107
#################
65108
# Main Compiler #
66109
#################
@@ -256,12 +299,8 @@ function generate_tilde(left, right)
256299
# If the LHS is a literal, it is always an observation
257300
if isliteral(left)
258301
return quote
259-
$(DynamicPPL.tilde_observe)(
260-
__context__,
261-
__sampler__,
262-
$(DynamicPPL.check_tilde_rhs)($right),
263-
$left,
264-
__varinfo__,
302+
$(DynamicPPL.tilde_observe!)(
303+
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
265304
)
266305
end
267306
end
@@ -274,19 +313,17 @@ function generate_tilde(left, right)
274313
$inds = $(vinds(left))
275314
$isassumption = $(DynamicPPL.isassumption(left))
276315
if $isassumption
277-
$left = $(DynamicPPL.tilde_assume)(
278-
__rng__,
316+
$left = $(DynamicPPL.tilde_assume!)(
279317
__context__,
280-
__sampler__,
281-
$(DynamicPPL.check_tilde_rhs)($right),
282-
$vn,
318+
$(DynamicPPL.unwrap_right_vn)(
319+
$(DynamicPPL.check_tilde_rhs)($right), $vn
320+
)...,
283321
$inds,
284322
__varinfo__,
285323
)
286324
else
287-
$(DynamicPPL.tilde_observe)(
325+
$(DynamicPPL.tilde_observe!)(
288326
__context__,
289-
__sampler__,
290327
$(DynamicPPL.check_tilde_rhs)($right),
291328
$left,
292329
$vn,
@@ -306,12 +343,8 @@ function generate_dot_tilde(left, right)
306343
# If the LHS is a literal, it is always an observation
307344
if isliteral(left)
308345
return quote
309-
$(DynamicPPL.dot_tilde_observe)(
310-
__context__,
311-
__sampler__,
312-
$(DynamicPPL.check_tilde_rhs)($right),
313-
$left,
314-
__varinfo__,
346+
$(DynamicPPL.dot_tilde_observe!)(
347+
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
315348
)
316349
end
317350
end
@@ -324,20 +357,17 @@ function generate_dot_tilde(left, right)
324357
$inds = $(vinds(left))
325358
$isassumption = $(DynamicPPL.isassumption(left))
326359
if $isassumption
327-
$left .= $(DynamicPPL.dot_tilde_assume)(
328-
__rng__,
360+
$left .= $(DynamicPPL.dot_tilde_assume!)(
329361
__context__,
330-
__sampler__,
331-
$(DynamicPPL.check_tilde_rhs)($right),
332-
$left,
333-
$vn,
362+
$(DynamicPPL.unwrap_right_left_vns)(
363+
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
364+
)...,
334365
$inds,
335366
__varinfo__,
336367
)
337368
else
338-
$(DynamicPPL.dot_tilde_observe)(
369+
$(DynamicPPL.dot_tilde_observe!)(
339370
__context__,
340-
__sampler__,
341371
$(DynamicPPL.check_tilde_rhs)($right),
342372
$left,
343373
$vn,
@@ -368,10 +398,8 @@ function build_output(modelinfo, linenumbernode)
368398
# Add the internal arguments to the user-specified arguments (positional + keywords).
369399
evaluatordef[:args] = vcat(
370400
[
371-
:(__rng__::$(Random.AbstractRNG)),
372401
:(__model__::$(DynamicPPL.Model)),
373402
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
374-
:(__sampler__::$(DynamicPPL.AbstractSampler)),
375403
:(__context__::$(DynamicPPL.AbstractContext)),
376404
],
377405
modelinfo[:allargs_exprs],
@@ -419,8 +447,12 @@ end
419447

420448
"""
421449
matchingvalue(sampler, vi, value)
450+
matchingvalue(context::AbstractContext, vi, value)
451+
452+
Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.
422453
423-
Convert the `value` to the correct type for the `sampler` and the `vi` object.
454+
For a `context` that is _not_ a `SamplingContext`, we fall back to
455+
`matchingvalue(SampleFromPrior(), vi, value)`.
424456
"""
425457
function matchingvalue(sampler, vi, value)
426458
T = typeof(value)
@@ -435,7 +467,16 @@ function matchingvalue(sampler, vi, value)
435467
return value
436468
end
437469
end
438-
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)
470+
function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType)
471+
return get_matching_type(sampler, vi, value)
472+
end
473+
474+
function matchingvalue(context::AbstractContext, vi, value)
475+
return matchingvalue(SampleFromPrior(), vi, value)
476+
end
477+
function matchingvalue(context::SamplingContext, vi, value)
478+
return matchingvalue(context.sampler, vi, value)
479+
end
439480

440481
"""
441482
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}

0 commit comments

Comments
 (0)