Skip to content

Making use of Symbolics.jl/SymbolicUtils.jl #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand All @@ -19,5 +20,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9"
ChainRulesCore = "0.9.7, 0.10"
Distributions = "0.23.8, 0.24, 0.25"
MacroTools = "0.5.6"
Symbolics = "1"
ZygoteRules = "0.2"
julia = "1.3"
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("symbolic/Symbolic.jl")
include("submodel_macro.jl")

end # module
82 changes: 82 additions & 0 deletions src/symbolic/Symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module Symbolic

using ..DynamicPPL: DynamicPPL
import ..DynamicPPL:
Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext

using Random: Random
using Bijectors: Bijectors
using Distributions
using Symbolics: Symbolics
import Symbolics: SymbolicUtils

issym(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = true
issym(x) = false

include("rules.jl")
include("contexts.jl")

# Allow `num` to appear on the RHS of `~`.
DynamicPPL.check_tilde_rhs(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = x

symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...)
function symbolize(
rng::Random.AbstractRNG,
m::Model,
vi::VarInfo=VarInfo(m);
spl=SampleFromPrior(),
ctx=DefaultContext(),
include_data=false,
)
m(rng, vi, spl, ctx)
θ_orig = vi[spl]

# Symbolic `logpdf` for fixed observations.
# TODO: don't `collect` once symbolic arrays are mature enough.
Symbolics.@variables θ[1:length(θ_orig)]
vi = VarInfo{Real}(vi, spl, θ)
m(vi, ctx)

return vi, θ
end
Comment on lines +23 to +41
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Execute model once to get shape of latent variables.
  2. Construct symbolic variables.
  3. Execute model on symbolic variables.
  4. vi (the trace struct) now contains a symbolic representation of the logjoint retrievable through getlogp(vi).


function dependencies(ctx::SymbolicContext, vn::VarName)
right = ctx.vn2rights[vn]
r = Symbolics.value(right)

if !issym(r)
# No dependencies.
return []
end

args = SymbolicUtils.arguments(r)
return mapreduce(vcat, args) do a
Symbolics.get_variables(a)
end
end
function dependencies(ctx::SymbolicContext, symbolic=false)
vn2var = ctx.vn2var
var2vn = Dict(values(vn2var) .=> keys(vn2var))
return Dict(
(symbolic ? vn2var[vn] : vn) =>
map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) for
vn in keys(ctx.vn2var)
)
end

function dependencies(m::Model, symbolic=false)
ctx = SymbolicContext(DefaultContext())
vi = symbolize(m, VarInfo(m); ctx=ctx)

return dependencies(ctx, symbolic)
end
Comment on lines +67 to +72
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uses "contextual" dispatch to overload the corresponding *tilde_ statements, stroing the mapping from symbolic variable θ[i] to VarName.


function symbolic_logp(m::Model)
vi, θ = symbolize(m)
lp = DynamicPPL.getlogp(vi)
lp_analytic = analytic_rw(Symbolics.value(lp))
lp_analytic_num = addnum_rw(lp_analytic)

return lp_analytic_num, θ
end
end
62 changes: 62 additions & 0 deletions src/symbolic/contexts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
struct SymbolicContext{Ctx} <: DynamicPPL.AbstractContext
ctx::Ctx
vn2var::Dict
vn2rights::Dict
end
SymbolicContext() = SymbolicContext(DefaultContext())
SymbolicContext(ctx) = SymbolicContext(ctx, Dict(), Dict())

Symbolics.@register Distributions.loglikelihood(dist, x)

# assume
function DynamicPPL.tilde_assume(ctx::SymbolicContext, right, vn, inds, vi)
if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn]))
# Distribution is symbolic OR variable is.
ctx.vn2var[vn] = vi[vn]
ctx.vn2rights[vn] = right
end

return DynamicPPL.tilde_assume(ctx.ctx, right, vn, inds, vi)
end

function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn, inds, vi)
if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn]))
# Distribution is symbolic OR variable is.
ctx.vn2var[vn] = vi[vn]
ctx.vn2rights[vn] = right
end

return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi)
end

# TODO: Make it more useful when working with symbolic observations.
# observe
function DynamicPPL.tilde_observe(ctx::SymbolicContext, right, left, vi)
if Symbolic.issym(right) || Symbolic.issym(left)
# TODO: implement
end
return DynamicPPL.tilde_observe(ctx.ctx, right, left, vi)
end
function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi)
if Symbolic.issym(right) || Symbolic.issym(left)
# TODO: implement
end

return DynamicPPL.tilde_observe(ctx.ctx, sampler, right, left, vi)
end

function DynamicPPL.assume(dist::Symbolics.Num, vn::VarName, vi)
if !haskey(vi, vn)
error("variable $vn does not exist")
end
r = vi[vn]
return r, Bijectors.logpdf_with_trans(dist, r, DynamicPPL.istrans(vi, vn))
end

function DynamicPPL.observe(right::Symbolics.Num, left, vi)
return Distributions.loglikelihood(right, left)
end

# TODO: Implement `dot_tilde_*` methods.


75 changes: 75 additions & 0 deletions src/symbolic/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using Bijectors: Bijectors
using Symbolics: Symbolics
using Symbolics.SymbolicUtils

Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans)

# Some predicates
isdist(d) = (d isa Type) && (d <: Distribution)
islogpdf(f::Function) = f === Distributions.logpdf
islogpdf(x) = false

# HACK: Apparently this is needed for disambiguiation.
# TODO: Open issue.
function Symbolics.:<ₑ(a::Real, b::Symbolics.Num)
return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b))
end
function Symbolics.:<ₑ(a::Symbolics.Num, b::Real)
return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b))
end

#############
### Rules ###
#############
# HACK: We'll wrap rewriters to add back `Num`. This way we can get jacobians and whatnot at then end.
const rmnum_rule = @rule (~x) => Symbolics.value(~x)
const addnum_rule = @rule (~x) => Symbolics.Num(~x)

# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful:
const logpdf_rule = @rule (~x ~ ~d) =>
Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x));
const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d))

# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`.
islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans
islogpdf_with_trans(x) = false
const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) =>
logpdf(~dist, ~x)

# Attempt to expand `logpdf` to get analytical expressions.
# The idea is that `getlogpdf(d, args)` should return a method of the following signature:
#
# f(args..., x)
#
# which returns the logpdf.
# HACK: this is very hacky but you get the idea
import Distributions: StatsFuns
function getlogpdf(d, args)
replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf)

dsym = Symbol(d)
if haskey(replacements, dsym)
return replacements[dsym]
else
return d
end
end
Comment on lines +47 to +56
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea behind all this was to replace the "non-tracable" logpdf impls from Distributions.jl with traceable impls from StatsFuns.jl. After #292 we could probably avoid this by simply using MeasureTheory.jl instead:)

Also, this is super messy and probably not the greatest; I blame it on the fact that I had no idea what I was doing:)


const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) =>
getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x))

#################
### Rewriters ###
#################
# TODO: these should probably be instantiated when needed, rather than here.
const analytic_rw = Rewriters.Postwalk(
Rewriters.Chain((
rmnum_rule, # 0. Remove `Num` so we're only working stuff from `SymbolicUtils.jl`.
logpdf_with_trans_rule, # 1. Replace `logpdf_with_trans` with `logpdf`.
analytic_rule, # 2. Attempt to replace `logpdf` with analytic expression.
))
)

# So we add back `Num` to all terms to allow differentiation.
const rmnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(rmnum_rule))
const addnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(addnum_rule))
5 changes: 5 additions & 0 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
)
end

function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T}
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
return VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi)))
end

function VarInfo(
rng::Random.AbstractRNG,
model::Model,
Expand Down