Skip to content

Backend switching for Mooncake #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).

## Implementations

Original file line number Diff line number Diff line change
@@ -111,4 +111,5 @@ There are, however, translation utilities:
### Backend switch

Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object.

Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object.
Original file line number Diff line number Diff line change
@@ -3,14 +3,22 @@ module DifferentiationInterfaceMooncakeExt
using ADTypes: ADTypes, AutoMooncake
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
prepare_gradient_cache,
prepare_pullback_cache,
tangent_type,
value_and_gradient!!,
value_and_pullback!!,
zero_tangent
zero_tangent,
rdata_type,
@is_primitive,
zero_fcodual,
MinimalCtx,
NoRData,
fdata,
primal

DI.check_available(::AutoMooncake) = true

@@ -26,5 +34,6 @@ mycopy(x) = deepcopy(x)

include("onearg.jl")
include("twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}}

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
primal_func = primal(dw)
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 7 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L3-L7

Added lines #L3 - L7 were not covered by tests

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
@assert only(tx) isa rdata_type(typeof(primal_x))
return NoRData(), only(tx)

Check warning on line 13 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L10-L13

Added lines #L10 - L13 were not covered by tests
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert only(tx) isa rdata_type(typeof(primal_x))
return NoRData(), only(tx)

Check warning on line 20 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L17-L20

Added lines #L17 - L20 were not covered by tests
end

return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!

Check warning on line 23 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L23

Added line #L23 was not covered by tests
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = fdata(x.dx)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 31 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L26-L31

Added lines #L26 - L31 were not covered by tests

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
fdata_arg .+= only(tx)
return NoRData(), dy

Check warning on line 38 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L34-L38

Added lines #L34 - L38 were not covered by tests
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
fdata_arg .+= only(tx)
return NoRData(), NoRData()

Check warning on line 46 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L42-L46

Added lines #L42 - L46 were not covered by tests
end

return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!

Check warning on line 49 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L49

Added line #L49 was not covered by tests
end
6 changes: 5 additions & 1 deletion DifferentiationInterface/src/misc/differentiate_with.jl
Original file line number Diff line number Diff line change
@@ -13,9 +13,13 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be

!!! warning
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).

!!! warning
When using Mooncake as a substitute backend via `DifferentiateWith(f, AutoMooncake())`. The function `f` must not close over any active data.
As of now, we cannot differentiate with respect to parameters stored inside `f`.

# Fields

- `f`: the function in question, with signature `f(x)`
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
using Pkg
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])

using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this file we need to add tests that are specific to Mooncake. Ideally I should have done that with the other backends too.
Basically, what we test now is the projection of the Mooncake rule you wrote onto the subset of stuff that DI cares about. But we should also check that the rule is correct from the Mooncake perspective. Probably the best tool for that is Mooncake.TestUtils.test_rule?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, will check it out

Copy link
Author

@AstitvaAggarwal AstitvaAggarwal May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay so it fails for some of the general primal, functions. But for this PR its maybe okay? as DifferentiateWith is exclusive to DI, so the user is anyways limited to DI when using the Mooncake substitute backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it fails for some of the general primal, functions.

To clarify, these (primal) functions are permitted by the DI interface, right?

Assuming that is true, @gdalle, I think this is okay.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AstitvaAggarwal is it possible to add these tests (excluding those not supported by DI) to this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of them are not, and the ones that are-seem to be present in current tests. but as suggested I'll add them here as well, it would make this rrule better anyways

using Test

LOGGING = get(ENV, "CI", "false") == "false"

function differentiatewith_scenarios()
bad_scens = # these closurified scenarios have mutation and type constraints
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen
filter(
DIT.default_scenarios(; include_normal=false, include_closurified=true)
) do scen
DIT.function_place(scen) == :out
end
good_scens = map(bad_scens) do scen
@@ -22,7 +25,7 @@ function differentiatewith_scenarios()
end

test_differentiation(
[AutoForwardDiff(), AutoZygote()],
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,