1
+ """
2
+ TrackedValue{T}
3
+
4
+ A struct that wraps something on the right-hand side of `:=`. This is needed
5
+ because the DynamicPPL compiler actually converts `lhs := rhs` to `lhs ~
6
+ TrackedValue(rhs)` (so that we can hit the `tilde_assume` method below). Having
7
+ the rhs wrapped in a TrackedValue makes sure that the logpdf of the rhs is not
8
+ computed (as it wouldn't make sense).
9
+ """
1
10
struct TrackedValue{T}
2
11
value:: T
3
12
end
@@ -24,17 +33,27 @@ struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
24
33
values:: OrderedDict
25
34
" whether to extract variables on the LHS of :="
26
35
include_colon_eq:: Bool
36
+ " varnames to be tracked; `nothing` means track all varnames"
37
+ tracked_varnames:: Union{Nothing,Array{<:VarName}}
27
38
" child context"
28
39
context:: C
29
40
end
30
- function ValuesAsInModelContext (include_colon_eq, context:: AbstractContext )
31
- return ValuesAsInModelContext (OrderedDict (), include_colon_eq, context)
41
+ function ValuesAsInModelContext (
42
+ include_colon_eq:: Bool ,
43
+ tracked_varnames:: Union{Nothing,Array{<:VarName}} ,
44
+ context:: AbstractContext ,
45
+ )
46
+ return ValuesAsInModelContext (
47
+ OrderedDict (), include_colon_eq, tracked_varnames, context
48
+ )
32
49
end
33
50
34
51
NodeTrait (:: ValuesAsInModelContext ) = IsParent ()
35
52
childcontext (context:: ValuesAsInModelContext ) = context. context
36
53
function setchildcontext (context:: ValuesAsInModelContext , child)
37
- return ValuesAsInModelContext (context. values, context. include_colon_eq, child)
54
+ return ValuesAsInModelContext (
55
+ context. values, context. include_colon_eq, context. tracked_varnames, child
56
+ )
38
57
end
39
58
40
59
is_extracting_values (context:: ValuesAsInModelContext ) = context. include_colon_eq
63
82
64
83
# `tilde_asssume`
65
84
function tilde_assume (context:: ValuesAsInModelContext , right, vn, vi)
66
- if is_tracked_value (right)
85
+ is_tracked_value_right = is_tracked_value (right)
86
+ if is_tracked_value_right
67
87
value = right. value
68
88
logp = zero (getlogp (vi))
69
89
else
70
90
value, logp, vi = tilde_assume (childcontext (context), right, vn, vi)
71
91
end
72
92
# Save the value.
73
- push! (context, vn, value)
74
- # Save the value.
93
+ if is_tracked_value_right ||
94
+ isnothing (context. tracked_varnames) ||
95
+ any (tracked_vn -> subsumes (tracked_vn, vn), context. tracked_varnames)
96
+ push! (context, vn, value)
97
+ end
75
98
# Pass on.
76
99
return value, logp, vi
77
100
end
78
101
function tilde_assume (
79
102
rng:: Random.AbstractRNG , context:: ValuesAsInModelContext , sampler, right, vn, vi
80
103
)
81
- if is_tracked_value (right)
104
+ is_tracked_value_right = is_tracked_value (right)
105
+ if is_tracked_value_right
82
106
value = right. value
83
107
logp = zero (getlogp (vi))
84
108
else
85
109
value, logp, vi = tilde_assume (rng, childcontext (context), sampler, right, vn, vi)
86
110
end
87
111
# Save the value.
88
- push! (context, vn, value)
112
+ if is_tracked_value_right ||
113
+ isnothing (context. tracked_varnames) ||
114
+ any (tracked_vn -> subsumes (tracked_vn, vn), context. tracked_varnames)
115
+ push! (context, vn, value)
116
+ end
89
117
# Pass on.
90
118
return value, logp, vi
91
119
end
@@ -167,9 +195,39 @@ function values_as_in_model(
167
195
model:: Model ,
168
196
include_colon_eq:: Bool ,
169
197
varinfo:: AbstractVarInfo ,
198
+ tracked_varnames= tracked_varnames (model),
170
199
context:: AbstractContext = DefaultContext (),
171
200
)
172
- context = ValuesAsInModelContext (include_colon_eq, context)
201
+ tracked_varnames = isnothing (tracked_varnames) ? nothing : collect (tracked_varnames)
202
+ context = ValuesAsInModelContext (include_colon_eq, tracked_varnames, context)
173
203
evaluate!! (model, varinfo, context)
174
204
return context. values
175
205
end
206
+
207
+ """
208
+ tracked_varnames(model::Model)
209
+
210
+ Returns a set of `VarName`s that the model should track.
211
+
212
+ By default, this returns `nothing`, which means that all `VarName`s should be
213
+ tracked.
214
+
215
+ If you want to track only a subset of `VarName`s, you can override this method
216
+ in your model definition:
217
+
218
+ ```julia
219
+ @model function mymodel()
220
+ x ~ Normal()
221
+ y ~ Normal(x, 1)
222
+ end
223
+
224
+ DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(y)]
225
+ ```
226
+
227
+ Then, when you sample from `mymodel()`, only the value of `y` will be tracked
228
+ (and not `x`).
229
+
230
+ Note that quantities on the left-hand side of `:=` are always tracked, and will
231
+ ignore the varnames specified in this method.
232
+ """
233
+ tracked_varnames (:: Model ) = nothing
0 commit comments