diff --git a/ext/DiffEqBaseChainRulesCoreExt.jl b/ext/DiffEqBaseChainRulesCoreExt.jl index 9d44a6bee..2a60d4103 100644 --- a/ext/DiffEqBaseChainRulesCoreExt.jl +++ b/ext/DiffEqBaseChainRulesCoreExt.jl @@ -12,19 +12,19 @@ ChainRulesCore.@non_differentiable DiffEqBase.checkkwargs(kwargshandle) function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob, sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), kwargs...) DiffEqBase._solve_forward( - prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; + prob, sensealg, u0, p, originator, args...; kwargs...) end function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem, sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), kwargs...) DiffEqBase._solve_adjoint( - prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; + prob, sensealg, u0, p, originator, args...; kwargs...) end diff --git a/ext/DiffEqBaseMooncakeExt.jl b/ext/DiffEqBaseMooncakeExt.jl index 7e51ee12d..520d654ad 100644 --- a/ext/DiffEqBaseMooncakeExt.jl +++ b/ext/DiffEqBaseMooncakeExt.jl @@ -2,7 +2,7 @@ module DiffEqBaseMooncakeExt using DiffEqBase, Mooncake using DiffEqBase: SciMLBase -using SciMLBase: ADOriginator, MooncakeOriginator +using SciMLBase: ADOriginator, MooncakeOriginator, ChainRulesOriginator Mooncake.@from_rrule( Mooncake.MinimalCtx, Tuple{ @@ -17,6 +17,6 @@ Mooncake.@from_rrule( ) Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} -Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator +Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ChainRulesOriginator) = MooncakeOriginator() end \ No newline at end of file diff --git a/src/solve.jl b/src/solve.jl index 829618556..8ce794229 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1168,14 +1168,15 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing, p = p !== nothing ? p : prob.p if wrap isa Val{true} - wrap_sol(solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...)) + wrap_sol(solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...)) else - solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...) + solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...) end end function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p, - args...; kwargs...) + args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,