Skip to content

Overloads for LinearProblems with ForwardDiff Dual numbers #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c57c9d8
add LinearSolveForwardDiffExt.jl
jClugstor May 29, 2025
677570f
add partial linsolve
jClugstor Jun 2, 2025
c154d25
fix up the linear dual solution
jClugstor Jun 2, 2025
fc8c4b5
add ForwardDiffExt to project
jClugstor Jun 2, 2025
c419c48
use real solve
jClugstor Jun 2, 2025
d6bddf9
add ForwardDiff as weakdep
jClugstor Jun 3, 2025
81030fa
add imports and fix partial_val
jClugstor Jun 3, 2025
d7c56a8
add test
jClugstor Jun 3, 2025
f706c8f
add tests to runtest
jClugstor Jun 3, 2025
f95b799
format
jClugstor Jun 3, 2025
6352024
rm debug message
jClugstor Jun 3, 2025
e6cda65
use inits and caches
jClugstor Jun 3, 2025
501f07d
rearrange
jClugstor Jun 3, 2025
9aa8b19
format
jClugstor Jun 3, 2025
277c4f8
bring in linalg, add tols to tests
jClugstor Jun 3, 2025
313e286
make sure using nonmutated A
jClugstor Jun 3, 2025
922f7ec
dual cache should have original A and b
jClugstor Jun 3, 2025
3547ec7
rearrange, make sure that dualcache works
jClugstor Jun 5, 2025
9cd4e19
reinit! not needed for now
jClugstor Jun 5, 2025
b2a4291
correct setproperty! for DualLinearCache
jClugstor Jun 5, 2025
680aec6
add tests for updating cache
jClugstor Jun 5, 2025
f9cd2fe
enable dual u0
jClugstor Jun 5, 2025
1b48666
use new_u0
jClugstor Jun 5, 2025
b39ce87
reuse primal cache for Dual computation
jClugstor Jun 5, 2025
e5761c8
redundant line
jClugstor Jun 9, 2025
9b69358
make sure u0 is correct type
jClugstor Jun 9, 2025
f55639a
add tests for iterative and u0
jClugstor Jun 9, 2025
51ce056
make sure that linearcache.b is reset after dual solve
jClugstor Jun 9, 2025
3565d9b
fix test
jClugstor Jun 10, 2025
6f33486
forward steproperty and getproperty more
jClugstor Jun 10, 2025
d05ad09
use correct u0
jClugstor Jun 10, 2025
fb0626f
add test for updating one of A or b
jClugstor Jun 11, 2025
613e9aa
p can be Any
jClugstor Jun 11, 2025
d690f1f
use remake instead
jClugstor Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -53,6 +54,7 @@ LinearSolveCUDSSExt = "CUDSS"
LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveFastLapackInterfaceExt = "FastLapackInterface"
LinearSolveForwardDiffExt = "ForwardDiff"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand Down
241 changes: 241 additions & 0 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
module LinearSolveForwardDiffExt

using LinearSolve
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Partials
using SciMLBase
using RecursiveArrayTools

const DualLinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Any
} where {iip, T, V, P}

const DualALinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing},
iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{Number, <:AbstractArray},
<:Any
} where {iip, T, V, P}

const DualBLinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing},
iip,
<:Union{Number, <:AbstractArray},
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Any
} where {iip, T, V, P}

const DualAbstractLinearProblem = Union{
DualLinearProblem, DualALinearProblem, DualBLinearProblem}

LinearSolve.@concrete mutable struct DualLinearCache
linear_cache
dual_type
partials_A
partials_b
end

function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
# Solve the primal problem
dual_u0 = copy(cache.linear_cache.u)
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
primal_b = copy(cache.linear_cache.b)
uu = sol.u

primal_sol = deepcopy(sol)

# Solves Dual partials separately
∂_A = cache.partials_A
∂_b = cache.partials_b

rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)

partial_cache = cache.linear_cache
partial_cache.u = dual_u0

for i in eachindex(rhs_list)
partial_cache.b = rhs_list[i]
rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u)
end

# Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
partial_cache.b = primal_b

partial_sols = rhs_list

primal_sol, partial_sols
end

function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
A_list = partials_to_list(∂_A)
b_list = partials_to_list(∂_b)

Auu = [A * uu for A in A_list]

return b_list .- Auu
end

function xp_linsolve_rhs(
uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
A_list = partials_to_list(∂_A)

Auu = [A * uu for A in A_list]

return -Auu
end

function xp_linsolve_rhs(
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
b_list = partials_to_list(∂_b)
b_list
end

function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
return solve(prob, nothing, args...; kwargs...)
end

function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
end

function SciMLBase.solve(prob::DualAbstractLinearProblem,
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
solve!(init(prob, alg, args...; kwargs...))
end

function linearsolve_dual_solution(
u::Number, partials, dual_type)
return dual_type(u, partials)
end

function linearsolve_dual_solution(
u::AbstractArray, partials, dual_type)
partials_list = RecursiveArrayTools.VectorOfArray(partials)
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
end

function SciMLBase.init(
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
args...;
alias = LinearAliasSpecifier(),
abstol = LinearSolve.default_tol(real(eltype(prob.b))),
reltol = LinearSolve.default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = nothing,
Pr = nothing,
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)

(; A, b, u0, p) = prob
new_A = nodual_value(A)
new_b = nodual_value(b)
new_u0 = nodual_value(u0)

∂_A = partial_vals(A)
∂_b = partial_vals(b)

#primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0)

if get_dual_type(prob.A) !== nothing
dual_type = get_dual_type(prob.A)
elseif get_dual_type(prob.b) !== nothing
dual_type = get_dual_type(prob.b)
end

non_partial_cache = init(
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
sensealg = sensealg, u0 = new_u0, kwargs...)
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
end

function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
sol,
partials = linearsolve_forwarddiff_solve(
cache::DualLinearCache, cache.alg, args...; kwargs...)

dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
return SciMLBase.build_linear_solution(
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats
)
end

# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
# Also "forwards" setproperty so that
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# If the property is A or b, also update it in the LinearCache
if sym === :A || sym === :b || sym === :u
setproperty!(dc.linear_cache, sym, nodual_value(val))
elseif hasfield(LinearSolve.LinearCache, sym)
setproperty!(dc.linear_cache, sym, val)
end

# Update the partials if setting A or b
if sym === :A
setfield!(dc, :partials_A, partial_vals(val))
elseif sym === :b
setfield!(dc, :partials_b, partial_vals(val))
else
setfield!(dc, sym, val)
end
end

# "Forwards" getproperty to LinearCache if necessary
function Base.getproperty(dc::DualLinearCache, sym::Symbol)
if hasfield(LinearSolve.LinearCache, sym)
return getproperty(dc.linear_cache, sym)
else
return getfield(dc, sym)
end
end



# Helper functions for Dual numbers
get_dual_type(x::Dual) = typeof(x)
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
get_dual_type(x) = nothing

partial_vals(x::Dual) = ForwardDiff.partials(x)
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
partial_vals(x) = nothing

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)


function partials_to_list(partial_matrix::Vector)
p = eachindex(first(partial_matrix))
[[partial[i] for partial in partial_matrix] for i in p]
end

function partials_to_list(partial_matrix)
p = length(first(partial_matrix))
m, n = size(partial_matrix)
res_list = fill(zeros(m, n), p)
for k in 1:p
res = zeros(m, n)
for i in 1:m
for j in 1:n
res[i, j] = partial_matrix[i, j][k]
end
end
res_list[k] = res
end
return res_list
end


end
82 changes: 82 additions & 0 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
using LinearSolve
using ForwardDiff
using Test

function h(p)
(A = [p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2] * p[1]-4;
p[2]^2 9*p[1] p[2]],
b = [p[1] + 1, p[2] * 2, p[1]^2])
end

A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
overload_x_p = solve(prob)
backslash_x_p = A \ b
krylov_overload_x_p = solve(prob, KrylovJL_GMRES())
@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)
@test ≈(krylov_overload_x_p, backslash_x_p, rtol = 1e-9)

krylov_prob = LinearProblem(A, b, u0 = rand(3))
krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES())

@test ≈(krylov_u0_sol, backslash_x_p, rtol = 1e-9)


A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
backslash_x_p = A \ [6.0, 10.0, 25.0]
prob = LinearProblem(A, [6.0, 10.0, 25.0])

@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9)
@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9)

_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0]
backslash_x_p = A \ b
prob = LinearProblem(A, b)

@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9)
@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9)

A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)

new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A
cache.b = new_b

x_p = solve!(cache)
backslash_x_p = new_A \ new_b

@test ≈(x_p, backslash_x_p, rtol = 1e-9)

# Just update A
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)

new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A

x_p = solve!(cache)
backslash_x_p = new_A \ b

@test ≈(x_p, backslash_x_p, rtol = 1e-9)

# Just update b
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)

_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.b = new_b

x_p = solve!(cache)
backslash_x_p = A \ new_b

@test ≈(x_p, backslash_x_p, rtol = 1e-9)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core"
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "ForwardDiff Overloads" include("forwarddiff_overloads.jl")
@time @safetestset "Traits" include("traits.jl")
@time @safetestset "BandedMatrices" include("banded.jl")
end
Expand Down
Loading