diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..32cfd6136 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -120,12 +120,14 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") +include("bijectors.jl") include("submodel_macro.jl") end # module diff --git a/src/bijectors.jl b/src/bijectors.jl new file mode 100644 index 000000000..8a1701455 --- /dev/null +++ b/src/bijectors.jl @@ -0,0 +1,32 @@ +@generated function _bijector(md::NamedTuple{names}; tuplify=false) where {names} + expr = Expr(:tuple) + for n in names + e = quote + if length(md.$n.dists) == 1 && + md.$n.dists[1] isa $(Distributions.UnivariateDistribution) + $(Bijectors).bijector(md.$n.dists[1]) + elseif tuplify + $(Bijectors.Stacked)( + map($(Bijectors).bijector, tuple(md.$n.dists...)), md.$n.ranges + ) + else + $(Bijectors.Stacked)(map($(Bijectors).bijector, md.$n.dists), md.$n.ranges) + end + end + push!(expr.args, e) + end + + return quote + bs = NamedTuple{$names}($expr) + return $(Bijectors).NamedBijector(bs) + end +end + +""" + bijector(varinfo::VarInfo; tuplify=false) + +Returns a `NamedBijector` which can transform different variants of `varinfo`. + +If `tuplify` is true, then a type-stable bijector will be returned. +""" +Bijectors.bijector(vi::TypedVarInfo; kwargs...) = _bijector(vi.metadata; kwargs...) diff --git a/src/compiler.jl b/src/compiler.jl index bef7d11c2..8450d782a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -84,11 +84,17 @@ end function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) + modelinfo_logjoint = deepcopy(modelinfo) # Generate main body modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) - return build_output(modelinfo, linenumbernode) + # Generate logjoint + modelinfo_logjoint[:body] = generate_mainbody_logdensity( + mod, modelinfo_logjoint[:modeldef][:body], warn + ) + + return build_output(modelinfo, modelinfo_logjoint, linenumbernode) end """ @@ -339,14 +345,7 @@ hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing( hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true hasmissing(T::Type) = false -""" - build_output(modelinfo, linenumbernode) - -Builds the output expression. -""" -function build_output(modelinfo, linenumbernode) - ## Build the anonymous evaluator from the user-provided model definition. - +function build_evaluator(modelinfo) # Remove the name. evaluatordef = deepcopy(modelinfo[:modeldef]) delete!(evaluatordef, :name) @@ -369,6 +368,55 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. evaluatordef[:body] = modelinfo[:body] + return evaluatordef +end + +function build_logjoint(modelinfo) + # Remove the name. + def = deepcopy(modelinfo[:modeldef]) + def[:name] = :logjoint + + # Add the internal arguments to the user-specified arguments (positional + keywords). + @gensym T + def[:args] = vcat( + [ + :(__model__::$(DynamicPPL.Model)), + :(__sampler__::$(DynamicPPL.AbstractSampler)), + :(__context__::$(DynamicPPL.AbstractContext)), + :(__variables__), + ], + modelinfo[:allargs_exprs], + [Expr(:kw, :(::Type{$T}), :Float64), ] + ) + + # Add the type-parameter. + def[:whereparams] = (def[:whereparams]..., T) + + # Delete the keyword arguments. + def[:kwargs] = [] + + # Replace the user-provided function body with the version created by DynamicPPL. + def[:body] = quote + __lp__ = zero($T) + $(modelinfo[:body]) + return __lp__ + end + + return def +end + +""" + build_output(modelinfo, linenumbernode) + +Builds the output expression. +""" +function build_output(modelinfo, modelinfo_logjoint, linenumbernode) + ## Build logjoint. + logjointdef = build_logjoint(modelinfo_logjoint) + + ## Build the anonymous evaluator from the user-provided model definition. + evaluatordef = build_evaluator(modelinfo) + ## Build the model function. # Extract the named tuple expression of all arguments and the default values. @@ -378,18 +426,20 @@ function build_output(modelinfo, linenumbernode) # Update the function body of the user-specified model. # We use a name for the anonymous evaluator that does not conflict with other variables. modeldef = modelinfo[:modeldef] - @gensym evaluator + @gensym evaluator logjoint # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) + $logjoint = $(MacroTools.combinedef(logjointdef)) return $(DynamicPPL.Model)( $(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple, $defaults_namedtuple, + $logjoint, ) end @@ -447,3 +497,150 @@ end floatof(::Type{T}) where {T<:Real} = typeof(one(T) / one(T)) floatof(::Type) = Real # fallback if type inference failed + +##################### +### `logdensity` ### +##################### +function generate_mainbody_logdensity(mod, expr, warn) + return generate_mainbody_logdensity!(mod, Symbol[], expr, warn) +end + +generate_mainbody_logdensity!(mod, found, x, warn) = x +function generate_mainbody_logdensity!(mod, found, sym::Symbol, warn) + if sym in DEPRECATED_INTERNALNAMES + newsym = Symbol(:_, sym, :__) + Base.depwarn( + "internal variable `$sym` is deprecated, use `$newsym` instead.", + :generate_mainbody_logdensity!, + ) + return generate_mainbody_logdensity!(mod, found, newsym, warn) + end + + if warn && sym in INTERNALNAMES && sym ∉ found + @warn "you are using the internal variable `$sym`" + push!(found, sym) + end + + return sym +end +function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) + # Do not touch interpolated expressions + expr.head === :$ && return expr.args[1] + + # If it's a macro, we expand it + if Meta.isexpr(expr, :macrocall) + return generate_mainbody_logdensity!( + mod, found, macroexpand(mod, expr; recursive=true), warn + ) + end + + # If it's a return, we instead return `__lp__`. + if Meta.isexpr(expr, :return) + returnbody = Expr( + :block, + map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., + ) + return :($(returnbody); return __lp__) + end + + # Modify dotted tilde operators. + args_dottilde = getargs_dottilde(expr) + if args_dottilde !== nothing + L, R = args_dottilde + left = generate_mainbody_logdensity!(mod, found, L, warn) + return Base.remove_linenums!( + generate_dot_tilde_logdensity( + left, generate_mainbody_logdensity!(mod, found, R, warn) + ), + ) + end + + # Modify tilde operators. + args_tilde = getargs_tilde(expr) + if args_tilde !== nothing + L, R = args_tilde + left = generate_mainbody_logdensity!(mod, found, L, warn) + return Base.remove_linenums!( + generate_tilde_logdensity( + left, generate_mainbody_logdensity!(mod, found, R, warn) + ), + ) + end + + return Expr( + expr.head, + map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., + ) +end + +function generate_tilde_logdensity(left, right) + # If the LHS is a literal, it is always an observation + if !(left isa Symbol || left isa Expr) + return quote + __lp__ += $(DynamicPPL.tilde_observe)( + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + __varinfo__, + ) + end + end + + @gensym vn inds isassumption + + # If it's not present in args of the model, we need to extract it from `__variables__`. + return quote + $vn = $(varname(left)) + $inds = $(vinds(left)) + $isassumption = $(DynamicPPL.isassumption(left)) + if $isassumption + $(vsym(left)) = __variables__.$(vsym(left)) + end + __lp__ += $(DynamicPPL.tilde_observe)( + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, + $inds, + nothing, + ) + end +end + +function generate_dot_tilde_logdensity(left, right) + # If the LHS is a literal, it is always an observation + if !(left isa Symbol || left isa Expr) + return quote + __lp__ += $(DynamicPPL.dot_tilde_observe)( + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + __varinfo__, + ) + end + end + + # Otherwise it is determined by the model or its value, + # if the LHS represents an observation + @gensym vn inds isassumption + return quote + $vn = $(varname(left)) + $inds = $(vinds(left)) + $isassumption = $(DynamicPPL.isassumption(left)) || $left === missing + if $isassumption + $(vsym(left)) = __variables__.$(vsym(left)) + end + __lp__ += $(DynamicPPL.dot_tilde_observe)( + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, + $inds, + nothing, + ) + end +end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index afc5e4da3..d0c8a3e4a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -56,6 +56,27 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end +function tilde_assume( + rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo{<:NamedTuple} +) + value = _getindex(getfield(vi.θ, getsym(vn)), inds) + + # Contexts which have different behavior between `assume` and `observe` we need + # to replace with `DefaultContext` here, otherwise the observation-only + # behavior will be applied to `assume`. + # FIXME: The below doesn't necessarily work for nested contexts, e.g. if `ctx.ctx.ctx isa PriorContext`. + # This is a broader issue though, which should probably be fixed by introducing a `WrapperContext`. + if ctx.ctx isa PriorContext + tilde_observe(LikelihoodContext(), sampler, right, value, vn, inds, vi) + elseif ctx.ctx isa LikelihoodContext + # Need to make it so that this isn't computed. + tilde_observe(PriorContext(), sampler, right, value, vn, inds, vi) + else + tilde_observe(ctx, sampler, right, value, vn, inds, vi) + end + return value +end + function _tilde(rng, sampler, right, vn::VarName, vi) return assume(rng, sampler, right, vn, vi) end @@ -76,6 +97,9 @@ end function tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) end +function tilde(ctx::EvaluationContext, sampler, right, left, vi) + return tilde(ctx.ctx, sampler, right, left, vi) +end function tilde(ctx::PrefixContext, sampler, right, left, vi) return tilde(ctx.ctx, sampler, right, left, vi) end @@ -91,8 +115,12 @@ and indices; if needed, these can be accessed through this function, though. """ function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) logp = tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end """ @@ -105,8 +133,12 @@ Falls back to `tilde(ctx, sampler, right, left, vi)`. """ function tilde_observe(ctx, sampler, right, left, vi) logp = tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end _tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) @@ -144,7 +176,7 @@ end function observe( spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi ) - increment_num_produce!(vi) + (vi isa VarInfo) && increment_num_produce!(vi) return Distributions.loglikelihood(dist, value) end @@ -195,6 +227,34 @@ function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) return value end +function dot_tilde_assume( + rng, + ctx::EvaluationContext, + sampler, + right, + left, + vn, + inds, + vi::SimpleVarInfo{<:NamedTuple}, +) + value = _getindex(getfield(vi.θ, getsym(vn)), inds) + + # Contexts which have different behavior between `assume` and `observe` we need + # to replace with `DefaultContext` here, otherwise the observation-only + # behavior will be applied to `assume`. + # FIXME: The below doesn't necessarily work for nested contexts, e.g. if `ctx.ctx.ctx isa PriorContext`. + # This is a broader issue though, which should probably be fixed by introducing a `WrapperContext`. + if ctx.ctx isa PriorContext + dot_tilde_observe(LikelihoodContext(), sampler, right, value, vn, inds, vi) + elseif ctx.ctx isa LikelihoodContext + # Need to make it so that this isn't computed. + dot_tilde_observe(PriorContext(), sampler, right, value, vn, inds, vi) + else + dot_tilde_observe(ctx.ctx, sampler, right, value, vn, inds, vi) + end + return value +end + function get_vns_and_dist(dist::NamedDist, var, vn::VarName) return get_vns_and_dist(dist.dist, var, dist.name) end @@ -360,6 +420,9 @@ end function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) end +function dot_tilde(ctx::EvaluationContext, sampler, right, left, vi) + return dot_tilde(ctx.ctx, sampler, right, left, vi) +end """ dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) @@ -372,8 +435,12 @@ name and indices; if needed, these can be accessed through this function, though """ function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi) logp = dot_tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end """ @@ -386,8 +453,12 @@ Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. """ function dot_tilde_observe(ctx, sampler, right, left, vi) logp = dot_tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end function _dot_tilde(sampler, right, left::AbstractArray, vi) @@ -409,7 +480,7 @@ function dot_observe( value::AbstractMatrix, vi, ) - increment_num_produce!(vi) + vi isa VarInfo && increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) @@ -420,7 +491,7 @@ function dot_observe( value::AbstractArray, vi, ) - increment_num_produce!(vi) + vi isa VarInfo && increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) @@ -431,7 +502,7 @@ function dot_observe( value::AbstractArray, vi, ) - increment_num_produce!(vi) + vi isa VarInfo && increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) diff --git a/src/contexts.jl b/src/contexts.jl index 4d4f30bdc..757d73e7a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -81,3 +81,8 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +struct EvaluationContext{Ctx} <: AbstractContext + ctx::Ctx +end +EvaluationContext() = EvaluationContext(DefaultContext()) diff --git a/src/model.jl b/src/model.jl index 7189b590e..c83d1f57b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,12 +32,13 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Logjoint} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + logjoint::Logjoint """ Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) @@ -50,10 +51,9 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults - ) + logjoint::Logjoint = identity + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Logjoint} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Logjoint}(name, f, args, defaults, logjoint) end end @@ -67,10 +67,14 @@ Default arguments `defaults` are used internally when constructing instances of model with different arguments. """ @generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() + name::Symbol, + f::F, + args::NamedTuple{argnames,Targs}, + defaults::NamedTuple = NamedTuple(), + logjoint = identity ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) + return :(Model{$missings}(name, f, args, defaults, logjoint)) end """ @@ -187,6 +191,13 @@ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogp(varinfo) end +function logjoint(model::Model, θ) + ctx = EvaluationContext(DefaultContext()) + vi = SimpleVarInfo(θ) + model(vi, SampleFromPrior(), ctx) + return getlogp(vi) +end + """ logprior(model::Model, varinfo::AbstractVarInfo) @@ -199,6 +210,13 @@ function logprior(model::Model, varinfo::AbstractVarInfo) return getlogp(varinfo) end +function logprior(model::Model, θ) + ctx = EvaluationContext(PriorContext()) + vi = SimpleVarInfo(θ) + model(vi, SampleFromPrior(), ctx) + return getlogp(vi) +end + """ loglikelihood(model::Model, varinfo::AbstractVarInfo) @@ -211,6 +229,13 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getlogp(varinfo) end +function Distributions.loglikelihood(model::Model, θ) + ctx = EvaluationContext(LikelihoodContext()) + vi = SimpleVarInfo(θ) + model(vi, SampleFromPrior(), ctx) + return getlogp(vi) +end + """ generated_quantities(model::Model, chain::AbstractChains) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl new file mode 100644 index 000000000..6d9fb72fd --- /dev/null +++ b/src/simple_varinfo.jl @@ -0,0 +1,21 @@ +struct SimpleVarInfo{NT,T} <: AbstractVarInfo + θ::NT + logp::Base.RefValue{T} +end + +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) + +function setlogp!(vi::SimpleVarInfo, logp) + vi.logp[] = logp + return vi +end + +function acclogp!(vi::SimpleVarInfo, logp) + vi.logp[] += logp + return vi +end + +getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ +getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ +getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ