@@ -75,6 +75,20 @@ function AbstractMCMC.sample(
75
75
initial_state= DynamicPPL. loadstate (resume_from),
76
76
kwargs... ,
77
77
)
78
+ # LDF needs to be set with SamplingContext, or else samplers cannot
79
+ # overload the tilde-pipeline.
80
+ if ! (ldf. context isa SamplingContext)
81
+ ldf = LogDensityFunction (
82
+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
83
+ )
84
+ end
85
+ # Note that, in particular, sampling can mutate the variables in the LDF's
86
+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
87
+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
88
+ # that the parameters in the LDF are the initial parameters. So, we need to
89
+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
90
+ # reproducible.
91
+ ldf = deepcopy (ldf)
78
92
# TODO : Right now, only generic checks are run. We could in principle
79
93
# specialise this to check for e.g. discrete variables with HMC
80
94
check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
@@ -140,6 +154,20 @@ function AbstractMCMC.sample(
140
154
initial_state= DynamicPPL. loadstate (resume_from),
141
155
kwargs... ,
142
156
)
157
+ # LDF needs to be set with SamplingContext, or else samplers cannot
158
+ # overload the tilde-pipeline.
159
+ if ! (ldf. context isa SamplingContext)
160
+ ldf = LogDensityFunction (
161
+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
162
+ )
163
+ end
164
+ # Note that, in particular, sampling can mutate the variables in the LDF's
165
+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
166
+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
167
+ # that the parameters in the LDF are the initial parameters. So, we need to
168
+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
169
+ # reproducible.
170
+ ldf = deepcopy (ldf)
143
171
# TODO : Right now, only generic checks are run. We could in principle
144
172
# specialise this to check for e.g. discrete variables with HMC
145
173
check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
0 commit comments