Skip to content

Commit 12e7c27

Browse files
authored
Fix for type instability when using LKJChol (#548)
* concretize reshape in `reconstruct` for `LKJCholesky` to avoid type-instabilities * added tests for `demo_lkjchol` * bumped patch versionion
1 parent efd9da3 commit 12e7c27

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.19"
3+
version = "0.23.20"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
238238
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)
239239

240240
function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})
241-
return reconstruct(dist, reshape(val, size(dist)))
241+
return reconstruct(dist, Matrix(reshape(val, size(dist))))
242242
end
243243
function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real})
244244
return Cholesky(val, dist.uplo, 0)

test/model.jl

+32
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ function innermost_distribution_type(d::Distributions.Product)
2525
return dists[1]
2626
end
2727

28+
is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false
29+
is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true
30+
is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
31+
2832
@testset "model.jl" begin
2933
@testset "convenience functions" begin
3034
model = gdemo_default # defined in test/test_util.jl
@@ -329,4 +333,32 @@ end
329333
@test x_true.UL == result.x.UL
330334
end
331335
end
336+
337+
@testset "Type stability of models" begin
338+
models_to_test = [
339+
# FIXME: Fix issues with type-stability in `DEMO_MODELS`.
340+
# DynamicPPL.TestUtils.DEMO_MODELS...,
341+
DynamicPPL.TestUtils.demo_lkjchol(2),
342+
]
343+
@testset "$(model.f)" for model in models_to_test
344+
vns = DynamicPPL.TestUtils.varnames(model)
345+
example_values = DynamicPPL.TestUtils.rand(model)
346+
varinfos = filter(
347+
is_typed_varinfo,
348+
DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns),
349+
)
350+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
351+
@test (@inferred(DynamicPPL.evaluate!!(model, varinfo, DefaultContext()));
352+
true)
353+
354+
varinfo_linked = DynamicPPL.link(varinfo, model)
355+
@test (
356+
@inferred(
357+
DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext())
358+
);
359+
true
360+
)
361+
end
362+
end
363+
end
332364
end

0 commit comments

Comments
 (0)