Skip to content

Commit eed6869

Browse files
Merge pull request #621 from jClugstor/forwarddiff_overloads
Overloads for LinearProblems with ForwardDiff Dual numbers
2 parents a65fb46 + d690f1f commit eed6869

File tree

4 files changed

+326
-0
lines changed

4 files changed

+326
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
3535
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3636
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
3737
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
38+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3839
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3940
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
4041
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
@@ -53,6 +54,7 @@ LinearSolveCUDSSExt = "CUDSS"
5354
LinearSolveEnzymeExt = "EnzymeCore"
5455
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
5556
LinearSolveFastLapackInterfaceExt = "FastLapackInterface"
57+
LinearSolveForwardDiffExt = "ForwardDiff"
5658
LinearSolveHYPREExt = "HYPRE"
5759
LinearSolveIterativeSolversExt = "IterativeSolvers"
5860
LinearSolveKernelAbstractionsExt = "KernelAbstractions"

ext/LinearSolveForwardDiffExt.jl

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
module LinearSolveForwardDiffExt
2+
3+
using LinearSolve
4+
using LinearAlgebra
5+
using ForwardDiff
6+
using ForwardDiff: Dual, Partials
7+
using SciMLBase
8+
using RecursiveArrayTools
9+
10+
const DualLinearProblem = LinearProblem{
11+
<:Union{Number, <:AbstractArray, Nothing}, iip,
12+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
13+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
14+
<:Any
15+
} where {iip, T, V, P}
16+
17+
const DualALinearProblem = LinearProblem{
18+
<:Union{Number, <:AbstractArray, Nothing},
19+
iip,
20+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
21+
<:Union{Number, <:AbstractArray},
22+
<:Any
23+
} where {iip, T, V, P}
24+
25+
const DualBLinearProblem = LinearProblem{
26+
<:Union{Number, <:AbstractArray, Nothing},
27+
iip,
28+
<:Union{Number, <:AbstractArray},
29+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
30+
<:Any
31+
} where {iip, T, V, P}
32+
33+
const DualAbstractLinearProblem = Union{
34+
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
35+
36+
LinearSolve.@concrete mutable struct DualLinearCache
37+
linear_cache
38+
dual_type
39+
partials_A
40+
partials_b
41+
end
42+
43+
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
44+
# Solve the primal problem
45+
dual_u0 = copy(cache.linear_cache.u)
46+
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
47+
primal_b = copy(cache.linear_cache.b)
48+
uu = sol.u
49+
50+
primal_sol = deepcopy(sol)
51+
52+
# Solves Dual partials separately
53+
∂_A = cache.partials_A
54+
∂_b = cache.partials_b
55+
56+
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
57+
58+
partial_cache = cache.linear_cache
59+
partial_cache.u = dual_u0
60+
61+
for i in eachindex(rhs_list)
62+
partial_cache.b = rhs_list[i]
63+
rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u)
64+
end
65+
66+
# Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
67+
partial_cache.b = primal_b
68+
69+
partial_sols = rhs_list
70+
71+
primal_sol, partial_sols
72+
end
73+
74+
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
75+
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
76+
A_list = partials_to_list(∂_A)
77+
b_list = partials_to_list(∂_b)
78+
79+
Auu = [A * uu for A in A_list]
80+
81+
return b_list .- Auu
82+
end
83+
84+
function xp_linsolve_rhs(
85+
uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
86+
A_list = partials_to_list(∂_A)
87+
88+
Auu = [A * uu for A in A_list]
89+
90+
return -Auu
91+
end
92+
93+
function xp_linsolve_rhs(
94+
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
95+
b_list = partials_to_list(∂_b)
96+
b_list
97+
end
98+
99+
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
100+
return solve(prob, nothing, args...; kwargs...)
101+
end
102+
103+
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
104+
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
105+
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
106+
end
107+
108+
function SciMLBase.solve(prob::DualAbstractLinearProblem,
109+
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
110+
solve!(init(prob, alg, args...; kwargs...))
111+
end
112+
113+
function linearsolve_dual_solution(
114+
u::Number, partials, dual_type)
115+
return dual_type(u, partials)
116+
end
117+
118+
function linearsolve_dual_solution(
119+
u::AbstractArray, partials, dual_type)
120+
partials_list = RecursiveArrayTools.VectorOfArray(partials)
121+
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
122+
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
123+
end
124+
125+
function SciMLBase.init(
126+
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
127+
args...;
128+
alias = LinearAliasSpecifier(),
129+
abstol = LinearSolve.default_tol(real(eltype(prob.b))),
130+
reltol = LinearSolve.default_tol(real(eltype(prob.b))),
131+
maxiters::Int = length(prob.b),
132+
verbose::Bool = false,
133+
Pl = nothing,
134+
Pr = nothing,
135+
assumptions = OperatorAssumptions(issquare(prob.A)),
136+
sensealg = LinearSolveAdjoint(),
137+
kwargs...)
138+
139+
(; A, b, u0, p) = prob
140+
new_A = nodual_value(A)
141+
new_b = nodual_value(b)
142+
new_u0 = nodual_value(u0)
143+
144+
∂_A = partial_vals(A)
145+
∂_b = partial_vals(b)
146+
147+
#primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
148+
primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0)
149+
150+
if get_dual_type(prob.A) !== nothing
151+
dual_type = get_dual_type(prob.A)
152+
elseif get_dual_type(prob.b) !== nothing
153+
dual_type = get_dual_type(prob.b)
154+
end
155+
156+
non_partial_cache = init(
157+
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
158+
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
159+
sensealg = sensealg, u0 = new_u0, kwargs...)
160+
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
161+
end
162+
163+
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
164+
sol,
165+
partials = linearsolve_forwarddiff_solve(
166+
cache::DualLinearCache, cache.alg, args...; kwargs...)
167+
168+
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
169+
return SciMLBase.build_linear_solution(
170+
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats
171+
)
172+
end
173+
174+
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
175+
# Also "forwards" setproperty so that
176+
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
177+
# If the property is A or b, also update it in the LinearCache
178+
if sym === :A || sym === :b || sym === :u
179+
setproperty!(dc.linear_cache, sym, nodual_value(val))
180+
elseif hasfield(LinearSolve.LinearCache, sym)
181+
setproperty!(dc.linear_cache, sym, val)
182+
end
183+
184+
# Update the partials if setting A or b
185+
if sym === :A
186+
setfield!(dc, :partials_A, partial_vals(val))
187+
elseif sym === :b
188+
setfield!(dc, :partials_b, partial_vals(val))
189+
else
190+
setfield!(dc, sym, val)
191+
end
192+
end
193+
194+
# "Forwards" getproperty to LinearCache if necessary
195+
function Base.getproperty(dc::DualLinearCache, sym::Symbol)
196+
if hasfield(LinearSolve.LinearCache, sym)
197+
return getproperty(dc.linear_cache, sym)
198+
else
199+
return getfield(dc, sym)
200+
end
201+
end
202+
203+
204+
205+
# Helper functions for Dual numbers
206+
get_dual_type(x::Dual) = typeof(x)
207+
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
208+
get_dual_type(x) = nothing
209+
210+
partial_vals(x::Dual) = ForwardDiff.partials(x)
211+
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
212+
partial_vals(x) = nothing
213+
214+
nodual_value(x) = x
215+
nodual_value(x::Dual) = ForwardDiff.value(x)
216+
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
217+
218+
219+
function partials_to_list(partial_matrix::Vector)
220+
p = eachindex(first(partial_matrix))
221+
[[partial[i] for partial in partial_matrix] for i in p]
222+
end
223+
224+
function partials_to_list(partial_matrix)
225+
p = length(first(partial_matrix))
226+
m, n = size(partial_matrix)
227+
res_list = fill(zeros(m, n), p)
228+
for k in 1:p
229+
res = zeros(m, n)
230+
for i in 1:m
231+
for j in 1:n
232+
res[i, j] = partial_matrix[i, j][k]
233+
end
234+
end
235+
res_list[k] = res
236+
end
237+
return res_list
238+
end
239+
240+
241+
end

test/forwarddiff_overloads.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using LinearSolve
2+
using ForwardDiff
3+
using Test
4+
5+
function h(p)
6+
(A = [p[1] p[2]+1 p[2]^3;
7+
3*p[1] p[1]+5 p[2] * p[1]-4;
8+
p[2]^2 9*p[1] p[2]],
9+
b = [p[1] + 1, p[2] * 2, p[1]^2])
10+
end
11+
12+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
13+
14+
prob = LinearProblem(A, b)
15+
overload_x_p = solve(prob)
16+
backslash_x_p = A \ b
17+
krylov_overload_x_p = solve(prob, KrylovJL_GMRES())
18+
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
19+
@test (krylov_overload_x_p, backslash_x_p, rtol = 1e-9)
20+
21+
krylov_prob = LinearProblem(A, b, u0 = rand(3))
22+
krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES())
23+
24+
@test (krylov_u0_sol, backslash_x_p, rtol = 1e-9)
25+
26+
27+
A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
28+
backslash_x_p = A \ [6.0, 10.0, 25.0]
29+
prob = LinearProblem(A, [6.0, 10.0, 25.0])
30+
31+
@test (solve(prob).u, backslash_x_p, rtol = 1e-9)
32+
@test (solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9)
33+
34+
_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
35+
A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0]
36+
backslash_x_p = A \ b
37+
prob = LinearProblem(A, b)
38+
39+
@test (solve(prob).u, backslash_x_p, rtol = 1e-9)
40+
@test (solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9)
41+
42+
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])
43+
44+
prob = LinearProblem(A, b)
45+
cache = init(prob)
46+
47+
new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
48+
cache.A = new_A
49+
cache.b = new_b
50+
51+
x_p = solve!(cache)
52+
backslash_x_p = new_A \ new_b
53+
54+
@test (x_p, backslash_x_p, rtol = 1e-9)
55+
56+
# Just update A
57+
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])
58+
59+
prob = LinearProblem(A, b)
60+
cache = init(prob)
61+
62+
new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
63+
cache.A = new_A
64+
65+
x_p = solve!(cache)
66+
backslash_x_p = new_A \ b
67+
68+
@test (x_p, backslash_x_p, rtol = 1e-9)
69+
70+
# Just update b
71+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
72+
73+
prob = LinearProblem(A, b)
74+
cache = init(prob)
75+
76+
_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
77+
cache.b = new_b
78+
79+
x_p = solve!(cache)
80+
backslash_x_p = A \ new_b
81+
82+
@test (x_p, backslash_x_p, rtol = 1e-9)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core"
1616
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
1717
@time @safetestset "Default Alg Tests" include("default_algs.jl")
1818
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
19+
@time @safetestset "ForwardDiff Overloads" include("forwarddiff_overloads.jl")
1920
@time @safetestset "Traits" include("traits.jl")
2021
@time @safetestset "BandedMatrices" include("banded.jl")
2122
end

0 commit comments

Comments
 (0)