Skip to content

Commit 2c9f970

Browse files
committed
Add new interface for AbstractMCMC.sample with LDFCompatibleAlgorithm
1 parent ae7abde commit 2c9f970

File tree

2 files changed

+388
-5
lines changed

2 files changed

+388
-5
lines changed

src/mcmc/Inference.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ getlogevidence(transitions, sampler, state) = missing
247247
# This is type piracy (at least for SampleFromPrior).
248248
function AbstractMCMC.bundle_samples(
249249
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
250-
model::AbstractModel,
250+
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
251251
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
252252
state,
253253
chain_type::Type{MCMCChains.Chains};
@@ -259,6 +259,11 @@ function AbstractMCMC.bundle_samples(
259259
thinning=1,
260260
kwargs...,
261261
)
262+
model = if model_or_ldf isa DynamicPPL.LogDensityFunction
263+
model_or_ldf.model
264+
else
265+
model_or_ldf
266+
end
262267
# Convert transitions to array format.
263268
# Also retrieve the variable names.
264269
varnames, vals = _params_to_array(model, ts)
@@ -310,12 +315,17 @@ end
310315
# This is type piracy (for SampleFromPrior).
311316
function AbstractMCMC.bundle_samples(
312317
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
313-
model::AbstractModel,
318+
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
314319
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
315320
state,
316321
chain_type::Type{Vector{NamedTuple}};
317322
kwargs...,
318323
)
324+
model = if model_or_ldf isa DynamicPPL.LogDensityFunction
325+
model_or_ldf.model
326+
else
327+
model_or_ldf
328+
end
319329
return map(ts) do t
320330
# Construct a dictionary of pairs `vn => value`.
321331
params = OrderedDict(getparams(model, t))

0 commit comments

Comments
 (0)