-
Notifications
You must be signed in to change notification settings - Fork 35
[WIP] Add pure logdensity to Model #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9006055
6414e42
95879e2
dfacb83
9c4f880
4dc65fc
cc421bd
5a279f7
4e1bb8c
6fbcdc5
74d620a
ee62924
13045ad
290bd38
d65337d
e7fbc15
c942f1f
ad1e852
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
@generated function _bijector(md::NamedTuple{names}; tuplify=false) where {names} | ||
expr = Expr(:tuple) | ||
for n in names | ||
e = quote | ||
if length(md.$n.dists) == 1 && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't be type stable, it seems? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, but AFAIK there's no way to ensure that since |
||
md.$n.dists[1] isa $(Distributions.UnivariateDistribution) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should not be necessary to qualify the types and functions in the generated function (BTW maybe make it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, didn't know; thanks! And regarding |
||
$(Bijectors).bijector(md.$n.dists[1]) | ||
elseif tuplify | ||
$(Bijectors.Stacked)( | ||
map($(Bijectors).bijector, tuple(md.$n.dists...)), md.$n.ranges | ||
) | ||
else | ||
$(Bijectors.Stacked)(map($(Bijectors).bijector, md.$n.dists), md.$n.ranges) | ||
end | ||
end | ||
push!(expr.args, e) | ||
end | ||
|
||
return quote | ||
bs = NamedTuple{$names}($expr) | ||
return $(Bijectors).NamedBijector(bs) | ||
end | ||
end | ||
|
||
""" | ||
bijector(varinfo::VarInfo; tuplify=false) | ||
|
||
Returns a `NamedBijector` which can transform different variants of `varinfo`. | ||
|
||
If `tuplify` is true, then a type-stable bijector will be returned. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe one could choose a name that indicates more clearly that it helps with type stability? It almost sounds like it would return a tuple of bijectors. Maybe also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, yeah sorry; that was leftover from earlier. I also hate
|
||
""" | ||
Bijectors.bijector(vi::TypedVarInfo; kwargs...) = _bijector(vi.metadata; kwargs...) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -84,11 +84,17 @@ end | |||||
|
||||||
function model(mod, linenumbernode, expr, warn) | ||||||
modelinfo = build_model_info(expr) | ||||||
modelinfo_logjoint = deepcopy(modelinfo) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice if the approach could cover also log prior, log likelihood, and other combinations of log densities - or does it already? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intention was that it should, but I just realized that this suffers from the same issues as |
||||||
|
||||||
# Generate main body | ||||||
modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) | ||||||
|
||||||
return build_output(modelinfo, linenumbernode) | ||||||
# Generate logjoint | ||||||
modelinfo_logjoint[:body] = generate_mainbody_logdensity( | ||||||
mod, modelinfo_logjoint[:modeldef][:body], warn | ||||||
) | ||||||
Comment on lines
+93
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it seems there is still a separate function generated, even with the new context? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry, no. I currently have both for the sake of comparison. Ideally we don't need both. But as I mentioned before, there is actually a benefit to the "generated" |
||||||
|
||||||
return build_output(modelinfo, modelinfo_logjoint, linenumbernode) | ||||||
end | ||||||
|
||||||
""" | ||||||
|
@@ -339,14 +345,7 @@ hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing( | |||||
hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true | ||||||
hasmissing(T::Type) = false | ||||||
|
||||||
""" | ||||||
build_output(modelinfo, linenumbernode) | ||||||
|
||||||
Builds the output expression. | ||||||
""" | ||||||
function build_output(modelinfo, linenumbernode) | ||||||
## Build the anonymous evaluator from the user-provided model definition. | ||||||
|
||||||
function build_evaluator(modelinfo) | ||||||
# Remove the name. | ||||||
evaluatordef = deepcopy(modelinfo[:modeldef]) | ||||||
delete!(evaluatordef, :name) | ||||||
|
@@ -369,6 +368,55 @@ function build_output(modelinfo, linenumbernode) | |||||
# Replace the user-provided function body with the version created by DynamicPPL. | ||||||
evaluatordef[:body] = modelinfo[:body] | ||||||
|
||||||
return evaluatordef | ||||||
end | ||||||
|
||||||
function build_logjoint(modelinfo) | ||||||
# Remove the name. | ||||||
def = deepcopy(modelinfo[:modeldef]) | ||||||
def[:name] = :logjoint | ||||||
|
||||||
# Add the internal arguments to the user-specified arguments (positional + keywords). | ||||||
@gensym T | ||||||
def[:args] = vcat( | ||||||
[ | ||||||
:(__model__::$(DynamicPPL.Model)), | ||||||
:(__sampler__::$(DynamicPPL.AbstractSampler)), | ||||||
:(__context__::$(DynamicPPL.AbstractContext)), | ||||||
:(__variables__), | ||||||
], | ||||||
modelinfo[:allargs_exprs], | ||||||
[Expr(:kw, :(::Type{$T}), :Float64), ] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
) | ||||||
|
||||||
# Add the type-parameter. | ||||||
def[:whereparams] = (def[:whereparams]..., T) | ||||||
|
||||||
# Delete the keyword arguments. | ||||||
def[:kwargs] = [] | ||||||
|
||||||
# Replace the user-provided function body with the version created by DynamicPPL. | ||||||
def[:body] = quote | ||||||
__lp__ = zero($T) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One could add |
||||||
$(modelinfo[:body]) | ||||||
return __lp__ | ||||||
end | ||||||
|
||||||
return def | ||||||
end | ||||||
|
||||||
""" | ||||||
build_output(modelinfo, linenumbernode) | ||||||
|
||||||
Builds the output expression. | ||||||
""" | ||||||
function build_output(modelinfo, modelinfo_logjoint, linenumbernode) | ||||||
## Build logjoint. | ||||||
logjointdef = build_logjoint(modelinfo_logjoint) | ||||||
|
||||||
## Build the anonymous evaluator from the user-provided model definition. | ||||||
evaluatordef = build_evaluator(modelinfo) | ||||||
|
||||||
## Build the model function. | ||||||
|
||||||
# Extract the named tuple expression of all arguments and the default values. | ||||||
|
@@ -378,18 +426,20 @@ function build_output(modelinfo, linenumbernode) | |||||
# Update the function body of the user-specified model. | ||||||
# We use a name for the anonymous evaluator that does not conflict with other variables. | ||||||
modeldef = modelinfo[:modeldef] | ||||||
@gensym evaluator | ||||||
@gensym evaluator logjoint | ||||||
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure | ||||||
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode` | ||||||
# to the call site | ||||||
modeldef[:body] = MacroTools.@q begin | ||||||
$(linenumbernode) | ||||||
$evaluator = $(MacroTools.combinedef(evaluatordef)) | ||||||
$logjoint = $(MacroTools.combinedef(logjointdef)) | ||||||
return $(DynamicPPL.Model)( | ||||||
$(QuoteNode(modeldef[:name])), | ||||||
$evaluator, | ||||||
$allargs_namedtuple, | ||||||
$defaults_namedtuple, | ||||||
$logjoint, | ||||||
) | ||||||
end | ||||||
|
||||||
|
@@ -447,3 +497,150 @@ end | |||||
|
||||||
floatof(::Type{T}) where {T<:Real} = typeof(one(T) / one(T)) | ||||||
floatof(::Type) = Real # fallback if type inference failed | ||||||
|
||||||
##################### | ||||||
### `logdensity` ### | ||||||
##################### | ||||||
function generate_mainbody_logdensity(mod, expr, warn) | ||||||
return generate_mainbody_logdensity!(mod, Symbol[], expr, warn) | ||||||
end | ||||||
|
||||||
generate_mainbody_logdensity!(mod, found, x, warn) = x | ||||||
function generate_mainbody_logdensity!(mod, found, sym::Symbol, warn) | ||||||
if sym in DEPRECATED_INTERNALNAMES | ||||||
newsym = Symbol(:_, sym, :__) | ||||||
Base.depwarn( | ||||||
"internal variable `$sym` is deprecated, use `$newsym` instead.", | ||||||
:generate_mainbody_logdensity!, | ||||||
) | ||||||
return generate_mainbody_logdensity!(mod, found, newsym, warn) | ||||||
end | ||||||
|
||||||
if warn && sym in INTERNALNAMES && sym ∉ found | ||||||
@warn "you are using the internal variable `$sym`" | ||||||
push!(found, sym) | ||||||
end | ||||||
|
||||||
return sym | ||||||
end | ||||||
function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) | ||||||
# Do not touch interpolated expressions | ||||||
expr.head === :$ && return expr.args[1] | ||||||
|
||||||
# If it's a macro, we expand it | ||||||
if Meta.isexpr(expr, :macrocall) | ||||||
return generate_mainbody_logdensity!( | ||||||
mod, found, macroexpand(mod, expr; recursive=true), warn | ||||||
) | ||||||
end | ||||||
|
||||||
# If it's a return, we instead return `__lp__`. | ||||||
if Meta.isexpr(expr, :return) | ||||||
returnbody = Expr( | ||||||
:block, | ||||||
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., | ||||||
) | ||||||
return :($(returnbody); return __lp__) | ||||||
end | ||||||
|
||||||
# Modify dotted tilde operators. | ||||||
args_dottilde = getargs_dottilde(expr) | ||||||
if args_dottilde !== nothing | ||||||
L, R = args_dottilde | ||||||
left = generate_mainbody_logdensity!(mod, found, L, warn) | ||||||
return Base.remove_linenums!( | ||||||
generate_dot_tilde_logdensity( | ||||||
left, generate_mainbody_logdensity!(mod, found, R, warn) | ||||||
), | ||||||
) | ||||||
end | ||||||
|
||||||
# Modify tilde operators. | ||||||
args_tilde = getargs_tilde(expr) | ||||||
if args_tilde !== nothing | ||||||
L, R = args_tilde | ||||||
left = generate_mainbody_logdensity!(mod, found, L, warn) | ||||||
return Base.remove_linenums!( | ||||||
generate_tilde_logdensity( | ||||||
left, generate_mainbody_logdensity!(mod, found, R, warn) | ||||||
), | ||||||
) | ||||||
end | ||||||
|
||||||
return Expr( | ||||||
expr.head, | ||||||
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., | ||||||
) | ||||||
end | ||||||
Comment on lines
+504
to
+574
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't check the details but it seems this code is very similar to the existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, I had the exact same thought as I was writing it out 😅 Left it as is for now though since I'm still on the edge whether or not we should have this vs. |
||||||
|
||||||
function generate_tilde_logdensity(left, right) | ||||||
# If the LHS is a literal, it is always an observation | ||||||
if !(left isa Symbol || left isa Expr) | ||||||
return quote | ||||||
__lp__ += $(DynamicPPL.tilde_observe)( | ||||||
__context__, | ||||||
__sampler__, | ||||||
$(DynamicPPL.check_tilde_rhs)($right), | ||||||
$left, | ||||||
__varinfo__, | ||||||
) | ||||||
end | ||||||
end | ||||||
|
||||||
@gensym vn inds isassumption | ||||||
|
||||||
# If it's not present in args of the model, we need to extract it from `__variables__`. | ||||||
return quote | ||||||
$vn = $(varname(left)) | ||||||
$inds = $(vinds(left)) | ||||||
$isassumption = $(DynamicPPL.isassumption(left)) | ||||||
if $isassumption | ||||||
$(vsym(left)) = __variables__.$(vsym(left)) | ||||||
end | ||||||
__lp__ += $(DynamicPPL.tilde_observe)( | ||||||
__context__, | ||||||
__sampler__, | ||||||
$(DynamicPPL.check_tilde_rhs)($right), | ||||||
$left, | ||||||
$vn, | ||||||
$inds, | ||||||
nothing, | ||||||
) | ||||||
end | ||||||
end | ||||||
|
||||||
function generate_dot_tilde_logdensity(left, right) | ||||||
# If the LHS is a literal, it is always an observation | ||||||
if !(left isa Symbol || left isa Expr) | ||||||
return quote | ||||||
__lp__ += $(DynamicPPL.dot_tilde_observe)( | ||||||
__context__, | ||||||
__sampler__, | ||||||
$(DynamicPPL.check_tilde_rhs)($right), | ||||||
$left, | ||||||
__varinfo__, | ||||||
) | ||||||
end | ||||||
end | ||||||
|
||||||
# Otherwise it is determined by the model or its value, | ||||||
# if the LHS represents an observation | ||||||
@gensym vn inds isassumption | ||||||
return quote | ||||||
$vn = $(varname(left)) | ||||||
$inds = $(vinds(left)) | ||||||
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing | ||||||
if $isassumption | ||||||
$(vsym(left)) = __variables__.$(vsym(left)) | ||||||
end | ||||||
__lp__ += $(DynamicPPL.dot_tilde_observe)( | ||||||
__context__, | ||||||
__sampler__, | ||||||
$(DynamicPPL.check_tilde_rhs)($right), | ||||||
$left, | ||||||
$vn, | ||||||
$inds, | ||||||
nothing, | ||||||
) | ||||||
end | ||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would make sense to move the bijector changes to a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you might be right here. But the reason why I added it is because we sort of "need" it if we support
NamedTuple
evaluations without usingVarInfo
, since there's no "linking" going on here. Therefore I figured I should just add it together with the introduction ofSimpleVarInfo
or w/e we end up going with.With that being said, I'm fine with making it a separate PR too:) Just wanted to explain my reasoning.