-
Notifications
You must be signed in to change notification settings - Fork 35
[WIP] Add pure logdensity to Model #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it actually necessary to build a separate function? What are the advantages over defining logdensity
similar to evaluate
but with a special context that only accumulates the log density?
src/compiler.jl
Outdated
@@ -70,13 +70,19 @@ end | |||
|
|||
function model(mod, linenumbernode, expr, warn) | |||
modelinfo = build_model_info(expr) | |||
modelinfo_logπ = deepcopy(modelinfo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we use it everywhere but I really don't like the name logπ
😕 It's even a constant in LogExpFunctions/StatsFuns... Could we use something that is more descriptive? Maybe logdensity
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah yeah, happy to do so! Generally I'm opposed to using unicode-symbols in structs anyways because people might not always have access to that.
After removing the possible exceptions being thrown when RHS isn't a distribution. Still not quite there. You got any clue @devmotion ? @benchmark $g($θ_nt, $([1.0, 1.0]))
@benchmark $(m.logπ)($m, $spl, $ctx, $θ_nt, $(Float64), $(m.args...))
@code_typed g(θ_nt, [1.0, 1.0])
@code_typed m.logπ(m, spl, ctx, θ_nt, Float64, m.args...)
|
Hmm, but you mean while still not passing in the But I still don't want to keep the variables in |
It seems like the possible redefinition of the variable causes the following statements:
to show up. I'm not entirely certain why or what's going on here, but it seems to cause the additional allocations. EDIT: Nvm. I read a bit about SSA and AFAIK the above is a Pi-statement which represents information known at compile-time, and a Phi-statement which is a conditional depending on the referenced blocks. So in this case the Pi + Phi will be compiled away. I can confirm this by removing the |
Okay, so now I'm even more confused. I've removed the @benchmark $g($θ_nt, $([1.0, 1.0]))
@benchmark $(m.logπ)($m, $spl, $ctx, $θ_nt, $(Float64), $(m.args.x))
@code_typed g(θ_nt, [1.0, 1.0])
@code_typed m.logπ(m, spl, ctx, θ_nt, Float64, m.args...)
|
Yes, the idea would be to neither pass in |
Sounds like a good idea; I'll give this a try! Any thoughts on why we're still seeing perf loss compared to manual impl? |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
So now I've added a Also, I've added an implementation of I'm still confused about about why there's still an overhead for the generated EDIT: Also, there is a benefit to the generated |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@@ -0,0 +1,33 @@ | |||
@generated function _bijector(md::NamedTuple{names}; tuplify=false) where {names} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would make sense to move the bijector changes to a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you might be right here. But the reason why I added it is because we sort of "need" it if we support NamedTuple
evaluations without using VarInfo
, since there's no "linking" going on here. Therefore I figured I should just add it together with the introduction of SimpleVarInfo
or w/e we end up going with.
With that being said, I'm fine with making it a separate PR too:) Just wanted to explain my reasoning.
expr = Expr(:tuple) | ||
for n in names | ||
e = quote | ||
if length(md.$n.dists) == 1 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't be type stable, it seems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, but AFAIK there's no way to ensure that since dists
can be a vector of anything. The aim isn't for this to be type-stable either; it's meant be called once and then you re-use the resulting bijector wherever you need. So mby we should also just remove the "@generated`?
for n in names | ||
e = quote | ||
if length(md.$n.dists) == 1 && | ||
md.$n.dists[1] isa $(Distributions.UnivariateDistribution) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should not be necessary to qualify the types and functions in the generated function (BTW maybe make it if @generated
?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, didn't know; thanks!
And regarding @generated
; starting to think it shouldn't even be a @generated
...
|
||
Returns a `NamedBijector` which can transform different variants of `varinfo`. | ||
|
||
If `tuplify` is true, then a type-stable bijector will be returned. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe one could choose a name that indicates more clearly that it helps with type stability? It almost sounds like it would return a tuple of bijectors. Maybe also astuple
would be a more "formal" API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, yeah sorry; that was leftover from earlier. I also hate tuplify
:)
astuple
maybe, or maybe static
? Or typed
?
function generate_mainbody_logdensity(mod, expr, warn) | ||
return generate_mainbody_logdensity!(mod, Symbol[], expr, warn) | ||
end | ||
|
||
generate_mainbody_logdensity!(mod, found, x, warn) = x | ||
function generate_mainbody_logdensity!(mod, found, sym::Symbol, warn) | ||
if sym in DEPRECATED_INTERNALNAMES | ||
newsym = Symbol(:_, sym, :__) | ||
Base.depwarn( | ||
"internal variable `$sym` is deprecated, use `$newsym` instead.", | ||
:generate_mainbody_logdensity!, | ||
) | ||
return generate_mainbody_logdensity!(mod, found, newsym, warn) | ||
end | ||
|
||
if warn && sym in INTERNALNAMES && sym ∉ found | ||
@warn "you are using the internal variable `$sym`" | ||
push!(found, sym) | ||
end | ||
|
||
return sym | ||
end | ||
function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) | ||
# Do not touch interpolated expressions | ||
expr.head === :$ && return expr.args[1] | ||
|
||
# If it's a macro, we expand it | ||
if Meta.isexpr(expr, :macrocall) | ||
return generate_mainbody_logdensity!( | ||
mod, found, macroexpand(mod, expr; recursive=true), warn | ||
) | ||
end | ||
|
||
# If it's a return, we instead return `__lp__`. | ||
if Meta.isexpr(expr, :return) | ||
returnbody = Expr( | ||
:block, | ||
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., | ||
) | ||
return :($(returnbody); return __lp__) | ||
end | ||
|
||
# Modify dotted tilde operators. | ||
args_dottilde = getargs_dottilde(expr) | ||
if args_dottilde !== nothing | ||
L, R = args_dottilde | ||
left = generate_mainbody_logdensity!(mod, found, L, warn) | ||
return Base.remove_linenums!( | ||
generate_dot_tilde_logdensity( | ||
left, generate_mainbody_logdensity!(mod, found, R, warn) | ||
), | ||
) | ||
end | ||
|
||
# Modify tilde operators. | ||
args_tilde = getargs_tilde(expr) | ||
if args_tilde !== nothing | ||
L, R = args_tilde | ||
left = generate_mainbody_logdensity!(mod, found, L, warn) | ||
return Base.remove_linenums!( | ||
generate_tilde_logdensity( | ||
left, generate_mainbody_logdensity!(mod, found, R, warn) | ||
), | ||
) | ||
end | ||
|
||
return Expr( | ||
expr.head, | ||
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., | ||
) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't check the details but it seems this code is very similar to the existing generate_mainbody!
. Maybe one could just pass unify them by passing the inner functions (generate_dot_tilde
/generate_dot_tilde_logdensity
and generate_tilde
/generate_tilde_logdensity
) as first argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, I had the exact same thought as I was writing it out 😅 Left it as is for now though since I'm still on the edge whether or not we should have this vs. EvaluationContext
with SimpleVarInfo
. If we do keep it, I 100% agree.
if ctx.ctx isa PriorContext | ||
tilde_observe(LikelihoodContext(), sampler, right, value, vn, inds, vi) | ||
elseif ctx.ctx isa LikelihoodContext | ||
# Need to make it so that this isn't computed. | ||
tilde_observe(PriorContext(), sampler, right, value, vn, inds, vi) | ||
else | ||
tilde_observe(ctx, sampler, right, value, vn, inds, vi) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me immediately why this is needed, and why we use LikelihoodContext
for PriorContext
and vice versa.
In general, I think implementation-wise it would be better to dispatch on EvaluationContext
and to not hardcode the type checks in the function body.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why I use the if
here is because what I mentioned in #249 where it seemed more natural to have a if
here for readability. But I've added LeafCtx
in that PR too, so we could dispatch on this.
if vi === nothing | ||
return logp | ||
else | ||
acclogp!(vi, logp) | ||
return left | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here and below, I think it is cleaner to not add the check in the function body but to dispatch on ::Nothing
. Then one can move all these special definitions together and it can be explained and understood easier why they are needed.
@@ -409,7 +480,7 @@ function dot_observe( | |||
value::AbstractMatrix, | |||
vi, | |||
) | |||
increment_num_produce!(vi) | |||
vi isa VarInfo && increment_num_produce!(vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here, I think it might be cleaner to just dispatch on vi::VarInfo
.
@@ -0,0 +1,21 @@ | |||
struct SimpleVarInfo{NT,T} <: AbstractVarInfo | |||
θ::NT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems it is necessary to save the parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean here 😕
SimpleVarInfo
is used in conjuction with EvaluationContext
; it is not used for in the generated logjoint
. So SimpleVarInfo
just keeps the variables for access in the assume
and observe
statements. Does that help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha yeah now it's clear, I think I just mixed up the two approaches, the generated logjoint and the VarInfo/EvaluationContext one 😄
…into tor/logdensity
Co-authored-by: David Widmann <[email protected]>
:(__variables__), | ||
], | ||
modelinfo[:allargs_exprs], | ||
[Expr(:kw, :(::Type{$T}), :Float64), ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
[Expr(:kw, :(::Type{$T}), :Float64), ] | |
[Expr(:kw, :(::Type{$T}), :Float64)], |
Closed in favour of #242 . |
I'm currently working on a project where I need to "avoid" `VarInfo` due to inefficiencies + I only need HMC samplers for a very high-dimensional latent space. Therefore I've used the following workflow:
make_logjoint
which takes the model constructor as an argument and returns alogjoint
which I've implemented by hand.But I thought "Why not just add this to
Model
? Would be super-useful for a lot of standard samplers." So here we are. Is very early stage though, so would love some feedback!Why is performance worse?
WAZ GOING ON?! Might the
nothing
in thetilde_observe
not being propagated properly?