Skip to content

Commit 13ed275

Browse files
Merge pull request #520 from SciML/noncrete_solve
Setup solve for adjoints to deprecate concrete_solve
2 parents 00ad2e3 + a65c5ce commit 13ed275

File tree

12 files changed

+218
-115
lines changed

12 files changed

+218
-115
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.35.2"
4+
version = "6.36.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/interpolation.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,41 @@ struct ConstantInterpolation{T1,T2} <: AbstractDiffEqInterpolation
2323
u::T2
2424
end
2525

26+
"""
27+
$(TYPEDEF)
28+
"""
29+
struct SensitivityInterpolation{T1,T2} <: AbstractDiffEqInterpolation
30+
t::T1
31+
u::T2
32+
end
33+
2634
interp_summary(::AbstractDiffEqInterpolation) = "Unknown"
2735
interp_summary(::HermiteInterpolation) = "3rd order Hermite"
2836
interp_summary(::LinearInterpolation) = "1st order linear"
2937
interp_summary(::ConstantInterpolation) = "Piecewise constant interpolation"
3038
interp_summary(::Nothing) = "No interpolation"
39+
interp_summary(::SensitivityInterpolation) = "Interpolation disabled due to sensitivity analysis"
3140
interp_summary(sol::DESolution) = interp_summary(sol.interp)
3241

42+
const SENSITIVITY_INTERP_MESSAGE =
43+
"""
44+
Standard interpolation is disabled due to sensitivity analysis being
45+
used for the gradients. Only linear and constant interpolations are
46+
compatible with non-AD sensitivity analysis calculations. Either
47+
utilize tooling like saveat to avoid post-solution interpolation, use
48+
the keyword argument dense=false for linear or constant interpolations,
49+
or use the keyword argument sensealg=SensitivityADPassThrough() to revert
50+
to AD-based derivatives.
51+
"""
52+
3353
(id::HermiteInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
3454
(id::HermiteInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
3555
(id::LinearInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
3656
(id::LinearInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
3757
(id::ConstantInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
3858
(id::ConstantInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
59+
(id::SensitivityInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
60+
(id::SensitivityInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
3961

4062
@inline function interpolation(tvals,id,idxs,deriv,p,continuity::Symbol=:left)
4163
t = id.t; u = id.u
@@ -72,6 +94,7 @@ interp_summary(sol::DESolution) = interp_summary(sol.interp)
7294
vals[j] = u[i-1][idxs]
7395
end
7496
else
97+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
7598
dt = t[i] - t[i-1]
7699
Θ = (tval-t[i-1])/dt
77100
idxs_internal = idxs
@@ -119,6 +142,7 @@ times t (sorted), with values u and derivatives ks
119142
vals[j] = u[i-1][idxs]
120143
end
121144
else
145+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
122146
dt = t[i] - t[i-1]
123147
Θ = (tval-t[i-1])/dt
124148
idxs_internal = idxs
@@ -169,6 +193,7 @@ times t (sorted), with values u and derivatives ks
169193
val = u[i-1][idxs]
170194
end
171195
else
196+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
172197
dt = t[i] - t[i-1]
173198
Θ = (tval-t[i-1])/dt
174199
idxs_internal = idxs
@@ -211,6 +236,7 @@ times t (sorted), with values u and derivatives ks
211236
copy!(out,u[i-1][idxs])
212237
end
213238
else
239+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
214240
dt = t[i] - t[i-1]
215241
Θ = (tval-t[i-1])/dt
216242
idxs_internal = idxs

src/reversediff.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;
2-
sensealg=nothing,kwargs...)
3-
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
1+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;kwargs...)
2+
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
43
end
54

6-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;
7-
sensealg=nothing,kwargs...)
8-
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
5+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;kwargs...)
6+
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
97
end
108

11-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;
12-
sensealg=nothing,kwargs...)
13-
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
9+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;kwargs...)
10+
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
1411
end
1512

16-
ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...;
17-
sensealg=nothing,kwargs...)
18-
out = _concrete_solve_adjoint(prob,alg,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
13+
ReverseDiff.@grad function solve_up(prob,sensealg,u0,p,args...;kwargs...)
14+
out = _solve_adjoint(prob,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
1915
Array(out[1]),out[2]
2016
end

src/solutions/ode_solutions.jl

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,42 @@ function solution_new_retcode(sol::AbstractODESolution{T,N},retcode) where {T,N}
100100
sol.alg,sol.interp,sol.dense,sol.tslocation,sol.destats,retcode)
101101
end
102102

103-
function solution_new_tslocation(sol::AbstractODESolution{T,N},tslocation) where {T,N}
104-
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
105-
typeof(sol.t),typeof(sol.k),
106-
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
107-
sol.u,sol.u_analytic,sol.errors,sol.t,sol.k,sol.prob,
108-
sol.alg,sol.interp,sol.dense,tslocation,sol.destats,sol.retcode)
103+
function solution_new_tslocation(sol::AbstractODESolution{T,N},tslocation) where {T,N}
104+
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
105+
typeof(sol.t),typeof(sol.k),
106+
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
107+
sol.u,sol.u_analytic,sol.errors,sol.t,sol.k,sol.prob,
108+
sol.alg,sol.interp,sol.dense,tslocation,sol.destats,sol.retcode)
109+
end
110+
111+
function solution_slice(sol::AbstractODESolution{T,N},I) where {T,N}
112+
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
113+
typeof(sol.t),typeof(sol.k),
114+
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
115+
sol.u[I],
116+
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
117+
sol.errors,sol.t[I],
118+
sol.dense ? sol.k[I] : sol.k,
119+
sol.prob,
120+
sol.alg,sol.interp,false,sol.tslocation,sol.destats,sol.retcode)
121+
end
122+
123+
function sensitivity_solution(sol::AbstractODESolution,u,t)
124+
T = eltype(eltype(u))
125+
N = length((size(sol.prob.u0)..., length(u)))
126+
interp = if typeof(sol.interp) <: LinearInterpolation
127+
LinearInterpolation(t,u)
128+
elseif typeof(sol.interp) <: ConstantInterpolation
129+
ConstantInterpolation(t,u)
130+
else
131+
SensitivityInterpolation(t,u)
109132
end
110133

111-
function solution_slice(sol::AbstractODESolution{T,N},I) where {T,N}
112-
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
113-
typeof(sol.t),typeof(sol.k),
114-
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
115-
sol.u[I],
116-
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
117-
sol.errors,sol.t[I],
118-
sol.dense ? sol.k[I] : sol.k,
119-
sol.prob,
120-
sol.alg,sol.interp,false,sol.tslocation,sol.destats,sol.retcode)
121-
end
134+
ODESolution{T,N,typeof(u),typeof(sol.u_analytic),typeof(sol.errors),
135+
typeof(t),Nothing,typeof(sol.prob),typeof(sol.alg),
136+
typeof(interp),typeof(sol.destats)}(
137+
u,sol.u_analytic,sol.errors,t,nothing,sol.prob,
138+
sol.alg,interp,
139+
sol.dense,sol.tslocation,
140+
sol.destats,sol.retcode)
141+
end

src/solutions/rode_solutions.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,24 @@ function solution_slice(sol::AbstractRODESolution{T,N},I) where {T,N}
120120
false,sol.tslocation,sol.destats,
121121
sol.retcode,sol.seed)
122122
end
123+
124+
function sensitivity_solution(sol::AbstractRODESolution,u,t)
125+
T = eltype(eltype(u))
126+
N = length((size(sol.prob.u0)..., length(u)))
127+
interp = if typeof(sol.interp) <: LinearInterpolation
128+
LinearInterpolation(t,u)
129+
elseif typeof(sol.interp) <: ConstantInterpolation
130+
ConstantInterpolation(t,u)
131+
else
132+
SensitivityInterpolation(t,u)
133+
end
134+
135+
RODESolution{T,N,typeof(u),typeof(sol.u_analytic),
136+
typeof(sol.errors),typeof(t),
137+
typeof(nothing),typeof(sol.prob),typeof(sol.alg),
138+
typeof(sol.interp),typeof(sol.destats)}(
139+
u,sol.u_analytic,sol.errors,t,nothing,sol.prob,
140+
sol.alg,sol.interp,
141+
sol.dense,sol.tslocation,sol.destats,
142+
sol.retcode,sol.seed)
143+
end

src/solutions/steady_state_solutions.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ function build_solution(prob::AbstractSteadyStateProblem,
1818

1919
SteadyStateSolution{T,N,typeof(u),typeof(resid),typeof(prob),typeof(alg)}(u,resid,prob,alg,retcode)
2020
end
21+
22+
function sensitivity_solution(sol::AbstractSteadyStateSolution,u)
23+
T = eltype(eltype(u))
24+
N = length((size(sol.prob.u0)...,))
25+
26+
SteadyStateSolution{T,N,typeof(u),typeof(sol.resid),
27+
typeof(sol.prob),typeof(sol.alg)}(
28+
u,sol.resid,sol.prob,sol.alg,sol.retcode)
29+
end

0 commit comments

Comments
 (0)