|
1 | 1 | using DynamicPPL: LogDensityFunction
|
2 | 2 |
|
3 |
| -@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin |
4 |
| - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS |
5 |
| - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) |
6 |
| - vns = DynamicPPL.TestUtils.varnames(m) |
7 |
| - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) |
| 3 | +@testset "Automatic differentiation" begin |
| 4 | + @testset "Unsupported backends" begin |
| 5 | + @model demo() = x ~ Normal() |
| 6 | + @test_logs (:warn, r"not officially supported") LogDensityFunction( |
| 7 | + demo(); adtype=AutoZygote() |
| 8 | + ) |
| 9 | + end |
| 10 | + |
| 11 | + @testset "Correctness: ForwardDiff, ReverseDiff, and Mooncake" begin |
| 12 | + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS |
| 13 | + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) |
| 14 | + vns = DynamicPPL.TestUtils.varnames(m) |
| 15 | + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) |
8 | 16 |
|
9 |
| - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos |
10 |
| - f = LogDensityFunction(m, varinfo) |
11 |
| - x = DynamicPPL.getparams(f) |
12 |
| - # Calculate reference logp + gradient of logp using ForwardDiff |
13 |
| - ref_adtype = ADTypes.AutoForwardDiff() |
14 |
| - ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) |
15 |
| - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) |
| 17 | + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos |
| 18 | + f = LogDensityFunction(m, varinfo) |
| 19 | + x = DynamicPPL.getparams(f) |
| 20 | + # Calculate reference logp + gradient of logp using ForwardDiff |
| 21 | + ref_adtype = ADTypes.AutoForwardDiff() |
| 22 | + ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) |
| 23 | + ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) |
16 | 24 |
|
17 |
| - @testset "$adtype" for adtype in [ |
18 |
| - AutoReverseDiff(; compile=false), |
19 |
| - AutoReverseDiff(; compile=true), |
20 |
| - AutoMooncake(; config=nothing), |
21 |
| - ] |
22 |
| - @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" |
| 25 | + @testset "$adtype" for adtype in [ |
| 26 | + AutoReverseDiff(; compile=false), |
| 27 | + AutoReverseDiff(; compile=true), |
| 28 | + AutoMooncake(; config=nothing), |
| 29 | + ] |
| 30 | + @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" |
23 | 31 |
|
24 |
| - # Put predicates here to avoid long lines |
25 |
| - is_mooncake = adtype isa AutoMooncake |
26 |
| - is_1_10 = v"1.10" <= VERSION < v"1.11" |
27 |
| - is_1_11 = v"1.11" <= VERSION < v"1.12" |
28 |
| - is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} |
29 |
| - is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} |
| 32 | + # Put predicates here to avoid long lines |
| 33 | + is_mooncake = adtype isa AutoMooncake |
| 34 | + is_1_10 = v"1.10" <= VERSION < v"1.11" |
| 35 | + is_1_11 = v"1.11" <= VERSION < v"1.12" |
| 36 | + is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} |
| 37 | + is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} |
30 | 38 |
|
31 |
| - # Mooncake doesn't work with several combinations of SimpleVarInfo. |
32 |
| - if is_mooncake && is_1_11 && is_svi_vnv |
33 |
| - # https://github.com/compintell/Mooncake.jl/issues/470 |
34 |
| - @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) |
35 |
| - elseif is_mooncake && is_1_10 && is_svi_vnv |
36 |
| - # TODO: report upstream |
37 |
| - @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) |
38 |
| - elseif is_mooncake && is_1_10 && is_svi_od |
39 |
| - # TODO: report upstream |
40 |
| - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( |
41 |
| - ref_ldf, adtype |
42 |
| - ) |
43 |
| - else |
44 |
| - ldf = DynamicPPL.setadtype(ref_ldf, adtype) |
45 |
| - logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) |
46 |
| - @test grad ≈ ref_grad |
47 |
| - @test logp ≈ ref_logp |
| 39 | + # Mooncake doesn't work with several combinations of SimpleVarInfo. |
| 40 | + if is_mooncake && is_1_11 && is_svi_vnv |
| 41 | + # https://github.com/compintell/Mooncake.jl/issues/470 |
| 42 | + @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) |
| 43 | + elseif is_mooncake && is_1_10 && is_svi_vnv |
| 44 | + # TODO: report upstream |
| 45 | + @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) |
| 46 | + elseif is_mooncake && is_1_10 && is_svi_od |
| 47 | + # TODO: report upstream |
| 48 | + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( |
| 49 | + ref_ldf, adtype |
| 50 | + ) |
| 51 | + else |
| 52 | + ldf = DynamicPPL.setadtype(ref_ldf, adtype) |
| 53 | + logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) |
| 54 | + @test grad ≈ ref_grad |
| 55 | + @test logp ≈ ref_logp |
| 56 | + end |
48 | 57 | end
|
49 | 58 | end
|
50 | 59 | end
|
|
0 commit comments