-
Notifications
You must be signed in to change notification settings - Fork 24
Backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1a389a6
08b176a
ba0c9e6
1340d92
2ce1ee2
08de6df
84f27c9
2e95299
13233e5
1e8df98
afdddd4
233c312
7a07127
f3e436d
6a0d937
e543958
2472ecc
c63c956
36da036
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}} | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) | ||
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
Check warning on line 7 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
@assert only(tx) isa rdata_type(typeof(primal_x)) | ||
return NoRData(), only(tx) | ||
Check warning on line 13 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert only(tx) isa rdata_type(typeof(primal_x)) | ||
return NoRData(), only(tx) | ||
Check warning on line 20 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! | ||
Check warning on line 23 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) | ||
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
fdata_arg = fdata(x.dx) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
Check warning on line 31 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), dy | ||
Check warning on line 38 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x))) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), NoRData() | ||
Check warning on line 46 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end | ||
|
||
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!! | ||
Check warning on line 49 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,21 @@ | ||
using Pkg | ||
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"]) | ||
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) | ||
|
||
using DifferentiationInterface, DifferentiationInterfaceTest | ||
import DifferentiationInterfaceTest as DIT | ||
using FiniteDiff: FiniteDiff | ||
using ForwardDiff: ForwardDiff | ||
using Zygote: Zygote | ||
using Mooncake: Mooncake | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this file we need to add tests that are specific to Mooncake. Ideally I should have done that with the other backends too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup, will check it out There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay so it fails for some of the general primal, functions. But for this PR its maybe okay? as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
To clarify, these (primal) functions are permitted by the DI interface, right? Assuming that is true, @gdalle, I think this is okay. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AstitvaAggarwal is it possible to add these tests (excluding those not supported by DI) to this PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Many of them are not, and the ones that are-seem to be present in current tests. but as suggested I'll add them here as well, it would make this rrule better anyways |
||
using Test | ||
|
||
LOGGING = get(ENV, "CI", "false") == "false" | ||
|
||
function differentiatewith_scenarios() | ||
bad_scens = # these closurified scenarios have mutation and type constraints | ||
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen | ||
filter( | ||
DIT.default_scenarios(; include_normal=false, include_closurified=true) | ||
) do scen | ||
DIT.function_place(scen) == :out | ||
end | ||
good_scens = map(bad_scens) do scen | ||
|
@@ -22,7 +25,7 @@ function differentiatewith_scenarios() | |
end | ||
|
||
test_differentiation( | ||
[AutoForwardDiff(), AutoZygote()], | ||
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], | ||
differentiatewith_scenarios(); | ||
excluded=SECOND_ORDER, | ||
logging=LOGGING, | ||
|
Uh oh!
There was an error while loading. Please reload this page.