Skip to content

Commit 740dca5

Browse files
committed
Some minor utility improvements (#452)
This PR does the following: - Moves the `varname_leaves` from `TestUtils` to main module. - It can be very useful in Turing.jl for constructing `Chains` and the like, so I think it's a good idea to make it part of the main module rather than keeping it "hidden" there. - Makes the default `varinfo` in the constructor of `LogDensityFunction` be `model.context` rather than a new `DynamicPPL.DefaultContext`. - The `context` pass to `evaluate!!` will override the leaf-context in `model.context`, and so the current default constructor always uses `DefaultContext` as the leaf-context, even if the `Model` has been `contextualize`d with some other leaf-context, e.g. `PriorContext`. This PR fixes this issue.
1 parent df2c975 commit 740dca5

File tree

4 files changed

+57
-23
lines changed

4 files changed

+57
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.21.5"
3+
version = "0.21.6"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/logdensityfunction.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ $(FIELDS)
1010
```jldoctest
1111
julia> using Distributions
1212
13-
julia> using DynamicPPL: LogDensityFunction
13+
julia> using DynamicPPL: LogDensityFunction, contextualize
1414
1515
julia> @model function demo(x)
1616
m ~ Normal()
@@ -36,6 +36,12 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
3636
3737
julia> LogDensityProblems.logdensity(f, [0.0])
3838
-2.3378770664093453
39+
40+
julia> # This also respects the context in `model`.
41+
f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model));
42+
43+
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
44+
true
3945
```
4046
"""
4147
struct LogDensityFunction{V,M,C}
@@ -60,7 +66,7 @@ end
6066
function LogDensityFunction(
6167
model::Model,
6268
varinfo::AbstractVarInfo=VarInfo(model),
63-
context::AbstractContext=DefaultContext(),
69+
context::AbstractContext=model.context,
6470
)
6571
return LogDensityFunction(varinfo, model, context)
6672
end

src/test_utils.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,8 @@ using Random: Random
1010
using Bijectors: Bijectors
1111
using Setfield: Setfield
1212

13-
"""
14-
varname_leaves(vn::VarName, val)
15-
16-
Return iterator over all varnames that are represented by `vn` on `val`,
17-
e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`.
18-
"""
19-
varname_leaves(vn::VarName, val::Real) = [vn]
20-
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
21-
return (
22-
VarName(vn, DynamicPPL.getlens(vn) Setfield.IndexLens(Tuple(I))) for
23-
I in CartesianIndices(val)
24-
)
25-
end
26-
function varname_leaves(vn::VarName, val::AbstractArray)
27-
return Iterators.flatten(
28-
varname_leaves(
29-
VarName(vn, DynamicPPL.getlens(vn) Setfield.IndexLens(Tuple(I))), val[I]
30-
) for I in CartesianIndices(val)
31-
)
32-
end
13+
# For backwards compat.
14+
using DynamicPPL: varname_leaves
3315

3416
"""
3517
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)

src/utils.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,49 @@ infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_
740740

741741
# No need + causes issues for some AD backends, e.g. Zygote.
742742
ChainRulesCore.@non_differentiable infer_nested_eltype(x)
743+
744+
"""
745+
varname_leaves(vn::VarName, val)
746+
747+
Return an iterator over all varnames that are represented by `vn` on `val`.
748+
749+
# Examples
750+
```jldoctest
751+
julia> using DynamicPPL: varname_leaves
752+
753+
julia> foreach(println, varname_leaves(@varname(x), rand(2)))
754+
x[1]
755+
x[2]
756+
757+
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
758+
x[1:2][1]
759+
x[1:2][2]
760+
761+
julia> x = (y = 1, z = [[2.0], [3.0]]);
762+
763+
julia> foreach(println, varname_leaves(@varname(x), x))
764+
x.y
765+
x.z[1][1]
766+
x.z[2][1]
767+
```
768+
"""
769+
varname_leaves(vn::VarName, ::Real) = [vn]
770+
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
771+
return (
772+
VarName(vn, getlens(vn) Setfield.IndexLens(Tuple(I))) for
773+
I in CartesianIndices(val)
774+
)
775+
end
776+
function varname_leaves(vn::VarName, val::AbstractArray)
777+
return Iterators.flatten(
778+
varname_leaves(VarName(vn, getlens(vn) Setfield.IndexLens(Tuple(I))), val[I]) for
779+
I in CartesianIndices(val)
780+
)
781+
end
782+
function varname_leaves(vn::DynamicPPL.VarName, val::NamedTuple)
783+
iter = Iterators.map(keys(val)) do sym
784+
lens = Setfield.PropertyLens{sym}()
785+
varname_leaves(vn lens, get(val, lens))
786+
end
787+
return Iterators.flatten(iter)
788+
end

0 commit comments

Comments
 (0)