Skip to content

Backend switching for Mooncake #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/compintell/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to chalk-lab


## Implementations

Original file line number Diff line number Diff line change
@@ -111,4 +111,4 @@ There are, however, translation utilities:
### Backend switch

Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object.
Right now, it only targets ForwardDiff.jl, Mooncake.jl, ChainRules.jl-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object.
Original file line number Diff line number Diff line change
@@ -3,14 +3,21 @@ module DifferentiationInterfaceMooncakeExt
using ADTypes: ADTypes, AutoMooncake
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
prepare_gradient_cache,
prepare_pullback_cache,
tangent_type,
value_and_gradient!!,
value_and_pullback!!,
zero_tangent
zero_tangent,
@is_primitive,
zero_fcodual,
MinimalCtx,
NoRData,
fdata,
primal

DI.check_available(::AutoMooncake) = true

@@ -26,5 +33,6 @@ mycopy(x) = deepcopy(x)

include("onearg.jl")
include("twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}}

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
primal_func = primal(dw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are derivatives inside dw?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure i understand. dw would be a function and if derivatives are used inside i think that would be handled by the substitute backend?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that outside of this code, Mooncake may be trying to differentiate something with respect to some parameters that are stored inside dw. We could exclude it, and say that functions passed to DifferentiateWith must not close over any active data (that is what I implicitly assume for the chain rule). But let's document it then

primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 7 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L3-L7

Added lines #L3 - L7 were not covered by tests

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
return NoRData(), only(tx)

Check warning on line 12 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L10-L12

Added lines #L10 - L12 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not sure that f has no RData

Copy link
Author

@AstitvaAggarwal AstitvaAggarwal Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
I think all functions would have no RData

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all callable objects are functions though.

julia> struct Multiplier
           a::Float64
       end

julia> (m::Multiplier)(x) = m.a * x

julia> m = Multiplier(2)
Multiplier(2.0)

julia> m(3)
6.0

julia> m isa Function
false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I'm not even convinced the current Mooncake behavior is correct, see chalk-lab/Mooncake.jl#557

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, Mooncake treats all callable objects as a function, giving NoRData(). The issue you have opened in mooncake, i think would need to be figured out independent of DI (atleast this PR) as of now (comes in when Mooncake is a substitute backend as well).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my remark above, let's assume for simplicity that f contains no differentiated data, but with a corresponding warning in the docstring of DifferentiateWith

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably ensure that only(tx) isa rdata_type(typeof(x)), and at least error otherwise. For simple numbers and arrays it holds, but not in general

end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
return NoRData(), only(tx)

Check warning on line 18 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
end

return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!

Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L21

Added line #L21 was not covered by tests
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = fdata(x.dx)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 29 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L24-L29

Added lines #L24 - L29 were not covered by tests

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
fdata_arg .+= only(tx)
return NoRData(), dy

Check warning on line 35 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L32-L35

Added lines #L32 - L35 were not covered by tests
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
fdata_arg .+= only(tx)
return NoRData(), NoRData()

Check warning on line 42 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L39-L42

Added lines #L39 - L42 were not covered by tests
end

return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!

Check warning on line 45 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L45

Added line #L45 was not covered by tests
end
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using Pkg
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])

using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this file we need to add tests that are specific to Mooncake. Ideally I should have done that with the other backends too.
Basically, what we test now is the projection of the Mooncake rule you wrote onto the subset of stuff that DI cares about. But we should also check that the rule is correct from the Mooncake perspective. Probably the best tool for that is Mooncake.TestUtils.test_rule?

using Test

LOGGING = get(ENV, "CI", "false") == "false"
@@ -24,7 +25,7 @@ function differentiatewith_scenarios()
end

test_differentiation(
[AutoForwardDiff(), AutoZygote()],
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,