Skip to content

Commit bb8adb7

Browse files
committed
Implement tracked_varnames
1 parent 061acbe commit bb8adb7

File tree

5 files changed

+161
-63
lines changed

5 files changed

+161
-63
lines changed

src/values_as_in_model.jl

+67-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
"""
2+
TrackedValue{T}
3+
4+
A struct that wraps something on the right-hand side of `:=`. This is needed
5+
because the DynamicPPL compiler actually converts `lhs := rhs` to `lhs ~
6+
TrackedValue(rhs)` (so that we can hit the `tilde_assume` method below). Having
7+
the rhs wrapped in a TrackedValue makes sure that the logpdf of the rhs is not
8+
computed (as it wouldn't make sense).
9+
"""
110
struct TrackedValue{T}
211
value::T
312
end
@@ -24,17 +33,27 @@ struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
2433
values::OrderedDict
2534
"whether to extract variables on the LHS of :="
2635
include_colon_eq::Bool
36+
"varnames to be tracked; `nothing` means track all varnames"
37+
tracked_varnames::Union{Nothing,Array{<:VarName}}
2738
"child context"
2839
context::C
2940
end
30-
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
31-
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
41+
function ValuesAsInModelContext(
42+
include_colon_eq::Bool,
43+
tracked_varnames::Union{Nothing,Array{<:VarName}},
44+
context::AbstractContext,
45+
)
46+
return ValuesAsInModelContext(
47+
OrderedDict(), include_colon_eq, tracked_varnames, context
48+
)
3249
end
3350

3451
NodeTrait(::ValuesAsInModelContext) = IsParent()
3552
childcontext(context::ValuesAsInModelContext) = context.context
3653
function setchildcontext(context::ValuesAsInModelContext, child)
37-
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
54+
return ValuesAsInModelContext(
55+
context.values, context.include_colon_eq, context.tracked_varnames, child
56+
)
3857
end
3958

4059
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
@@ -63,29 +82,38 @@ end
6382

6483
# `tilde_asssume`
6584
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
66-
if is_tracked_value(right)
85+
is_tracked_value_right = is_tracked_value(right)
86+
if is_tracked_value_right
6787
value = right.value
6888
logp = zero(getlogp(vi))
6989
else
7090
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
7191
end
7292
# Save the value.
73-
push!(context, vn, value)
74-
# Save the value.
93+
if is_tracked_value_right ||
94+
isnothing(context.tracked_varnames) ||
95+
any(tracked_vn -> subsumes(tracked_vn, vn), context.tracked_varnames)
96+
push!(context, vn, value)
97+
end
7598
# Pass on.
7699
return value, logp, vi
77100
end
78101
function tilde_assume(
79102
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
80103
)
81-
if is_tracked_value(right)
104+
is_tracked_value_right = is_tracked_value(right)
105+
if is_tracked_value_right
82106
value = right.value
83107
logp = zero(getlogp(vi))
84108
else
85109
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
86110
end
87111
# Save the value.
88-
push!(context, vn, value)
112+
if is_tracked_value_right ||
113+
isnothing(context.tracked_varnames) ||
114+
any(tracked_vn -> subsumes(tracked_vn, vn), context.tracked_varnames)
115+
push!(context, vn, value)
116+
end
89117
# Pass on.
90118
return value, logp, vi
91119
end
@@ -167,9 +195,39 @@ function values_as_in_model(
167195
model::Model,
168196
include_colon_eq::Bool,
169197
varinfo::AbstractVarInfo,
198+
tracked_varnames=tracked_varnames(model),
170199
context::AbstractContext=DefaultContext(),
171200
)
172-
context = ValuesAsInModelContext(include_colon_eq, context)
201+
tracked_varnames = isnothing(tracked_varnames) ? nothing : collect(tracked_varnames)
202+
context = ValuesAsInModelContext(include_colon_eq, tracked_varnames, context)
173203
evaluate!!(model, varinfo, context)
174204
return context.values
175205
end
206+
207+
"""
208+
tracked_varnames(model::Model)
209+
210+
Returns a set of `VarName`s that the model should track.
211+
212+
By default, this returns `nothing`, which means that all `VarName`s should be
213+
tracked.
214+
215+
If you want to track only a subset of `VarName`s, you can override this method
216+
in your model definition:
217+
218+
```julia
219+
@model function mymodel()
220+
x ~ Normal()
221+
y ~ Normal(x, 1)
222+
end
223+
224+
DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(y)]
225+
```
226+
227+
Then, when you sample from `mymodel()`, only the value of `y` will be tracked
228+
(and not `x`).
229+
230+
Note that quantities on the left-hand side of `:=` are always tracked, and will
231+
ignore the varnames specified in this method.
232+
"""
233+
tracked_varnames(::Model) = nothing

test/compiler.jl

-12
Original file line numberDiff line numberDiff line change
@@ -728,18 +728,6 @@ module Issue537 end
728728
varinfo = VarInfo(model)
729729
@test haskey(varinfo, @varname(x))
730730
@test !haskey(varinfo, @varname(y))
731-
732-
# While `values_as_in_model` should contain both `x` and `y`, if
733-
# include_colon_eq is set to `true`.
734-
values = values_as_in_model(model, true, deepcopy(varinfo))
735-
@test haskey(values, @varname(x))
736-
@test haskey(values, @varname(y))
737-
738-
# And if include_colon_eq is set to `false`, then `values` should
739-
# only contain `x`.
740-
values = values_as_in_model(model, false, deepcopy(varinfo))
741-
@test haskey(values, @varname(x))
742-
@test !haskey(values, @varname(y))
743731
end
744732
end
745733

test/model.jl

-42
Original file line numberDiff line numberDiff line change
@@ -410,48 +410,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
410410
end
411411
end
412412

413-
@testset "values_as_in_model" begin
414-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
415-
vns = DynamicPPL.TestUtils.varnames(model)
416-
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
417-
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
418-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
419-
# We can set the include_colon_eq arg to false because none of
420-
# the demo models contain :=. The behaviour when
421-
# include_colon_eq is true is tested in test/compiler.jl
422-
realizations = values_as_in_model(model, false, varinfo)
423-
# Ensure that all variables are found.
424-
vns_found = collect(keys(realizations))
425-
@test vns vns_found == vns vns_found
426-
# Ensure that the values are the same.
427-
for vn in vns
428-
@test realizations[vn] == varinfo[vn]
429-
end
430-
end
431-
end
432-
433-
@testset "Prefixing" begin
434-
@model inner() = x ~ Normal()
435-
436-
@model function outer_auto_prefix()
437-
a ~ to_submodel(inner(), true)
438-
b ~ to_submodel(inner(), true)
439-
return nothing
440-
end
441-
@model function outer_manual_prefix()
442-
a ~ to_submodel(prefix(inner(), :a), false)
443-
b ~ to_submodel(prefix(inner(), :b), false)
444-
return nothing
445-
end
446-
447-
for model in (outer_auto_prefix(), outer_manual_prefix())
448-
vi = VarInfo(model)
449-
vns = Set(keys(values_as_in_model(model, false, vi)))
450-
@test vns == Set([@varname(var"a.x"), @varname(var"b.x")])
451-
end
452-
end
453-
end
454-
455413
@testset "Erroneous model call" begin
456414
# Calling a model with the wrong arguments used to lead to infinite recursion, see
457415
# https://github.com/TuringLang/Turing.jl/issues/2182. This guards against it.

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ include("test_util.jl")
5151
include("varinfo.jl")
5252
include("simple_varinfo.jl")
5353
include("model.jl")
54+
include("values_as_in_model.jl")
5455
include("sampler.jl")
5556
include("independence.jl")
5657
include("distribution_wrappers.jl")

test/values_as_in_model.jl

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
@testset "values_as_in_model" begin
2+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
3+
vns = DynamicPPL.TestUtils.varnames(model)
4+
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
5+
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
6+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
7+
# We can set the include_colon_eq arg to false because none of
8+
# the demo models contain :=. The behaviour when
9+
# include_colon_eq is true is tested in test/compiler.jl
10+
realizations = values_as_in_model(model, false, varinfo)
11+
# Ensure that all variables are found.
12+
vns_found = collect(keys(realizations))
13+
@test vns vns_found == vns vns_found
14+
# Ensure that the values are the same.
15+
for vn in vns
16+
@test realizations[vn] == varinfo[vn]
17+
end
18+
end
19+
end
20+
21+
@testset "support for :=" begin
22+
@model function demo_tracked()
23+
x ~ Normal()
24+
y := 100 + x
25+
return (; x, y)
26+
end
27+
@model function demo_tracked_submodel()
28+
return vals ~ to_submodel(demo_tracked(), false)
29+
end
30+
31+
for model in [demo_tracked(), demo_tracked_submodel()]
32+
values = values_as_in_model(model, true, VarInfo(model))
33+
@test haskey(values, @varname(x))
34+
@test haskey(values, @varname(y))
35+
36+
values = values_as_in_model(model, false, VarInfo(model))
37+
@test haskey(values, @varname(x))
38+
@test !haskey(values, @varname(y))
39+
end
40+
end
41+
42+
@testset "Prefixing" begin
43+
@model inner() = x ~ Normal()
44+
45+
@model function outer_auto_prefix()
46+
a ~ to_submodel(inner(), true)
47+
b ~ to_submodel(inner(), true)
48+
return nothing
49+
end
50+
@model function outer_manual_prefix()
51+
a ~ to_submodel(prefix(inner(), :a), false)
52+
b ~ to_submodel(prefix(inner(), :b), false)
53+
return nothing
54+
end
55+
56+
for model in (outer_auto_prefix(), outer_manual_prefix())
57+
vi = VarInfo(model)
58+
vns = Set(keys(values_as_in_model(model, false, vi)))
59+
@test vns == Set([@varname(var"a.x"), @varname(var"b.x")])
60+
end
61+
end
62+
63+
@testset "Track only specific varnames" begin
64+
@model function track_specific()
65+
x = Vector{Float64}(undef, 2)
66+
# Include a vector x to test for correct subsumption behaviour
67+
for i in eachindex(x)
68+
x[i] ~ Normal()
69+
end
70+
y ~ Normal(x[1], 1)
71+
z := sum(x)
72+
end
73+
74+
model = track_specific()
75+
vi = VarInfo(model)
76+
77+
# Specify varnames to be tracked directly as an argument to `values_as_in_model`
78+
values = values_as_in_model(model, true, vi, [@varname(x)])
79+
# Since x subsumes both x[1] and x[2], they should be included
80+
@test haskey(values, @varname(x[1]))
81+
@test haskey(values, @varname(x[2]))
82+
@test !haskey(values, @varname(y))
83+
@test haskey(values, @varname(z)) # := is always included
84+
85+
# Specify instead using `tracked_varnames` method
86+
DynamicPPL.tracked_varnames(::Model{typeof(track_specific)}) = [@varname(y)]
87+
values = values_as_in_model(model, true, vi)
88+
@test !haskey(values, @varname(x[1]))
89+
@test !haskey(values, @varname(x[2]))
90+
@test haskey(values, @varname(y))
91+
@test haskey(values, @varname(z)) # := is always included
92+
end
93+
end

0 commit comments

Comments
 (0)