Skip to content

[WIP] Add pure logdensity to Model #242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,14 @@ include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("threadsafe.jl")
include("context_implementations.jl")
include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("bijectors.jl")
include("submodel_macro.jl")

end # module
32 changes: 32 additions & 0 deletions src/bijectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
@generated function _bijector(md::NamedTuple{names}; tuplify=false) where {names}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make sense to move the bijector changes to a separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you might be right here. But the reason why I added it is because we sort of "need" it if we support NamedTuple evaluations without using VarInfo, since there's no "linking" going on here. Therefore I figured I should just add it together with the introduction of SimpleVarInfo or w/e we end up going with.

With that being said, I'm fine with making it a separate PR too:) Just wanted to explain my reasoning.

expr = Expr(:tuple)
for n in names
e = quote
if length(md.$n.dists) == 1 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be type stable, it seems?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, but AFAIK there's no way to ensure that since dists can be a vector of anything. The aim isn't for this to be type-stable either; it's meant be called once and then you re-use the resulting bijector wherever you need. So mby we should also just remove the "@generated`?

md.$n.dists[1] isa $(Distributions.UnivariateDistribution)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should not be necessary to qualify the types and functions in the generated function (BTW maybe make it if @generated?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, didn't know; thanks!

And regarding @generated; starting to think it shouldn't even be a @generated...

$(Bijectors).bijector(md.$n.dists[1])
elseif tuplify
$(Bijectors.Stacked)(
map($(Bijectors).bijector, tuple(md.$n.dists...)), md.$n.ranges
)
else
$(Bijectors.Stacked)(map($(Bijectors).bijector, md.$n.dists), md.$n.ranges)
end
end
push!(expr.args, e)
end

return quote
bs = NamedTuple{$names}($expr)
return $(Bijectors).NamedBijector(bs)
end
end

"""
bijector(varinfo::VarInfo; tuplify=false)

Returns a `NamedBijector` which can transform different variants of `varinfo`.

If `tuplify` is true, then a type-stable bijector will be returned.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe one could choose a name that indicates more clearly that it helps with type stability? It almost sounds like it would return a tuple of bijectors. Maybe also astuple would be a more "formal" API?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, yeah sorry; that was leftover from earlier. I also hate tuplify:)

astuple maybe, or maybe static? Or typed?

"""
Bijectors.bijector(vi::TypedVarInfo; kwargs...) = _bijector(vi.metadata; kwargs...)
217 changes: 207 additions & 10 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,17 @@ end

function model(mod, linenumbernode, expr, warn)
modelinfo = build_model_info(expr)
modelinfo_logjoint = deepcopy(modelinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if the approach could cover also log prior, log likelihood, and other combinations of log densities - or does it already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention was that it should, but I just realized that this suffers from the same issues as EvaluationContext: it will "force" observe and thus if you provide PriorContext or LikelihoodContext it won't do what you expect.


# Generate main body
modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)

return build_output(modelinfo, linenumbernode)
# Generate logjoint
modelinfo_logjoint[:body] = generate_mainbody_logdensity(
mod, modelinfo_logjoint[:modeldef][:body], warn
)
Comment on lines +93 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems there is still a separate function generated, even with the new context?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, no. I currently have both for the sake of comparison. Ideally we don't need both. But as I mentioned before, there is actually a benefit to the "generated" logjoint.


return build_output(modelinfo, modelinfo_logjoint, linenumbernode)
end

"""
Expand Down Expand Up @@ -339,14 +345,7 @@ hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(
hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true
hasmissing(T::Type) = false

"""
build_output(modelinfo, linenumbernode)

Builds the output expression.
"""
function build_output(modelinfo, linenumbernode)
## Build the anonymous evaluator from the user-provided model definition.

function build_evaluator(modelinfo)
# Remove the name.
evaluatordef = deepcopy(modelinfo[:modeldef])
delete!(evaluatordef, :name)
Expand All @@ -369,6 +368,55 @@ function build_output(modelinfo, linenumbernode)
# Replace the user-provided function body with the version created by DynamicPPL.
evaluatordef[:body] = modelinfo[:body]

return evaluatordef
end

function build_logjoint(modelinfo)
# Remove the name.
def = deepcopy(modelinfo[:modeldef])
def[:name] = :logjoint

# Add the internal arguments to the user-specified arguments (positional + keywords).
@gensym T
def[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__sampler__::$(DynamicPPL.AbstractSampler)),
:(__context__::$(DynamicPPL.AbstractContext)),
:(__variables__),
],
modelinfo[:allargs_exprs],
[Expr(:kw, :(::Type{$T}), :Float64), ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
[Expr(:kw, :(::Type{$T}), :Float64), ]
[Expr(:kw, :(::Type{$T}), :Float64)],

)

# Add the type-parameter.
def[:whereparams] = (def[:whereparams]..., T)

# Delete the keyword arguments.
def[:kwargs] = []

# Replace the user-provided function body with the version created by DynamicPPL.
def[:body] = quote
__lp__ = zero($T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could add __lp__ to the internal variable names. But since we want to move away from the warnings it's probably not a good idea 😛

$(modelinfo[:body])
return __lp__
end

return def
end

"""
build_output(modelinfo, linenumbernode)

Builds the output expression.
"""
function build_output(modelinfo, modelinfo_logjoint, linenumbernode)
## Build logjoint.
logjointdef = build_logjoint(modelinfo_logjoint)

## Build the anonymous evaluator from the user-provided model definition.
evaluatordef = build_evaluator(modelinfo)

## Build the model function.

# Extract the named tuple expression of all arguments and the default values.
Expand All @@ -378,18 +426,20 @@ function build_output(modelinfo, linenumbernode)
# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]
@gensym evaluator
@gensym evaluator logjoint
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
$evaluator = $(MacroTools.combinedef(evaluatordef))
$logjoint = $(MacroTools.combinedef(logjointdef))
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
$defaults_namedtuple,
$logjoint,
)
end

Expand Down Expand Up @@ -447,3 +497,150 @@ end

floatof(::Type{T}) where {T<:Real} = typeof(one(T) / one(T))
floatof(::Type) = Real # fallback if type inference failed

#####################
### `logdensity` ###
#####################
function generate_mainbody_logdensity(mod, expr, warn)
return generate_mainbody_logdensity!(mod, Symbol[], expr, warn)
end

generate_mainbody_logdensity!(mod, found, x, warn) = x
function generate_mainbody_logdensity!(mod, found, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody_logdensity!,
)
return generate_mainbody_logdensity!(mod, found, newsym, warn)
end

if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$sym`"
push!(found, sym)
end

return sym
end
function generate_mainbody_logdensity!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody_logdensity!(
mod, found, macroexpand(mod, expr; recursive=true), warn
)
end

# If it's a return, we instead return `__lp__`.
if Meta.isexpr(expr, :return)
returnbody = Expr(
:block,
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...,
)
return :($(returnbody); return __lp__)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
left = generate_mainbody_logdensity!(mod, found, L, warn)
return Base.remove_linenums!(
generate_dot_tilde_logdensity(
left, generate_mainbody_logdensity!(mod, found, R, warn)
),
)
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
left = generate_mainbody_logdensity!(mod, found, L, warn)
return Base.remove_linenums!(
generate_tilde_logdensity(
left, generate_mainbody_logdensity!(mod, found, R, warn)
),
)
end

return Expr(
expr.head,
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...,
)
end
Comment on lines +504 to +574
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the details but it seems this code is very similar to the existing generate_mainbody!. Maybe one could just pass unify them by passing the inner functions (generate_dot_tilde/generate_dot_tilde_logdensity and generate_tilde/generate_tilde_logdensity) as first argument?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, I had the exact same thought as I was writing it out 😅 Left it as is for now though since I'm still on the edge whether or not we should have this vs. EvaluationContext with SimpleVarInfo. If we do keep it, I 100% agree.


function generate_tilde_logdensity(left, right)
# If the LHS is a literal, it is always an observation
if !(left isa Symbol || left isa Expr)
return quote
__lp__ += $(DynamicPPL.tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
)
end
end

@gensym vn inds isassumption

# If it's not present in args of the model, we need to extract it from `__variables__`.
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$(vsym(left)) = __variables__.$(vsym(left))
end
__lp__ += $(DynamicPPL.tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
nothing,
)
end
end

function generate_dot_tilde_logdensity(left, right)
# If the LHS is a literal, it is always an observation
if !(left isa Symbol || left isa Expr)
return quote
__lp__ += $(DynamicPPL.dot_tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
)
end
end

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
if $isassumption
$(vsym(left)) = __variables__.$(vsym(left))
end
__lp__ += $(DynamicPPL.dot_tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
nothing,
)
end
end
Loading