@@ -247,7 +247,7 @@ getlogevidence(transitions, sampler, state) = missing
247
247
# This is type piracy (at least for SampleFromPrior).
248
248
function AbstractMCMC. bundle_samples (
249
249
ts:: Vector{<:Union{AbstractTransition,AbstractVarInfo}} ,
250
- model :: AbstractModel ,
250
+ model_or_ldf :: Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction} ,
251
251
spl:: Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler} ,
252
252
state,
253
253
chain_type:: Type{MCMCChains.Chains} ;
@@ -259,6 +259,11 @@ function AbstractMCMC.bundle_samples(
259
259
thinning= 1 ,
260
260
kwargs... ,
261
261
)
262
+ model = if model_or_ldf isa DynamicPPL. LogDensityFunction
263
+ model_or_ldf. model
264
+ else
265
+ model_or_ldf
266
+ end
262
267
# Convert transitions to array format.
263
268
# Also retrieve the variable names.
264
269
varnames, vals = _params_to_array (model, ts)
@@ -310,12 +315,17 @@ end
310
315
# This is type piracy (for SampleFromPrior).
311
316
function AbstractMCMC. bundle_samples (
312
317
ts:: Vector{<:Union{AbstractTransition,AbstractVarInfo}} ,
313
- model :: AbstractModel ,
318
+ model_or_ldf :: Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction} ,
314
319
spl:: Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler} ,
315
320
state,
316
321
chain_type:: Type{Vector{NamedTuple}} ;
317
322
kwargs... ,
318
323
)
324
+ model = if model_or_ldf isa DynamicPPL. LogDensityFunction
325
+ model_or_ldf. model
326
+ else
327
+ model_or_ldf
328
+ end
319
329
return map (ts) do t
320
330
# Construct a dictionary of pairs `vn => value`.
321
331
params = OrderedDict (getparams (model, t))
0 commit comments