-
Notifications
You must be signed in to change notification settings - Fork 24
Backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1a389a6
08b176a
ba0c9e6
1340d92
2ce1ee2
08de6df
84f27c9
2e95299
13233e5
1e8df98
afdddd4
233c312
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}} | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) | ||
primal_func = primal(dw) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if there are derivatives inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. im not sure i understand. dw would be a function and if derivatives are used inside i think that would be handled by the substitute backend? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I mean is that outside of this code, Mooncake may be trying to differentiate something with respect to some parameters that are stored inside |
||
primal_x = primal(x) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
Check warning on line 7 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
return NoRData(), only(tx) | ||
Check warning on line 12 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're not sure that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not all callable objects are functions though. julia> struct Multiplier
a::Float64
end
julia> (m::Multiplier)(x) = m.a * x
julia> m = Multiplier(2)
Multiplier(2.0)
julia> m(3)
6.0
julia> m isa Function
false There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And I'm not even convinced the current Mooncake behavior is correct, see chalk-lab/Mooncake.jl#557 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch, Mooncake treats all callable objects as a function, giving There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my remark above, let's assume for simplicity that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably ensure that |
||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
return NoRData(), only(tx) | ||
Check warning on line 18 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! | ||
Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) | ||
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
fdata_arg = fdata(x.dx) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
Check warning on line 29 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), dy | ||
Check warning on line 35 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), NoRData() | ||
Check warning on line 42 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! | ||
Check warning on line 45 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
using Pkg | ||
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"]) | ||
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) | ||
|
||
using DifferentiationInterface, DifferentiationInterfaceTest | ||
import DifferentiationInterfaceTest as DIT | ||
using FiniteDiff: FiniteDiff | ||
using ForwardDiff: ForwardDiff | ||
using Zygote: Zygote | ||
using Mooncake: Mooncake | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this file we need to add tests that are specific to Mooncake. Ideally I should have done that with the other backends too. |
||
using Test | ||
|
||
LOGGING = get(ENV, "CI", "false") == "false" | ||
|
@@ -24,7 +25,7 @@ function differentiatewith_scenarios() | |
end | ||
|
||
test_differentiation( | ||
[AutoForwardDiff(), AutoZygote()], | ||
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], | ||
differentiatewith_scenarios(); | ||
excluded=SECOND_ORDER, | ||
logging=LOGGING, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update to
chalk-lab