Skip to content

Commit 816e962

Browse files
torfjeldeyebai
andauthored
Fix for higher-dim Dirichlet, e.g. product_distribution (#586)
* fix for linking higher dimensional `Dirichlet` * bump patch version * bump Bijectors compat entry * added tests for high-dim Dirichlet * added `reconstruct` for `ArrayLikeVariate` * don't test functionality that relies on TuringLang/DistributionsAD.jl#264 yet * Update test/linking.jl * Update test/linking.jl * Update test/linking.jl --------- Co-authored-by: Hong Ge <[email protected]>
1 parent f32631f commit 816e962

File tree

4 files changed

+41
-1
lines changed

4 files changed

+41
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ ADTypes = "0.2"
2929
AbstractMCMC = "5"
3030
AbstractPPL = "0.7"
3131
BangBang = "0.3"
32-
Bijectors = "0.13"
32+
Bijectors = "0.13.9"
3333
ChainRulesCore = "1"
3434
Compat = "4"
3535
ConstructionBase = "1.5.4"

src/abstract_varinfo.jl

+10
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,16 @@ function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
765765
return with_logabsdet_jacobian(f, x_recon)
766766
end
767767

768+
# NOTE: Necessary to handle product distributions of `Dirichlet` and similar.
769+
function with_logabsdet_jacobian_and_reconstruct(
770+
f::Bijectors.Inverse{<:Bijectors.SimplexBijector}, dist, y
771+
)
772+
(d, ns...) = size(dist)
773+
yreshaped = reshape(y, d - 1, ns...)
774+
x, logjac = with_logabsdet_jacobian(f, yreshaped)
775+
return x, logjac
776+
end
777+
768778
# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
769779
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
770780
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.

src/utils.jl

+5
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ reconstruct(f, dist, val) = reconstruct(dist, val)
237237
reconstruct(::UnivariateDistribution, val::Real) = val
238238
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
239239
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
240+
function reconstruct(
241+
::Distribution{ArrayLikeVariate{N}}, val::AbstractArray{<:Real,N}
242+
) where {N}
243+
return copy(val)
244+
end
240245
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)
241246

242247
function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})

test/linking.jl

+25
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,29 @@ end
174174
end
175175
end
176176
end
177+
178+
# Related: https://github.com/TuringLang/Turing.jl/issues/2190
179+
@testset "High-dim Dirichlet" begin
180+
@model function demo_highdim_dirichlet(ns...)
181+
return x ~ filldist(Dirichlet(ones(2)), ns...)
182+
end
183+
@testset "ns=$ns" for ns in [
184+
(3,),
185+
# TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304
186+
# (3, 4), (3, 4, 5)
187+
]
188+
model = demo_highdim_dirichlet(ns...)
189+
example_values = rand(NamedTuple, model)
190+
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
191+
@testset "$(short_varinfo_name(vi))" for vi in vis
192+
# Linked.
193+
vi_linked = if mutable
194+
DynamicPPL.link!!(deepcopy(vi), model)
195+
else
196+
DynamicPPL.link(vi, model)
197+
end
198+
@test length(vi_linked[:]) == prod(ns)
199+
end
200+
end
201+
end
177202
end

0 commit comments

Comments
 (0)