Skip to content

Commit 29a6c7e

Browse files
authored
Handle nested PrefixContext (#787)
* Prefix varnames appropriately inside check_model_and_trace * Fix values_as_in_model as well * Add test for check_model with manual prefix * Add values_as_in_model tests * Add tests for prefix nesting * Bump Project.toml
1 parent 00e7ee3 commit 29a6c7e

File tree

7 files changed

+73
-19
lines changed

7 files changed

+73
-19
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.34.1"
3+
version = "0.34.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/contexts.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ end
261261

262262
const PREFIX_SEPARATOR = Symbol(".")
263263

264+
# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here
264265
function PrefixContext{PrefixInner}(
265266
context::PrefixContext{PrefixOuter}
266267
) where {PrefixInner,PrefixOuter}
@@ -273,13 +274,15 @@ function PrefixContext{PrefixInner}(
273274
end
274275
end
275276

276-
function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
277-
if @generated
278-
return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn)))
279-
else
280-
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
281-
end
277+
# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here
278+
function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
279+
return prefix(
280+
childcontext(ctx), VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
281+
)
282282
end
283+
prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn)
284+
prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn
285+
prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn)
283286

284287
"""
285288
prefix(model::Model, x)

src/debug_utils.jl

+12-11
Original file line numberDiff line numberDiff line change
@@ -239,50 +239,51 @@ function DynamicPPL.setchildcontext(context::DebugContext, child)
239239
end
240240

241241
function record_varname!(context::DebugContext, varname::VarName, dist)
242-
if haskey(context.varnames_seen, varname)
242+
prefixed_varname = prefix(context, varname)
243+
if haskey(context.varnames_seen, prefixed_varname)
243244
if context.error_on_failure
244-
error("varname $varname used multiple times in model")
245+
error("varname $prefixed_varname used multiple times in model")
245246
else
246-
@warn "varname $varname used multiple times in model"
247+
@warn "varname $prefixed_varname used multiple times in model"
247248
end
248-
context.varnames_seen[varname] += 1
249+
context.varnames_seen[prefixed_varname] += 1
249250
else
250251
# We need to check:
251252
# 1. Does this `varname` subsume any of the other keys.
252253
# 2. Does any of the other keys subsume `varname`.
253254
vns = collect(keys(context.varnames_seen))
254255
# Is `varname` subsumed by any of the other keys?
255-
idx_parent = findfirst(Base.Fix2(subsumes, varname), vns)
256+
idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns)
256257
if idx_parent !== nothing
257258
varname_parent = vns[idx_parent]
258259
if context.error_on_failure
259260
error(
260-
"varname $(varname_parent) used multiple times in model (subsumes $varname)",
261+
"varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)",
261262
)
262263
else
263-
@warn "varname $(varname_parent) used multiple times in model (subsumes $varname)"
264+
@warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)"
264265
end
265266
# Update count of parent.
266267
context.varnames_seen[varname_parent] += 1
267268
else
268269
# Does `varname` subsume any of the other keys?
269-
idx_child = findfirst(Base.Fix1(subsumes, varname), vns)
270+
idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns)
270271
if idx_child !== nothing
271272
varname_child = vns[idx_child]
272273
if context.error_on_failure
273274
error(
274-
"varname $(varname_child) used multiple times in model (subsumed by $varname)",
275+
"varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)",
275276
)
276277
else
277-
@warn "varname $(varname_child) used multiple times in model (subsumed by $varname)"
278+
@warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)"
278279
end
279280

280281
# Update count of child.
281282
context.varnames_seen[varname_child] += 1
282283
end
283284
end
284285

285-
context.varnames_seen[varname] = 1
286+
context.varnames_seen[prefixed_varname] = 1
286287
end
287288
end
288289

src/values_as_in_model.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ is_extracting_values(::IsParent, ::AbstractContext) = false
4545
is_extracting_values(::IsLeaf, ::AbstractContext) = false
4646

4747
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
48-
return setindex!(context.values, copy(value), vn)
48+
return setindex!(context.values, copy(value), prefix(context, vn))
4949
end
5050

5151
function broadcast_push!(context::ValuesAsInModelContext, vns, values)

test/contexts.jl

+20
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,26 @@ end
162162
@test getoptic(vn_prefixed) === getoptic(vn)
163163
end
164164

165+
@testset "nested within arbitrary context stacks" begin
166+
vn = @varname(x[1])
167+
ctx1 = PrefixContext{:a}(DefaultContext())
168+
ctx2 = SamplingContext(ctx1)
169+
ctx3 = PrefixContext{:b}(ctx2)
170+
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
171+
vn_prefixed1 = prefix(ctx1, vn)
172+
vn_prefixed2 = prefix(ctx2, vn)
173+
vn_prefixed3 = prefix(ctx3, vn)
174+
vn_prefixed4 = prefix(ctx4, vn)
175+
@test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x")
176+
@test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x")
177+
@test DynamicPPL.getsym(vn_prefixed3) == Symbol("a.b.x")
178+
@test DynamicPPL.getsym(vn_prefixed4) == Symbol("a.b.x")
179+
@test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn)
180+
@test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn)
181+
@test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn)
182+
@test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn)
183+
end
184+
165185
context = DynamicPPL.PrefixContext{:prefix}(SamplingContext())
166186
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
167187
# Sample with the context.

test/debug_utils.jl

+9
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@
6060
end
6161
model = ModelOuterWorking()
6262
@test check_model(model; error_on_failure=true)
63+
64+
# With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785
65+
@model function ModelOuterWorking2()
66+
x1 ~ to_submodel(prefix(ModelInner(), :a), false)
67+
x2 ~ to_submodel(prefix(ModelInner(), :b), false)
68+
return (x1, x2)
69+
end
70+
model = ModelOuterWorking2()
71+
@test check_model(model; error_on_failure=true)
6372
end
6473

6574
@testset "subsumes (x then x[1])" begin

test/model.jl

+21
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,27 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
429429
end
430430
end
431431
end
432+
433+
@testset "Prefixing" begin
434+
@model inner() = x ~ Normal()
435+
436+
@model function outer_auto_prefix()
437+
a ~ to_submodel(inner(), true)
438+
b ~ to_submodel(inner(), true)
439+
return nothing
440+
end
441+
@model function outer_manual_prefix()
442+
a ~ to_submodel(prefix(inner(), :a), false)
443+
b ~ to_submodel(prefix(inner(), :b), false)
444+
return nothing
445+
end
446+
447+
for model in (outer_auto_prefix(), outer_manual_prefix())
448+
vi = VarInfo(model)
449+
vns = Set(keys(values_as_in_model(model, false, vi)))
450+
@test vns == Set([@varname(var"a.x"), @varname(var"b.x")])
451+
end
452+
end
432453
end
433454

434455
@testset "Erroneous model call" begin

0 commit comments

Comments
 (0)