diff --git a/src/compiler.jl b/src/compiler.jl index c67da6f95..c3f77af4b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -240,6 +240,17 @@ function unwrap_right_left_vns( left::AbstractArray, vn::VarName, ) + # Need to check that we don't end up double-counting log-probabilities. + combined_axes = Broadcast.combine_axes(left, right) + if prod(length, combined_axes) > length(left) + throw( + ArgumentError( + "a `.~` statement cannot result in a broadcasted expression with more elements than the left-hand side", + ), + ) + end + + # Extract the sub-varnames. vns = map(CartesianIndices(left)) do i return Accessors.IndexLens(Tuple(i)) ∘ vn end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 462012676..c53daad3e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -331,6 +331,72 @@ function dot_tilde_assume( ) end +# `FixedContext` +function dot_tilde_assume(context::FixedContext, right, left, vns, vi) + if !has_fixed_symbol(context, first(vns)) + # Defer to `childcontext`. + return dot_tilde_assume(childcontext(context), right, left, vns, vi) + end + + # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. + # We _might_ also have some of the variables fixed, but not all. + logp = 0 + # TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`. + # If the `Symbol` is not present, we can just skip this check completely. Such a check can + # then be compiled away in cases where the `Symbol` is not present. + left_bc = Broadcast.broadcastable(left) + right_bc = Broadcast.broadcastable(right) + for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...) + for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...) + vn = vns[I_left...] + if hasfixed(context, vn) + left[I_left...] = getfixed(context, vn) + else + # Defer to `tilde_assume`. + left[I_left...], logp_inner, vi = tilde_assume( + childcontext(context), right_bc[I_right...], vn, vi + ) + logp += logp_inner + end + end + end + + return left, logp, vi +end + +function dot_tilde_assume( + rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi +) + if !has_fixed_symbol(context, first(vns)) + # Defer to `childcontext`. + return dot_tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi) + end + # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. + # So we need to check each of the vns. + logp = 0 + # TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`. + # If the `Symbol` is not present, we can just skip this check completely. Such a check can + # then be compiled away in cases where the `Symbol` is not present. + left_bc = Broadcast.broadcastable(left) + right_bc = Broadcast.broadcastable(right) + for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...) + for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...) + vn = vns[I_left...] + if hasfixed(context, vn) + left[I_left...] = getfixed(context, vn) + else + # Defer to `tilde_assume`. + left[I_left...], logp_inner, vi = tilde_assume( + rng, childcontext(context), sampler, right_bc[I_right...], vn, vi + ) + logp += logp_inner + end + end + end + + return left, logp, vi +end + """ dot_tilde_assume!!(context, right, left, vn, vi) diff --git a/src/contexts.jl b/src/contexts.jl index b337e4750..440687fa9 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -529,6 +529,13 @@ NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +has_fixed_symbol(context::FixedContext, vn::VarName) = has_symbol(context.values, vn) + +has_symbol(d::AbstractDict, vn::VarName) = haskey(d, vn) +@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names,sym} + return sym in names +end + """ hasfixed(context::AbstractContext, vn::VarName) diff --git a/test/compiler.jl b/test/compiler.jl index 4dc9fcb24..56c849f54 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -720,4 +720,13 @@ module Issue537 end res = model() @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end + + @testset "invalid .~ expressions" begin + @model function demo_with_invalid_dot_tilde() + m = Matrix{Float64}(undef, 1, 2) + return m .~ [Normal(); Normal()] + end + + @test_throws ArgumentError demo_with_invalid_dot_tilde()() + end end