-
Notifications
You must be signed in to change notification settings - Fork 32
Making use of Symbolics.jl/SymbolicUtils.jl #234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
783d81d
0421aed
14ad023
03067a0
785242e
a3a9742
284d85d
817533c
daa1319
70c9997
23141be
18ae91b
b19700e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
module Symbolic | ||
|
||
using ..DynamicPPL: DynamicPPL | ||
import ..DynamicPPL: | ||
Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext | ||
|
||
using Random: Random | ||
using Bijectors: Bijectors | ||
using Distributions | ||
using Symbolics: Symbolics | ||
import Symbolics: SymbolicUtils | ||
|
||
issym(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = true | ||
issym(x) = false | ||
|
||
include("rules.jl") | ||
include("contexts.jl") | ||
|
||
# Allow `num` to appear on the RHS of `~`. | ||
DynamicPPL.check_tilde_rhs(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = x | ||
|
||
symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...) | ||
function symbolize( | ||
rng::Random.AbstractRNG, | ||
m::Model, | ||
vi::VarInfo=VarInfo(m); | ||
spl=SampleFromPrior(), | ||
ctx=DefaultContext(), | ||
include_data=false, | ||
) | ||
m(rng, vi, spl, ctx) | ||
θ_orig = vi[spl] | ||
|
||
# Symbolic `logpdf` for fixed observations. | ||
# TODO: don't `collect` once symbolic arrays are mature enough. | ||
Symbolics.@variables θ[1:length(θ_orig)] | ||
vi = VarInfo{Real}(vi, spl, θ) | ||
m(vi, ctx) | ||
|
||
return vi, θ | ||
end | ||
|
||
function dependencies(ctx::SymbolicContext, vn::VarName) | ||
right = ctx.vn2rights[vn] | ||
r = Symbolics.value(right) | ||
|
||
if !issym(r) | ||
# No dependencies. | ||
return [] | ||
end | ||
|
||
args = SymbolicUtils.arguments(r) | ||
return mapreduce(vcat, args) do a | ||
Symbolics.get_variables(a) | ||
end | ||
end | ||
function dependencies(ctx::SymbolicContext, symbolic=false) | ||
vn2var = ctx.vn2var | ||
var2vn = Dict(values(vn2var) .=> keys(vn2var)) | ||
return Dict( | ||
(symbolic ? vn2var[vn] : vn) => | ||
map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) for | ||
vn in keys(ctx.vn2var) | ||
) | ||
end | ||
|
||
function dependencies(m::Model, symbolic=false) | ||
ctx = SymbolicContext(DefaultContext()) | ||
vi = symbolize(m, VarInfo(m); ctx=ctx) | ||
|
||
return dependencies(ctx, symbolic) | ||
end | ||
Comment on lines
+67
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uses "contextual" dispatch to overload the corresponding |
||
|
||
function symbolic_logp(m::Model) | ||
vi, θ = symbolize(m) | ||
lp = DynamicPPL.getlogp(vi) | ||
lp_analytic = analytic_rw(Symbolics.value(lp)) | ||
lp_analytic_num = addnum_rw(lp_analytic) | ||
|
||
return lp_analytic_num, θ | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
struct SymbolicContext{Ctx} <: DynamicPPL.AbstractContext | ||
ctx::Ctx | ||
vn2var::Dict | ||
vn2rights::Dict | ||
end | ||
SymbolicContext() = SymbolicContext(DefaultContext()) | ||
SymbolicContext(ctx) = SymbolicContext(ctx, Dict(), Dict()) | ||
|
||
Symbolics.@register Distributions.loglikelihood(dist, x) | ||
|
||
# assume | ||
function DynamicPPL.tilde_assume(ctx::SymbolicContext, right, vn, inds, vi) | ||
if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) | ||
# Distribution is symbolic OR variable is. | ||
ctx.vn2var[vn] = vi[vn] | ||
ctx.vn2rights[vn] = right | ||
end | ||
|
||
return DynamicPPL.tilde_assume(ctx.ctx, right, vn, inds, vi) | ||
end | ||
|
||
function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn, inds, vi) | ||
if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) | ||
# Distribution is symbolic OR variable is. | ||
ctx.vn2var[vn] = vi[vn] | ||
ctx.vn2rights[vn] = right | ||
end | ||
|
||
return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) | ||
end | ||
|
||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: Make it more useful when working with symbolic observations. | ||
# observe | ||
function DynamicPPL.tilde_observe(ctx::SymbolicContext, right, left, vi) | ||
if Symbolic.issym(right) || Symbolic.issym(left) | ||
# TODO: implement | ||
end | ||
return DynamicPPL.tilde_observe(ctx.ctx, right, left, vi) | ||
end | ||
function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi) | ||
if Symbolic.issym(right) || Symbolic.issym(left) | ||
# TODO: implement | ||
end | ||
|
||
return DynamicPPL.tilde_observe(ctx.ctx, sampler, right, left, vi) | ||
end | ||
|
||
function DynamicPPL.assume(dist::Symbolics.Num, vn::VarName, vi) | ||
if !haskey(vi, vn) | ||
error("variable $vn does not exist") | ||
end | ||
r = vi[vn] | ||
return r, Bijectors.logpdf_with_trans(dist, r, DynamicPPL.istrans(vi, vn)) | ||
end | ||
|
||
function DynamicPPL.observe(right::Symbolics.Num, left, vi) | ||
return Distributions.loglikelihood(right, left) | ||
end | ||
|
||
# TODO: Implement `dot_tilde_*` methods. | ||
|
||
|
||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
using Bijectors: Bijectors | ||
using Symbolics: Symbolics | ||
using Symbolics.SymbolicUtils | ||
|
||
Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans) | ||
|
||
# Some predicates | ||
isdist(d) = (d isa Type) && (d <: Distribution) | ||
islogpdf(f::Function) = f === Distributions.logpdf | ||
islogpdf(x) = false | ||
|
||
# HACK: Apparently this is needed for disambiguiation. | ||
# TODO: Open issue. | ||
function Symbolics.:<ₑ(a::Real, b::Symbolics.Num) | ||
return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) | ||
end | ||
function Symbolics.:<ₑ(a::Symbolics.Num, b::Real) | ||
return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) | ||
end | ||
|
||
############# | ||
### Rules ### | ||
############# | ||
# HACK: We'll wrap rewriters to add back `Num`. This way we can get jacobians and whatnot at then end. | ||
const rmnum_rule = @rule (~x) => Symbolics.value(~x) | ||
const addnum_rule = @rule (~x) => Symbolics.Num(~x) | ||
|
||
# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful: | ||
const logpdf_rule = @rule (~x ~ ~d) => | ||
Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x)); | ||
const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d)) | ||
|
||
# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`. | ||
islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans | ||
islogpdf_with_trans(x) = false | ||
const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => | ||
logpdf(~dist, ~x) | ||
|
||
# Attempt to expand `logpdf` to get analytical expressions. | ||
# The idea is that `getlogpdf(d, args)` should return a method of the following signature: | ||
# | ||
# f(args..., x) | ||
# | ||
# which returns the logpdf. | ||
# HACK: this is very hacky but you get the idea | ||
import Distributions: StatsFuns | ||
function getlogpdf(d, args) | ||
replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf) | ||
|
||
dsym = Symbol(d) | ||
if haskey(replacements, dsym) | ||
return replacements[dsym] | ||
else | ||
return d | ||
end | ||
end | ||
Comment on lines
+47
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea behind all this was to replace the "non-tracable" Also, this is super messy and probably not the greatest; I blame it on the fact that I had no idea what I was doing:) |
||
|
||
const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => | ||
getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x)) | ||
|
||
################# | ||
### Rewriters ### | ||
################# | ||
# TODO: these should probably be instantiated when needed, rather than here. | ||
const analytic_rw = Rewriters.Postwalk( | ||
Rewriters.Chain(( | ||
rmnum_rule, # 0. Remove `Num` so we're only working stuff from `SymbolicUtils.jl`. | ||
logpdf_with_trans_rule, # 1. Replace `logpdf_with_trans` with `logpdf`. | ||
analytic_rule, # 2. Attempt to replace `logpdf` with analytic expression. | ||
)) | ||
) | ||
|
||
# So we add back `Num` to all terms to allow differentiation. | ||
const rmnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(rmnum_rule)) | ||
const addnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(addnum_rule)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vi
(the trace struct) now contains a symbolic representation of the logjoint retrievable throughgetlogp(vi)
.