Skip to content

Commit 3ff1c95

Browse files
committed
fix reproducibility of sampling
1 parent 151a429 commit 3ff1c95

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ function AbstractMCMC.sample(
7575
initial_state=DynamicPPL.loadstate(resume_from),
7676
kwargs...,
7777
)
78+
# LDF needs to be set with SamplingContext, or else samplers cannot
79+
# overload the tilde-pipeline.
80+
if !(ldf.context isa SamplingContext)
81+
ldf = LogDensityFunction(
82+
ldf.model, ldf.varinfo, SamplingContext(rng, spl); adtype=ldf.adtype
83+
)
84+
end
85+
# Note that, in particular, sampling can mutate the variables in the LDF's
86+
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
87+
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
88+
# that the parameters in the LDF are the initial parameters. So, we need to
89+
# deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
90+
# reproducible.
91+
ldf = deepcopy(ldf)
7892
# TODO: Right now, only generic checks are run. We could in principle
7993
# specialise this to check for e.g. discrete variables with HMC
8094
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)
@@ -140,6 +154,20 @@ function AbstractMCMC.sample(
140154
initial_state=DynamicPPL.loadstate(resume_from),
141155
kwargs...,
142156
)
157+
# LDF needs to be set with SamplingContext, or else samplers cannot
158+
# overload the tilde-pipeline.
159+
if !(ldf.context isa SamplingContext)
160+
ldf = LogDensityFunction(
161+
ldf.model, ldf.varinfo, SamplingContext(rng, spl); adtype=ldf.adtype
162+
)
163+
end
164+
# Note that, in particular, sampling can mutate the variables in the LDF's
165+
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
166+
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
167+
# that the parameters in the LDF are the initial parameters. So, we need to
168+
# deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
169+
# reproducible.
170+
ldf = deepcopy(ldf)
143171
# TODO: Right now, only generic checks are run. We could in principle
144172
# specialise this to check for e.g. discrete variables with HMC
145173
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)

0 commit comments

Comments
 (0)