From 9b2060b2b2ece541b70d4594ed7a75749c3c8d5e Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sat, 7 Dec 2024 13:56:54 +0000 Subject: [PATCH 1/5] Mooncake ext --- Project.toml | 3 +++ ext/DynamicPPLMooncakeExt.jl | 9 +++++++++ test/Project.toml | 2 +- test/ext/DynamicPPLMooncakeExt.jl | 3 +++ test/runtests.jl | 2 ++ 5 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 ext/DynamicPPLMooncakeExt.jl create mode 100644 test/ext/DynamicPPLMooncakeExt.jl diff --git a/Project.toml b/Project.toml index 909be870f..f0c582fec 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -37,6 +38,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMooncakeExt = ["Mooncake"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -58,6 +60,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" +Mooncake = "0.4.59" OrderedCollections = "1" Random = "1.6" Requires = "1" diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl new file mode 100644 index 000000000..97699a7cb --- /dev/null +++ b/ext/DynamicPPLMooncakeExt.jl @@ -0,0 +1,9 @@ +module DynamicPPLMooncakeExt + +using DynamicPPL: DynamicPPL, istrans +using Mooncake: Mooncake + +# This is purely an optimisation. +Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg} + +end # module diff --git a/test/Project.toml b/test/Project.toml index 0d247c3ec..4f12f2015 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -45,7 +45,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" -Mooncake = "0.4.50" +Mooncake = "0.4.59" ReverseDiff = "1" StableRNGs = "1" Tracker = "0.2.23" diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl new file mode 100644 index 000000000..8e1611c63 --- /dev/null +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -0,0 +1,3 @@ +@testset "DynamicPPLMooncakeExt" begin + Mooncake.TestUtils.test_rule(StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true) +end diff --git a/test/runtests.jl b/test/runtests.jl index dbfa319b0..a4fdabf22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using LogDensityProblems, LogDensityProblemsAD using MacroTools using MCMCChains using Mooncake: Mooncake +using StableRNGs using Tracker using ReverseDiff using Zygote @@ -77,6 +78,7 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") + include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end From 29c2c80b599299fbb771cfa46d6d8bc4408000b5 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sat, 7 Dec 2024 13:57:12 +0000 Subject: [PATCH 2/5] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f0c582fec..21d501e9f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.3" +version = "0.31.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 21f4528af799bce035dff604957c5d2d657487fd Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sat, 7 Dec 2024 14:01:26 +0000 Subject: [PATCH 3/5] Update ext/DynamicPPLMooncakeExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/DynamicPPLMooncakeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 97699a7cb..b86d807bc 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -4,6 +4,6 @@ using DynamicPPL: DynamicPPL, istrans using Mooncake: Mooncake # This is purely an optimisation. -Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg} +Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} end # module From 57f3de33c5cf4a694531605468d2c32637994396 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sun, 8 Dec 2024 11:07:06 +0000 Subject: [PATCH 4/5] Update test/ext/DynamicPPLMooncakeExt.jl --- test/ext/DynamicPPLMooncakeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index 8e1611c63..8757ad67b 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -1,3 +1,3 @@ @testset "DynamicPPLMooncakeExt" begin - Mooncake.TestUtils.test_rule(StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true) + Mooncake.TestUtils.test_rule(StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true) end From 689006a525bad83b8bffec6db2ab496b4cf7e240 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sun, 8 Dec 2024 14:53:55 +0000 Subject: [PATCH 5/5] Update test/ext/DynamicPPLMooncakeExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ext/DynamicPPLMooncakeExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index 8757ad67b..986057da0 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -1,3 +1,5 @@ @testset "DynamicPPLMooncakeExt" begin - Mooncake.TestUtils.test_rule(StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true) + Mooncake.TestUtils.test_rule( + StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true + ) end