Skip to content

[WIP] Add pure logdensity to Model #242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed

[WIP] Add pure logdensity to Model #242

wants to merge 18 commits into from

Conversation

torfjelde
Copy link
Member

I'm currently working on a project where I need to "avoid" `VarInfo` due to inefficiencies + I only need HMC samplers for a very high-dimensional latent space. Therefore I've used the following workflow:

  1. Define naive Turing.jl model.
  2. Use this to get stuff like parameters, bijector, etc.
  3. Implement a make_logjoint which takes the model constructor as an argument and returns a logjoint which I've implemented by hand.
    • This impl is tested against the "naive" Turing.jl implementation for correctness.
  4. Use this hand-coded version.

But I thought "Why not just add this to Model? Would be super-useful for a lot of standard samplers." So here we are. Is very early stage though, so would love some feedback!

using DynamicPPL
using Distributions
using Test
using BenchmarkTools

@model function demo(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, s)
    for i in eachindex(x)
        x[i] ~ Normal(m, s)
    end
end

ctx = DefaultContext()
spl = SampleFromPrior()

m = demo(fill(missing, 2))
# Works ✓
θ_nt = (s = 1.0, m = 0.0, x = [1.0, 1.0])
m.logπ(m, spl, ctx, θ_nt, Float32, m.args...)
-4.5595910222777984
# Breaks since missing `x` ✓
θ_nt = (s = 1.0, m = 0.0)
m.logπ(m, spl, ctx, θ_nt, Float32, m.args...)
type NamedTuple has no field x

Stacktrace:
 [1] getproperty(x::NamedTuple{(:s, :m), Tuple{Float64, Float64}}, f::Symbol)
   @ Base ./Base.jl:33
 [2] macro expansion
   @ ./In[14]:9 [inlined]
 [3] (::var"#logπ#14")(__model__::Model{var"#12#13", (:x,), (), (), Tuple{Vector{Missing}}, Tuple{}, var"#logπ#14"}, __sampler__::SampleFromPrior, __context__::DefaultContext, __variables__::NamedTuple{(:s, :m), Tuple{Float64, Float64}}, T#398::Type, x::Vector{Missing})
   @ Main ~/Projects/public/DynamicPPL.jl/src/compiler.jl:358
 [4] top-level scope
   @ In[15]:3
 [5] eval
   @ ./boot.jl:360 [inlined]
 [6] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base ./loading.jl:1094
# Works since `x` is now given ✓
m = demo([1.0, 1.0])
m.logπ(m, spl, ctx, θ_nt, Float32, m.args...)
-4.5595910222777984
# Impl by hand
function g(θ, x)
    s = θ.s
    m = θ.m

    lp = logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, s), m)

    for i in eachindex(x)
        lp += logpdf(Normal(m, s), x[i])
    end
    return lp
end

@test g(θ_nt, [1.0, 1.0])  m.logπ(m, spl, ctx, θ_nt, Float64, m.args...)
Test Passed

Why is performance worse?

WAZ GOING ON?! Might the nothing in the tilde_observe not being propagated properly?

@benchmark $g($θ_nt, $([1.0, 1.0]))
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     44.812 ns (0.00% GC)
  median time:      45.023 ns (0.00% GC)
  mean time:        46.820 ns (0.00% GC)
  maximum time:     188.450 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     988
@benchmark $(m.logπ)($m, $spl, $ctx, $θ_nt, $(Float64), $(m.args...))
BenchmarkTools.Trial: 
  memory estimate:  704 bytes
  allocs estimate:  22
  --------------
  minimum time:     1.356 μs (0.00% GC)
  median time:      1.481 μs (0.00% GC)
  mean time:        1.643 μs (2.72% GC)
  maximum time:     448.430 μs (99.51% GC)
  --------------
  samples:          10000
  evals/sample:     10
@code_typed g(θ_nt, [1.0, 1.0])
CodeInfo(
1 ── %1   = Base.getfield(θ, :s)::Float64
│    %2   = Base.getfield(θ, :m)::Float64
└───        goto #3 if not true
2 ──        nothing::Nothing
3 ┄─ %5   = Base.sle_int(1, 1)::Bool
└───        goto #5 if not %5
4 ── %7   = Base.sle_int(1, 0)::Bool
└───        goto #6
5 ──        nothing::Nothing
6 ┄─ %10  = φ (#4 => %7, #5 => false)::Bool
└───        goto #8 if not %10
7 ──        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
8 ──        goto #9
9 ──        goto #10
10 ─        goto #11
11 ─        goto #12
12 ─        goto #13
13 ─ %19  = Base.sle_int(1, 1)::Bool
└───        goto #15 if not %19
14 ─ %21  = Base.sle_int(1, 0)::Bool
└───        goto #16
15 ─        nothing::Nothing
16 ┄ %24  = φ (#14 => %21, #15 => false)::Bool
└───        goto #18 if not %24
17 ─        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
18 ─        goto #19
19 ─        goto #20
20 ─        goto #21
21 ─        goto #22
22 ─        goto #23
23 ─        goto #24
24 ─ %34  = %new(InverseGamma{Float64}, $(QuoteNode(Gamma{Float64}(α=2.0, θ=0.3333333333333333))), 3.0)::InverseGamma{Float64}
└───        goto #25
25 ─        goto #26
26 ─        goto #27
27 ─        goto #28
28 ─ %39  = invoke Main.logpdf(%34::InverseGamma{Float64}, %1::Float64)::Float64
│    %40  = Base.lt_float(%1, 0.0)::Bool
└───        goto #30 if not %40
29 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %1::Float64)::Union{}
└───        unreachable
30 ─ %44  = Base.Math.sqrt_llvm(%1)::Float64
└───        goto #31
31 ─        goto #35 if not true
32 ─ %47  = Base.le_float(0.0, %44)::Bool
│    %48  = Base.not_int(%47)::Bool
└───        goto #34 if not %48
33 ─ %50  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %51  = Distributions.ArgumentError(%50)::Any
│           Distributions.throw(%51)::Union{}
└───        unreachable
34 ─        nothing::Nothing
35 ┄ %55  = %new(Normal{Float64}, 0.0, %44)::Normal{Float64}
└───        goto #36
36 ─        goto #37
37 ─        goto #38
38 ─ %59  = invoke Main.logpdf(%55::Normal{Float64}, %2::Float64)::Float64
│    %60  = Base.add_float(%39, %59)::Float64
│    %61  = Base.arraysize(x, 1)::Int64
│    %62  = Base.slt_int(%61, 0)::Bool
│    %63  = Base.ifelse(%62, 0, %61)::Int64
│    %64  = Base.slt_int(%63, 1)::Bool
└───        goto #40 if not %64
39 ─        goto #41
40 ─        goto #41
41 ┄ %68  = φ (#39 => true, #40 => false)::Bool
│    %69  = φ (#40 => 1)::Int64
│    %70  = φ (#40 => 1)::Int64
│    %71  = Base.not_int(%68)::Bool
└───        goto #56 if not %71
42 ┄ %73  = φ (#41 => %69, #55 => %102)::Int64
│    %74  = φ (#41 => %70, #55 => %103)::Int64
│    %75  = φ (#41 => %60, #55 => %96)::Float64
│    %76  = Base.lt_float(%1, 0.0)::Bool
└───        goto #44 if not %76
43 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %1::Float64)::Union{}
└───        unreachable
44 ─ %80  = Base.Math.sqrt_llvm(%1)::Float64
└───        goto #45
45 ─        goto #49 if not true
46 ─ %83  = Base.le_float(0.0, %80)::Bool
│    %84  = Base.not_int(%83)::Bool
└───        goto #48 if not %84
47 ─ %86  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %87  = Distributions.ArgumentError(%86)::Any
│           Distributions.throw(%87)::Union{}
└───        unreachable
48 ─        nothing::Nothing
49 ┄ %91  = %new(Normal{Float64}, %2, %80)::Normal{Float64}
└───        goto #50
50 ─        goto #51
51 ─ %94  = Base.arrayref(true, x, %73)::Float64
│    %95  = invoke Main.logpdf(%91::Normal{Float64}, %94::Float64)::Float64
│    %96  = Base.add_float(%75, %95)::Float64
│    %97  = (%74 === %63)::Bool
└───        goto #53 if not %97
52 ─        goto #54
53 ─ %100 = Base.add_int(%74, 1)::Int64
└───        goto #54
54 ┄ %102 = φ (#53 => %100)::Int64
│    %103 = φ (#53 => %100)::Int64
│    %104 = φ (#52 => true, #53 => false)::Bool
│    %105 = Base.not_int(%104)::Bool
└───        goto #56 if not %105
55 ─        goto #42
56 ┄ %108 = φ (#54 => %96, #41 => %60)::Float64
└───        return %108
) => Float64
@code_typed m.logπ(m, spl, ctx, θ_nt, Float64, m.args...)
CodeInfo(
1 ──        goto #3 if not true
2 ──        nothing::Nothing
3 ┄─ %3   = Base.sle_int(1, 1)::Bool
└───        goto #5 if not %3
4 ── %5   = Base.sle_int(1, 0)::Bool
└───        goto #6
5 ──        nothing::Nothing
6 ┄─ %8   = φ (#4 => %5, #5 => false)::Bool
└───        goto #8 if not %8
7 ──        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
8 ──        goto #9
9 ──        goto #10
10 ─        goto #11
11 ─        goto #12
12 ─        goto #13
13 ─ %17  = Base.sle_int(1, 1)::Bool
└───        goto #15 if not %17
14 ─ %19  = Base.sle_int(1, 0)::Bool
└───        goto #16
15 ─        nothing::Nothing
16 ┄ %22  = φ (#14 => %19, #15 => false)::Bool
└───        goto #18 if not %22
17 ─        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
18 ─        goto #19
19 ─        goto #20
20 ─        goto #21
21 ─        goto #22
22 ─        goto #23
23 ─        goto #24
24 ─ %32  = %new(InverseGamma{Float64}, $(QuoteNode(Gamma{Float64}(α=2.0, θ=0.3333333333333333))), 3.0)::InverseGamma{Float64}
└───        goto #25
25 ─        goto #26
26 ─        goto #27
27 ─        goto #28
28 ─        invoke Core.TypeVar(Symbol("#s69")::Symbol, Distribution::Any)::Core.Compiler.PartialTypeVar(var"#s69"<:Distribution, true, true)
│           nothing::Nothing
│           nothing::Nothing
│    %40  = π (true, Core.Const(true))
└───        goto #30 if not %40
29 ─ %42  = Base.getfield(__variables__, :s)::Float64
30 ┄ %43  = φ (#29 => %42, #28 => #undef)::Core.Compiler.MaybeUndef(Float64)
│    %44  = invoke Distributions.logpdf(%32::InverseGamma{Float64}, %43::Float64)::Float64
│    %45  = Base.add_float(0.0, %44)::Float64
│    %46  = Base.lt_float(%43, 0.0)::Bool
└───        goto #32 if not %46
31 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %43::Float64)::Union{}
└───        unreachable
32 ─ %50  = Base.Math.sqrt_llvm(%43)::Float64
└───        goto #33
33 ─        goto #37 if not true
34 ─ %53  = Base.le_float(0.0, %50)::Bool
│    %54  = Base.not_int(%53)::Bool
└───        goto #36 if not %54
35 ─ %56  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %57  = Distributions.ArgumentError(%56)::Any
│           Distributions.throw(%57)::Union{}
└───        unreachable
36 ─        nothing::Nothing
37 ┄ %61  = %new(Normal{Float64}, 0.0, %50)::Normal{Float64}
└───        goto #38
38 ─        goto #39
39 ─        goto #40
40 ─        invoke Core.TypeVar(Symbol("#s68")::Symbol, Distribution::Any)::Core.Compiler.PartialTypeVar(var"#s68"<:Distribution, true, true)
│           nothing::Nothing
│           nothing::Nothing
│    %68  = π (true, Core.Const(true))
└───        goto #42 if not %68
41 ─ %70  = Base.getfield(__variables__, :m)::Float64
42 ┄ %71  = φ (#41 => %70, #40 => #undef)::Core.Compiler.MaybeUndef(Float64)
│    %72  = invoke Distributions.logpdf(%61::Normal{Float64}, %71::Float64)::Float64
│    %73  = Base.add_float(%45, %72)::Float64
│    %74  = Base.arraysize(x@_7, 1)::Int64
│    %75  = Base.slt_int(%74, 0)::Bool
│    %76  = Base.ifelse(%75, 0, %74)::Int64
│    %77  = Base.slt_int(%76, 1)::Bool
└───        goto #44 if not %77
43 ─        goto #45
44 ─        goto #45
45 ┄ %81  = φ (#43 => true, #44 => false)::Bool
│    %82  = φ (#44 => 1)::Int64
│    %83  = φ (#44 => 1)::Int64
│    %84  = Base.not_int(%81)::Bool
└───        goto #68 if not %84
46 ┄ %86  = φ (#45 => %82, #66 => %129)::Int64
│    %87  = φ (#45 => %83, #66 => %130)::Int64
│    %88  = φ (#45 => %73, #66 => %123)::Float64
│    %89  = Base.lt_float(%43, 0.0)::Bool
└───        goto #48 if not %89
47 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %43::Float64)::Union{}
└───        unreachable
48 ─ %93  = Base.Math.sqrt_llvm(%43)::Float64
└───        goto #49
49 ─        goto #53 if not true
50 ─ %96  = Base.le_float(0.0, %93)::Bool
│    %97  = Base.not_int(%96)::Bool
└───        goto #52 if not %97
51 ─ %99  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %100 = Distributions.ArgumentError(%99)::Any
│           Distributions.throw(%100)::Union{}
└───        unreachable
52 ─        nothing::Nothing
53 ┄ %104 = %new(Normal{Float64}, %71, %93)::Normal{Float64}
└───        goto #54
54 ─        goto #55
55 ─ %107 = invoke Core.TypeVar(Symbol("#s66")::Symbol, Distribution::Any)::TypeVar
│    %108 = Core.apply_type(Main.AbstractVector, %107)::Type{AbstractVector{_A}} where _A
│    %109 = Core.UnionAll(%107, %108)::Any
│    %110 = Core.apply_type(Main.Union, Distribution, %109)::Type
│    %111 = (%104 isa %110)::Bool
└───        goto #67 if not %111
56 ─        nothing::Nothing
└───        goto #58 if not false
57 ─        nothing::Nothing
58 ┄        goto #60 if not false
59 ─        nothing::Nothing
60 ┄        Base.arrayref(true, x@_7, %86)::Float64
└───        goto #62 if not false
61 ─        nothing::Nothing
62 ┄ %121 = Base.arrayref(true, x@_7, %86)::Float64
│    %122 = invoke Distributions.logpdf(%104::Normal{Float64}, %121::Float64)::Float64
│    %123 = Base.add_float(%88, %122)::Float64
│    %124 = (%87 === %76)::Bool
└───        goto #64 if not %124
63 ─        goto #65
64 ─ %127 = Base.add_int(%87, 1)::Int64
└───        goto #65
65 ┄ %129 = φ (#64 => %127)::Int64
│    %130 = φ (#64 => %127)::Int64
│    %131 = φ (#63 => true, #64 => false)::Bool
│    %132 = Base.not_int(%131)::Bool
└───        goto #68 if not %132
66 ─        goto #46
67 ─ %135 = Main.ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions.")::Any
│           Main.throw(%135)::Union{}
└───        unreachable
68 ┄ %138 = φ (#65 => %123, #45 => %73)::Float64
└───        return %138
) => Float64

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it actually necessary to build a separate function? What are the advantages over defining logdensity similar to evaluate but with a special context that only accumulates the log density?

src/compiler.jl Outdated
@@ -70,13 +70,19 @@ end

function model(mod, linenumbernode, expr, warn)
modelinfo = build_model_info(expr)
modelinfo_logπ = deepcopy(modelinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we use it everywhere but I really don't like the name logπ 😕 It's even a constant in LogExpFunctions/StatsFuns... Could we use something that is more descriptive? Maybe logdensity?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah yeah, happy to do so! Generally I'm opposed to using unicode-symbols in structs anyways because people might not always have access to that.

@torfjelde
Copy link
Member Author

After removing the possible exceptions being thrown when RHS isn't a distribution. Still not quite there. You got any clue @devmotion ?

@benchmark $g($θ_nt, $([1.0, 1.0]))
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     44.732 ns (0.00% GC)
  median time:      46.147 ns (0.00% GC)
  mean time:        46.831 ns (0.00% GC)
  maximum time:     116.459 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     988
@benchmark $(m.logπ)($m, $spl, $ctx, $θ_nt, $(Float64), $(m.args...))
BenchmarkTools.Trial: 
  memory estimate:  128 bytes
  allocs estimate:  8
  --------------
  minimum time:     203.054 ns (0.00% GC)
  median time:      210.341 ns (0.00% GC)
  mean time:        221.010 ns (2.03% GC)
  maximum time:     3.239 μs (93.14% GC)
  --------------
  samples:          10000
  evals/sample:     575
@code_typed g(θ_nt, [1.0, 1.0])
CodeInfo(
1 ── %1   = Base.getfield(θ, :s)::Float64
│    %2   = Base.getfield(θ, :m)::Float64
└───        goto #3 if not true
2 ──        nothing::Nothing
3 ┄─ %5   = Base.sle_int(1, 1)::Bool
└───        goto #5 if not %5
4 ── %7   = Base.sle_int(1, 0)::Bool
└───        goto #6
5 ──        nothing::Nothing
6 ┄─ %10  = φ (#4 => %7, #5 => false)::Bool
└───        goto #8 if not %10
7 ──        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
8 ──        goto #9
9 ──        goto #10
10 ─        goto #11
11 ─        goto #12
12 ─        goto #13
13 ─ %19  = Base.sle_int(1, 1)::Bool
└───        goto #15 if not %19
14 ─ %21  = Base.sle_int(1, 0)::Bool
└───        goto #16
15 ─        nothing::Nothing
16 ┄ %24  = φ (#14 => %21, #15 => false)::Bool
└───        goto #18 if not %24
17 ─        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
18 ─        goto #19
19 ─        goto #20
20 ─        goto #21
21 ─        goto #22
22 ─        goto #23
23 ─        goto #24
24 ─ %34  = %new(InverseGamma{Float64}, $(QuoteNode(Gamma{Float64}(α=2.0, θ=0.3333333333333333))), 3.0)::InverseGamma{Float64}
└───        goto #25
25 ─        goto #26
26 ─        goto #27
27 ─        goto #28
28 ─ %39  = invoke Main.logpdf(%34::InverseGamma{Float64}, %1::Float64)::Float64
│    %40  = Base.lt_float(%1, 0.0)::Bool
└───        goto #30 if not %40
29 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %1::Float64)::Union{}
└───        unreachable
30 ─ %44  = Base.Math.sqrt_llvm(%1)::Float64
└───        goto #31
31 ─        goto #35 if not true
32 ─ %47  = Base.le_float(0.0, %44)::Bool
│    %48  = Base.not_int(%47)::Bool
└───        goto #34 if not %48
33 ─ %50  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %51  = Distributions.ArgumentError(%50)::Any
│           Distributions.throw(%51)::Union{}
└───        unreachable
34 ─        nothing::Nothing
35 ┄ %55  = %new(Normal{Float64}, 0.0, %44)::Normal{Float64}
└───        goto #36
36 ─        goto #37
37 ─        goto #38
38 ─ %59  = invoke Main.logpdf(%55::Normal{Float64}, %2::Float64)::Float64
│    %60  = Base.add_float(%39, %59)::Float64
│    %61  = Base.arraysize(x, 1)::Int64
│    %62  = Base.slt_int(%61, 0)::Bool
│    %63  = Base.ifelse(%62, 0, %61)::Int64
│    %64  = Base.slt_int(%63, 1)::Bool
└───        goto #40 if not %64
39 ─        goto #41
40 ─        goto #41
41 ┄ %68  = φ (#39 => true, #40 => false)::Bool
│    %69  = φ (#40 => 1)::Int64
│    %70  = φ (#40 => 1)::Int64
│    %71  = Base.not_int(%68)::Bool
└───        goto #56 if not %71
42 ┄ %73  = φ (#41 => %69, #55 => %102)::Int64
│    %74  = φ (#41 => %70, #55 => %103)::Int64
│    %75  = φ (#41 => %60, #55 => %96)::Float64
│    %76  = Base.lt_float(%1, 0.0)::Bool
└───        goto #44 if not %76
43 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %1::Float64)::Union{}
└───        unreachable
44 ─ %80  = Base.Math.sqrt_llvm(%1)::Float64
└───        goto #45
45 ─        goto #49 if not true
46 ─ %83  = Base.le_float(0.0, %80)::Bool
│    %84  = Base.not_int(%83)::Bool
└───        goto #48 if not %84
47 ─ %86  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %87  = Distributions.ArgumentError(%86)::Any
│           Distributions.throw(%87)::Union{}
└───        unreachable
48 ─        nothing::Nothing
49 ┄ %91  = %new(Normal{Float64}, %2, %80)::Normal{Float64}
└───        goto #50
50 ─        goto #51
51 ─ %94  = Base.arrayref(true, x, %73)::Float64
│    %95  = invoke Main.logpdf(%91::Normal{Float64}, %94::Float64)::Float64
│    %96  = Base.add_float(%75, %95)::Float64
│    %97  = (%74 === %63)::Bool
└───        goto #53 if not %97
52 ─        goto #54
53 ─ %100 = Base.add_int(%74, 1)::Int64
└───        goto #54
54 ┄ %102 = φ (#53 => %100)::Int64
│    %103 = φ (#53 => %100)::Int64
│    %104 = φ (#52 => true, #53 => false)::Bool
│    %105 = Base.not_int(%104)::Bool
└───        goto #56 if not %105
55 ─        goto #42
56 ┄ %108 = φ (#54 => %96, #41 => %60)::Float64
└───        return %108
) => Float64
@code_typed m.logπ(m, spl, ctx, θ_nt, Float64, m.args...)
CodeInfo(
1 ──        goto #3 if not true
2 ──        nothing::Nothing
3 ┄─ %3   = Base.sle_int(1, 1)::Bool
└───        goto #5 if not %3
4 ── %5   = Base.sle_int(1, 0)::Bool
└───        goto #6
5 ──        nothing::Nothing
6 ┄─ %8   = φ (#4 => %5, #5 => false)::Bool
└───        goto #8 if not %8
7 ──        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
8 ──        goto #9
9 ──        goto #10
10 ─        goto #11
11 ─        goto #12
12 ─        goto #13
13 ─ %17  = Base.sle_int(1, 1)::Bool
└───        goto #15 if not %17
14 ─ %19  = Base.sle_int(1, 0)::Bool
└───        goto #16
15 ─        nothing::Nothing
16 ┄ %22  = φ (#14 => %19, #15 => false)::Bool
└───        goto #18 if not %22
17 ─        invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───        unreachable
18 ─        goto #19
19 ─        goto #20
20 ─        goto #21
21 ─        goto #22
22 ─        goto #23
23 ─        goto #24
24 ─ %32  = %new(InverseGamma{Float64}, $(QuoteNode(Gamma{Float64}(α=2.0, θ=0.3333333333333333))), 3.0)::InverseGamma{Float64}
└───        goto #25
25 ─        goto #26
26 ─        goto #27
27 ─        goto #28
28 ─        nothing::Nothing
│           nothing::Nothing
│    %39  = π (true, Core.Const(true))
└───        goto #30 if not %39
29 ─ %41  = Base.getfield(__variables__, :s)::Float64
30 ┄ %42  = φ (#29 => %41, #28 => #undef)::Core.Compiler.MaybeUndef(Float64)
│    %43  = invoke Distributions.logpdf(%32::InverseGamma{Float64}, %42::Float64)::Float64
│    %44  = Base.add_float(0.0, %43)::Float64
│    %45  = Base.lt_float(%42, 0.0)::Bool
└───        goto #32 if not %45
31 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %42::Float64)::Union{}
└───        unreachable
32 ─ %49  = Base.Math.sqrt_llvm(%42)::Float64
└───        goto #33
33 ─        goto #37 if not true
34 ─ %52  = Base.le_float(0.0, %49)::Bool
│    %53  = Base.not_int(%52)::Bool
└───        goto #36 if not %53
35 ─ %55  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %56  = Distributions.ArgumentError(%55)::Any
│           Distributions.throw(%56)::Union{}
└───        unreachable
36 ─        nothing::Nothing
37 ┄ %60  = %new(Normal{Float64}, 0.0, %49)::Normal{Float64}
└───        goto #38
38 ─        goto #39
39 ─        goto #40
40 ─        nothing::Nothing
│           nothing::Nothing
│    %66  = π (true, Core.Const(true))
└───        goto #42 if not %66
41 ─ %68  = Base.getfield(__variables__, :m)::Float64
42 ┄ %69  = φ (#41 => %68, #40 => #undef)::Core.Compiler.MaybeUndef(Float64)
│    %70  = invoke Distributions.logpdf(%60::Normal{Float64}, %69::Float64)::Float64
│    %71  = Base.add_float(%44, %70)::Float64
│    %72  = Base.arraysize(x@_7, 1)::Int64
│    %73  = Base.slt_int(%72, 0)::Bool
│    %74  = Base.ifelse(%73, 0, %72)::Int64
│    %75  = Base.slt_int(%74, 1)::Bool
└───        goto #44 if not %75
43 ─        goto #45
44 ─        goto #45
45 ┄ %79  = φ (#43 => true, #44 => false)::Bool
│    %80  = φ (#44 => 1)::Int64
│    %81  = φ (#44 => 1)::Int64
│    %82  = Base.not_int(%79)::Bool
└───        goto #66 if not %82
46 ┄ %84  = φ (#45 => %80, #65 => %120)::Int64
│    %85  = φ (#45 => %81, #65 => %121)::Int64
│    %86  = φ (#45 => %71, #65 => %114)::Float64
│    %87  = Base.lt_float(%42, 0.0)::Bool
└───        goto #48 if not %87
47 ─        invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %42::Float64)::Union{}
└───        unreachable
48 ─ %91  = Base.Math.sqrt_llvm(%42)::Float64
└───        goto #49
49 ─        goto #53 if not true
50 ─ %94  = Base.le_float(0.0, %91)::Bool
│    %95  = Base.not_int(%94)::Bool
└───        goto #52 if not %95
51 ─ %97  = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %98  = Distributions.ArgumentError(%97)::Any
│           Distributions.throw(%98)::Union{}
└───        unreachable
52 ─        nothing::Nothing
53 ┄ %102 = %new(Normal{Float64}, %69, %91)::Normal{Float64}
└───        goto #54
54 ─        goto #55
55 ─        goto #57 if not false
56 ─        nothing::Nothing
57 ┄        goto #59 if not false
58 ─        nothing::Nothing
59 ┄        Base.arrayref(true, x@_7, %84)::Float64
└───        goto #61 if not false
60 ─        nothing::Nothing
61 ┄ %112 = Base.arrayref(true, x@_7, %84)::Float64
│    %113 = invoke Distributions.logpdf(%102::Normal{Float64}, %112::Float64)::Float64
│    %114 = Base.add_float(%86, %113)::Float64
│    %115 = (%85 === %74)::Bool
└───        goto #63 if not %115
62 ─        goto #64
63 ─ %118 = Base.add_int(%85, 1)::Int64
└───        goto #64
64 ┄ %120 = φ (#63 => %118)::Int64
│    %121 = φ (#63 => %118)::Int64
│    %122 = φ (#62 => true, #63 => false)::Bool
│    %123 = Base.not_int(%122)::Bool
└───        goto #66 if not %123
65 ─        goto #46
66 ┄ %126 = φ (#64 => %114, #45 => %71)::Float64
└───        return %126
) => Float64

@torfjelde
Copy link
Member Author

torfjelde commented May 16, 2021

Is it actually necessary to build a separate function? What are the advantages over defining logdensity similar to evaluate but with a special context that only accumulates the log density?

Hmm, but you mean while still not passing in the VarInfo or making it nothing? Maybe that would be better yeah. That should even work with stuff like @submodel and so on without issues. I just wasn't quite certain if I would have to inherit anything else by doing so.

But I still don't want to keep the variables in VarInfo. I don't want the linearized representation, nor the linking-impl from VarInfo.

@torfjelde
Copy link
Member Author

torfjelde commented May 16, 2021

It seems like the possible redefinition of the variable causes the following statements:

│    %66  = π (true, Core.Const(true))
└───        goto #42 if not %66
41 ─ %68  = Base.getfield(__variables__, :m)::Float64
42 ┄ %69  = φ (#41 => %68, #40 => #undef)::Core.Compiler.MaybeUndef(Float64)

to show up. I'm not entirely certain why or what's going on here, but it seems to cause the additional allocations.

EDIT: Nvm. I read a bit about SSA and AFAIK the above is a Pi-statement which represents information known at compile-time, and a Phi-statement which is a conditional depending on the referenced blocks. So in this case the Pi + Phi will be compiled away. I can confirm this by removing the if $isassumption check that we do, which still results in the same runtime above.

@torfjelde
Copy link
Member Author

Okay, so now I'm even more confused. I've removed the if $isassumption statements and the likelihood in the model, and the two now generates almost identical lowered code with the exception of how the parameters are obtained BUT STILL performance is worse! Where are these allocations coming from?

@benchmark $g($θ_nt, $([1.0, 1.0]))
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     27.531 ns (0.00% GC)
  median time:      28.338 ns (0.00% GC)
  mean time:        28.911 ns (0.00% GC)
  maximum time:     90.618 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     994
@benchmark $(m.logπ)($m, $spl, $ctx, $θ_nt, $(Float64), $(m.args.x))
BenchmarkTools.Trial: 
  memory estimate:  64 bytes
  allocs estimate:  4
  --------------
  minimum time:     150.570 ns (0.00% GC)
  median time:      152.370 ns (0.00% GC)
  mean time:        158.550 ns (1.53% GC)
  maximum time:     3.160 μs (94.99% GC)
  --------------
  samples:          10000
  evals/sample:     767
@code_typed g(θ_nt, [1.0, 1.0])
CodeInfo(
1 ── %1  = Base.getfield(θ, :s)::Float64
│    %2  = Base.getfield(θ, :m)::Float64
│    %3  = π (0, Core.Const(0))
└───       goto #3 if not true
2 ──       nothing::Nothing
3 ┄─ %6  = Base.sle_int(1, 1)::Bool
└───       goto #5 if not %6
4 ── %8  = Base.sle_int(1, 0)::Bool
└───       goto #6
5 ──       nothing::Nothing
6 ┄─ %11 = φ (#4 => %8, #5 => false)::Bool
└───       goto #8 if not %11
7 ──       invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───       unreachable
8 ──       goto #9
9 ──       goto #10
10 ─       goto #11
11 ─       goto #12
12 ─       goto #13
13 ─ %20 = Base.sle_int(1, 1)::Bool
└───       goto #15 if not %20
14 ─ %22 = Base.sle_int(1, 0)::Bool
└───       goto #16
15 ─       nothing::Nothing
16 ┄ %25 = φ (#14 => %22, #15 => false)::Bool
└───       goto #18 if not %25
17 ─       invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───       unreachable
18 ─       goto #19
19 ─       goto #20
20 ─       goto #21
21 ─       goto #22
22 ─       goto #23
23 ─       goto #24
24 ─ %35 = %new(InverseGamma{Float64}, $(QuoteNode(Gamma{Float64}(α=2.0, θ=0.3333333333333333))), 3.0)::InverseGamma{Float64}
└───       goto #25
25 ─       goto #26
26 ─       goto #27
27 ─       goto #28
28 ─ %40 = invoke Distributions.logpdf(%35::InverseGamma{Float64}, %1::Float64)::Float64
│    %41 = Base.sitofp(Float64, %3)::Float64
│    %42 = Base.add_float(%41, %40)::Float64
│    %43 = Base.lt_float(%1, 0.0)::Bool
└───       goto #30 if not %43
29 ─       invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %1::Float64)::Union{}
└───       unreachable
30 ─ %47 = Base.Math.sqrt_llvm(%1)::Float64
└───       goto #31
31 ─       goto #35 if not true
32 ─ %50 = Base.le_float(0.0, %47)::Bool
│    %51 = Base.not_int(%50)::Bool
└───       goto #34 if not %51
33 ─ %53 = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %54 = Distributions.ArgumentError(%53)::Any
│          Distributions.throw(%54)::Union{}
└───       unreachable
34 ─       nothing::Nothing
35 ┄ %58 = %new(Normal{Float64}, 0.0, %47)::Normal{Float64}
└───       goto #36
36 ─       goto #37
37 ─       goto #38
38 ─ %62 = invoke Distributions.logpdf(%58::Normal{Float64}, %2::Float64)::Float64
│    %63 = Base.add_float(%42, %62)::Float64
└───       return %63
) => Float64
@code_typed m.logπ(m, spl, ctx, θ_nt, Float64, m.args...)
CodeInfo(
1 ──       goto #3 if not true
2 ──       nothing::Nothing
3 ┄─ %3  = Base.sle_int(1, 1)::Bool
└───       goto #5 if not %3
4 ── %5  = Base.sle_int(1, 0)::Bool
└───       goto #6
5 ──       nothing::Nothing
6 ┄─ %8  = φ (#4 => %5, #5 => false)::Bool
└───       goto #8 if not %8
7 ──       invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───       unreachable
8 ──       goto #9
9 ──       goto #10
10 ─       goto #11
11 ─       goto #12
12 ─       goto #13
13 ─ %17 = Base.sle_int(1, 1)::Bool
└───       goto #15 if not %17
14 ─ %19 = Base.sle_int(1, 0)::Bool
└───       goto #16
15 ─       nothing::Nothing
16 ┄ %22 = φ (#14 => %19, #15 => false)::Bool
└───       goto #18 if not %22
17 ─       invoke Base.getindex(()::Tuple, 1::Int64)::Union{}
└───       unreachable
18 ─       goto #19
19 ─       goto #20
20 ─       goto #21
21 ─       goto #22
22 ─       goto #23
23 ─       goto #24
24 ─ %32 = %new(InverseGamma{Float64}, $(QuoteNode(Gamma{Float64}(α=2.0, θ=0.3333333333333333))), 3.0)::InverseGamma{Float64}
└───       goto #25
25 ─       goto #26
26 ─       goto #27
27 ─       goto #28
28 ─ %37 = Base.getfield(__variables__, :s)::Float64
│    %38 = invoke Distributions.logpdf(%32::InverseGamma{Float64}, %37::Float64)::Float64
│    %39 = Base.add_float(0.0, %38)::Float64
│    %40 = Base.lt_float(%37, 0.0)::Bool
└───       goto #30 if not %40
29 ─       invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %37::Float64)::Union{}
└───       unreachable
30 ─ %44 = Base.Math.sqrt_llvm(%37)::Float64
└───       goto #31
31 ─       goto #35 if not true
32 ─ %47 = Base.le_float(0.0, %44)::Bool
│    %48 = Base.not_int(%47)::Bool
└───       goto #34 if not %48
33 ─ %50 = Distributions.string("Normal", ": the condition ", "σ >= zero(σ)", " is not satisfied.")::Any
│    %51 = Distributions.ArgumentError(%50)::Any
│          Distributions.throw(%51)::Union{}
└───       unreachable
34 ─       nothing::Nothing
35 ┄ %55 = %new(Normal{Float64}, 0.0, %44)::Normal{Float64}
└───       goto #36
36 ─       goto #37
37 ─       goto #38
38 ─ %59 = Base.getfield(__variables__, :m)::Float64
│    %60 = invoke Distributions.logpdf(%55::Normal{Float64}, %59::Float64)::Float64
│    %61 = Base.add_float(%39, %60)::Float64
└───       return %61
) => Float64

@devmotion
Copy link
Member

Is it actually necessary to build a separate function? What are the advantages over defining logdensity similar to evaluate but with a special context that only accumulates the log density?

Hmm, but you mean while still not passing in the VarInfo or making it nothing? Maybe that would be better yeah. That should even work with stuff like @submodel and so on without issues. I just wasn't quite certain if I would have to inherit anything else by doing so.

But I still don't want to keep the variables in VarInfo. I don't want the linearized representation, nor the linking-impl from VarInfo.

Yes, the idea would be to neither pass in VarInfo nor add variables to VarInfo.

@torfjelde
Copy link
Member Author

Yes, the idea would be to neither pass in VarInfo nor add variables to VarInfo.

Sounds like a good idea; I'll give this a try!

Any thoughts on why we're still seeing perf loss compared to manual impl?

torfjelde and others added 3 commits May 20, 2021 06:48
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

torfjelde commented May 20, 2021

So now I've added a SimpleVarInfo which only contains the parameters (currently in NamedTuple form but this can easily be implemented for linearized parameters) and the logp. Together with EvaluationContext this leads to a very simple yet efficient approach. It ends up being as efficient as the generated logjoint for evaluation but it seems slightly (not much though) slower in the adjoint-computation (using Zygote in a simple example).

Also, I've added an implementation of Bijectors.bijector(::VarInfo) that I've been using quite heavily in my own projects. Now that we have the SimpleVarInfo, the bijector resuting from bijector(VarInfo(model)) can be used to transform the entire SimpleVarInfo. I have also been using a optimize_bijector to vectorize where it makes sense, etc. I'll make a PR to Bijectors.jl for this.

I'm still confused about about why there's still an overhead for the generated logjoint though 😕

EDIT: Also, there is a benefit to the generated logjoint: not mutation occurs, i.e. we can use Zygote even for cases such as x[i] ~ Normal(). Whether it's worth it or not is a different question 😕

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@@ -0,0 +1,33 @@
@generated function _bijector(md::NamedTuple{names}; tuplify=false) where {names}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make sense to move the bijector changes to a separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you might be right here. But the reason why I added it is because we sort of "need" it if we support NamedTuple evaluations without using VarInfo, since there's no "linking" going on here. Therefore I figured I should just add it together with the introduction of SimpleVarInfo or w/e we end up going with.

With that being said, I'm fine with making it a separate PR too:) Just wanted to explain my reasoning.

expr = Expr(:tuple)
for n in names
e = quote
if length(md.$n.dists) == 1 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be type stable, it seems?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, but AFAIK there's no way to ensure that since dists can be a vector of anything. The aim isn't for this to be type-stable either; it's meant be called once and then you re-use the resulting bijector wherever you need. So mby we should also just remove the "@generated`?

for n in names
e = quote
if length(md.$n.dists) == 1 &&
md.$n.dists[1] isa $(Distributions.UnivariateDistribution)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should not be necessary to qualify the types and functions in the generated function (BTW maybe make it if @generated?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, didn't know; thanks!

And regarding @generated; starting to think it shouldn't even be a @generated...


Returns a `NamedBijector` which can transform different variants of `varinfo`.

If `tuplify` is true, then a type-stable bijector will be returned.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe one could choose a name that indicates more clearly that it helps with type stability? It almost sounds like it would return a tuple of bijectors. Maybe also astuple would be a more "formal" API?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, yeah sorry; that was leftover from earlier. I also hate tuplify:)

astuple maybe, or maybe static? Or typed?

Comment on lines +501 to +571
function generate_mainbody_logdensity(mod, expr, warn)
return generate_mainbody_logdensity!(mod, Symbol[], expr, warn)
end

generate_mainbody_logdensity!(mod, found, x, warn) = x
function generate_mainbody_logdensity!(mod, found, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody_logdensity!,
)
return generate_mainbody_logdensity!(mod, found, newsym, warn)
end

if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$sym`"
push!(found, sym)
end

return sym
end
function generate_mainbody_logdensity!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody_logdensity!(
mod, found, macroexpand(mod, expr; recursive=true), warn
)
end

# If it's a return, we instead return `__lp__`.
if Meta.isexpr(expr, :return)
returnbody = Expr(
:block,
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...,
)
return :($(returnbody); return __lp__)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
left = generate_mainbody_logdensity!(mod, found, L, warn)
return Base.remove_linenums!(
generate_dot_tilde_logdensity(
left, generate_mainbody_logdensity!(mod, found, R, warn)
),
)
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
left = generate_mainbody_logdensity!(mod, found, L, warn)
return Base.remove_linenums!(
generate_tilde_logdensity(
left, generate_mainbody_logdensity!(mod, found, R, warn)
),
)
end

return Expr(
expr.head,
map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...,
)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the details but it seems this code is very similar to the existing generate_mainbody!. Maybe one could just pass unify them by passing the inner functions (generate_dot_tilde/generate_dot_tilde_logdensity and generate_tilde/generate_tilde_logdensity) as first argument?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, I had the exact same thought as I was writing it out 😅 Left it as is for now though since I'm still on the edge whether or not we should have this vs. EvaluationContext with SimpleVarInfo. If we do keep it, I 100% agree.

Comment on lines +69 to +76
if ctx.ctx isa PriorContext
tilde_observe(LikelihoodContext(), sampler, right, value, vn, inds, vi)
elseif ctx.ctx isa LikelihoodContext
# Need to make it so that this isn't computed.
tilde_observe(PriorContext(), sampler, right, value, vn, inds, vi)
else
tilde_observe(ctx, sampler, right, value, vn, inds, vi)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me immediately why this is needed, and why we use LikelihoodContext for PriorContext and vice versa.

In general, I think implementation-wise it would be better to dispatch on EvaluationContext and to not hardcode the type checks in the function body.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I use the if here is because what I mentioned in #249 where it seemed more natural to have a if here for readability. But I've added LeafCtx in that PR too, so we could dispatch on this.

Comment on lines +118 to +123
if vi === nothing
return logp
else
acclogp!(vi, logp)
return left
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here and below, I think it is cleaner to not add the check in the function body but to dispatch on ::Nothing. Then one can move all these special definitions together and it can be explained and understood easier why they are needed.

@@ -409,7 +480,7 @@ function dot_observe(
value::AbstractMatrix,
vi,
)
increment_num_produce!(vi)
vi isa VarInfo && increment_num_produce!(vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here, I think it might be cleaner to just dispatch on vi::VarInfo.

@@ -0,0 +1,21 @@
struct SimpleVarInfo{NT,T} <: AbstractVarInfo
θ::NT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems it is necessary to save the parameters?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean here 😕

SimpleVarInfo is used in conjuction with EvaluationContext; it is not used for in the generated logjoint. So SimpleVarInfo just keeps the variables for access in the assume and observe statements. Does that help?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha yeah now it's clear, I think I just mixed up the two approaches, the generated logjoint and the VarInfo/EvaluationContext one 😄

:(__variables__),
],
modelinfo[:allargs_exprs],
[Expr(:kw, :(::Type{$T}), :Float64), ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
[Expr(:kw, :(::Type{$T}), :Float64), ]
[Expr(:kw, :(::Type{$T}), :Float64)],

@torfjelde torfjelde mentioned this pull request Jun 18, 2021
3 tasks
@torfjelde
Copy link
Member Author

Closed in favour of #242 .

@torfjelde torfjelde closed this Jul 10, 2021
@yebai yebai deleted the tor/logdensity branch January 28, 2022 20:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants