Skip to content

Commit d343827

Browse files
committed
Warn if unsupported AD type is used
1 parent 293f18b commit d343827

File tree

2 files changed

+67
-41
lines changed

2 files changed

+67
-41
lines changed

src/logdensityfunction.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
import DifferentiationInterface as DI
22

3+
"""
4+
is_supported(adtype::AbstractADType)
5+
6+
Check if the given AD type is formally supported by DynamicPPL.
7+
8+
AD backends that are not formally supported can still be used for gradient
9+
calculation; it is just that the DynamicPPL developers do not commit to
10+
maintaining compatibility with them.
11+
"""
12+
is_supported(::ADTypes.AbstractADType) = false
13+
is_supported(::ADTypes.AutoForwardDiff) = true
14+
is_supported(::ADTypes.AutoMooncake) = true
15+
is_supported(::ADTypes.AutoReverseDiff) = true
16+
317
"""
418
LogDensityFunction(
519
model::Model,
@@ -105,6 +119,9 @@ struct LogDensityFunction{
105119
prep = nothing
106120
with_closure = false
107121
else
122+
# Check support
123+
is_supported(adtype) ||
124+
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
108125
# Get a set of dummy params to use for prep
109126
x = map(identity, varinfo[:])
110127
with_closure = use_closure(adtype)

test/ad.jl

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,59 @@
11
using DynamicPPL: LogDensityFunction
22

3-
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
4-
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
5-
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
6-
vns = DynamicPPL.TestUtils.varnames(m)
7-
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
3+
@testset "Automatic differentiation" begin
4+
@testset "Unsupported backends" begin
5+
@model demo() = x ~ Normal()
6+
@test_logs (:warn, r"not officially supported") LogDensityFunction(
7+
demo(); adtype=AutoZygote()
8+
)
9+
end
10+
11+
@testset "Correctness: ForwardDiff, ReverseDiff, and Mooncake" begin
12+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
13+
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
14+
vns = DynamicPPL.TestUtils.varnames(m)
15+
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
816

9-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
10-
f = LogDensityFunction(m, varinfo)
11-
x = DynamicPPL.getparams(f)
12-
# Calculate reference logp + gradient of logp using ForwardDiff
13-
ref_adtype = ADTypes.AutoForwardDiff()
14-
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
15-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
17+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
18+
f = LogDensityFunction(m, varinfo)
19+
x = DynamicPPL.getparams(f)
20+
# Calculate reference logp + gradient of logp using ForwardDiff
21+
ref_adtype = ADTypes.AutoForwardDiff()
22+
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
23+
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
1624

17-
@testset "$adtype" for adtype in [
18-
AutoReverseDiff(; compile=false),
19-
AutoReverseDiff(; compile=true),
20-
AutoMooncake(; config=nothing),
21-
]
22-
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
25+
@testset "$adtype" for adtype in [
26+
AutoReverseDiff(; compile=false),
27+
AutoReverseDiff(; compile=true),
28+
AutoMooncake(; config=nothing),
29+
]
30+
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
2331

24-
# Put predicates here to avoid long lines
25-
is_mooncake = adtype isa AutoMooncake
26-
is_1_10 = v"1.10" <= VERSION < v"1.11"
27-
is_1_11 = v"1.11" <= VERSION < v"1.12"
28-
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
29-
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
32+
# Put predicates here to avoid long lines
33+
is_mooncake = adtype isa AutoMooncake
34+
is_1_10 = v"1.10" <= VERSION < v"1.11"
35+
is_1_11 = v"1.11" <= VERSION < v"1.12"
36+
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
37+
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
3038

31-
# Mooncake doesn't work with several combinations of SimpleVarInfo.
32-
if is_mooncake && is_1_11 && is_svi_vnv
33-
# https://github.com/compintell/Mooncake.jl/issues/470
34-
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
35-
elseif is_mooncake && is_1_10 && is_svi_vnv
36-
# TODO: report upstream
37-
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
38-
elseif is_mooncake && is_1_10 && is_svi_od
39-
# TODO: report upstream
40-
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype(
41-
ref_ldf, adtype
42-
)
43-
else
44-
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
45-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
46-
@test grad ref_grad
47-
@test logp ref_logp
39+
# Mooncake doesn't work with several combinations of SimpleVarInfo.
40+
if is_mooncake && is_1_11 && is_svi_vnv
41+
# https://github.com/compintell/Mooncake.jl/issues/470
42+
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
43+
elseif is_mooncake && is_1_10 && is_svi_vnv
44+
# TODO: report upstream
45+
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
46+
elseif is_mooncake && is_1_10 && is_svi_od
47+
# TODO: report upstream
48+
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype(
49+
ref_ldf, adtype
50+
)
51+
else
52+
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
53+
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
54+
@test grad ref_grad
55+
@test logp ref_logp
56+
end
4857
end
4958
end
5059
end

0 commit comments

Comments
 (0)