Skip to content

Commit 549d9b1

Browse files
torfjeldegithub-actions[bot]yebai
authored
Fix for LKJCholesky (#521)
* simplification of vectorize and make use of non-dist version in SimpleVarInfo * added special reconstruct for LKJCholeksy * make use of vectorize in setval! for VarInfo * added tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed test_setval! not working when the true value is not a vector * okay now we actually fixed the test_setval! * Update test/test_util.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml (#522) --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent 7ef5da7 commit 549d9b1

File tree

7 files changed

+68
-14
lines changed

7 files changed

+68
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.13"
3+
version = "0.23.14"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ values_as(vi::SimpleVarInfo) = vi.values
544544
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
545545
function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T}
546546
isempty(vi) && return T[]
547-
return mapreduce(v -> vec([v;]), vcat, values(vi.values))
547+
return mapreduce(vectorize, vcat, values(vi.values))
548548
end
549549
function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict}
550550
return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values)))

src/utils.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,10 @@ invlink_transform(dist) = inverse(link_transform(dist))
209209
# Helper functions for vectorize/reconstruct values #
210210
#####################################################
211211

212-
vectorize(d, r) = vec(r)
213-
vectorize(d::UnivariateDistribution, r::Real) = [r]
214-
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
215-
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
216-
vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL))
212+
vectorize(d, r) = vectorize(r)
213+
vectorize(r::Real) = [r]
214+
vectorize(r::AbstractArray{<:Real}) = copy(vec(r))
215+
vectorize(r::Cholesky) = copy(vec(r.UL))
217216

218217
# NOTE:
219218
# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real.
@@ -237,6 +236,15 @@ reconstruct(::UnivariateDistribution, val::Real) = val
237236
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
238237
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
239238
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)
239+
240+
function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})
241+
return reconstruct(dist, reshape(val, size(dist)))
242+
end
243+
function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real})
244+
return Cholesky(val, dist.uplo, 0)
245+
end
246+
reconstruct(::LKJCholesky, val::Cholesky) = val
247+
240248
function reconstruct(
241249
::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector
242250
)

src/varinfo.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,12 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
325325
The values may or may not be transformed to Euclidean space.
326326
"""
327327
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
328-
setval!(md::Metadata, val, vn::VarName) = md.vals[getrange(md, vn)] = [val;]
328+
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
329+
return md.vals[getrange(md, vn)] = val
330+
end
331+
function setval!(md::Metadata, val, vn::VarName)
332+
return md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val)
333+
end
329334

330335
"""
331336
getval(vi::VarInfo, vns::Vector{<:VarName})

test/linking.jl

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,51 @@ end
9191
end
9292
end
9393

94+
@testset "LKJCholesky" begin
95+
@testset "uplo=$uplo" for uplo in ['L', 'U']
96+
@model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo)
97+
@testset "d=$d" for d in [2, 3, 5]
98+
model = demo_lkj(d)
99+
dist = LKJCholesky(d, 1.0, uplo)
100+
values_original = rand(model)
101+
vis = DynamicPPL.TestUtils.setup_varinfos(
102+
model, values_original, (@varname(x),)
103+
)
104+
@testset "$(short_varinfo_name(vi))" for vi in vis
105+
val = vi[@varname(x), dist]
106+
# Ensure that `reconstruct` works as intended.
107+
@test val isa Cholesky
108+
@test val.uplo == uplo
109+
110+
@test length(vi[:]) == d^2
111+
lp = logpdf(dist, val)
112+
lp_model = logjoint(model, vi)
113+
@test lp_model lp
114+
# Linked.
115+
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
116+
@test length(vi_linked[:]) == d * (d - 1) ÷ 2
117+
# Should now include the log-absdet-jacobian correction.
118+
@test !(getlogp(vi_linked) lp)
119+
# Invlinked.
120+
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
121+
@test length(vi_invlinked[:]) == d^2
122+
@test getlogp(vi_invlinked) lp
123+
end
124+
end
125+
end
126+
end
127+
94128
# Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504
95-
@testset "dirichlet" begin
129+
@testset "Dirichlet" begin
96130
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
97131
@testset "d=$d" for d in [2, 3, 5]
98132
model = demo_dirichlet(d)
99133
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
100134
@testset "$(short_varinfo_name(vi))" for vi in vis
101135
lp = logpdf(Dirichlet(d, 1.0), vi[:])
102136
@test length(vi[:]) == d
103-
@test getlogp(vi) lp
137+
lp_model = logjoint(model, vi)
138+
@test lp_model lp
104139
# Linked.
105140
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
106141
@test length(vi_linked[:]) == d - 1

test/test_util.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,18 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1)
6161
nt = DynamicPPL.tonamedtuple(var_info)
6262
for (k, (vals, names)) in pairs(nt)
6363
for (n, v) in zip(names, vals)
64-
chain_val = if Symbol(n) keys(chain)
64+
if Symbol(n) keys(chain)
6565
# Assume it's a group
66-
vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx])
66+
chain_val = vec(
67+
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
68+
)
69+
v_true = vec(v)
6770
else
68-
chain[sample_idx, n, chain_idx]
71+
chain_val = chain[sample_idx, n, chain_idx]
72+
v_true = v
6973
end
70-
@test v == chain_val
74+
75+
@test v_true == chain_val
7176
end
7277
end
7378
end

test/turing/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ setprogress!(false)
1010
Random.seed!(100)
1111

1212
# load test utilities
13+
include(joinpath(pathof(DynamicPPL), "..", "..", "test", "test_util.jl"))
1314
include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_tests.jl"))
1415

1516
@testset "Turing" begin

0 commit comments

Comments
 (0)