From c57c9d8fe85ac200fa73ac83e8140a3562c85971 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 29 May 2025 16:48:24 -0400 Subject: [PATCH 01/34] add LinearSolveForwardDiffExt.jl --- ext/LinearSolveForwardDiffExt.jl | 123 +++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 ext/LinearSolveForwardDiffExt.jl diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl new file mode 100644 index 000000000..2cf36809d --- /dev/null +++ b/ext/LinearSolveForwardDiffExt.jl @@ -0,0 +1,123 @@ +module LinearSolveForwardDiffExt + +const DualLinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}}, + <:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}}, + <:Union{Number, <:AbstractArray} +} where {iip, T, V} + + +const DualALinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray}, + iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{Number, <:AbstractArray}, + <:Union{Number, <:AbstractArray} +} + +const DualBLinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray}, + iip, + <:Union{Number, <:AbstractArray}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{Number, <:AbstractArray} +} + +const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem} + + +function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...) + new_A = nodual_value(prob.A) + new_b = nodual_value(prob.b) + + newprob = remake(prob; A = new_A, b = new_b) + + sol = solve(newprob, alg, args...; kwargs...) + uu = sol.u + + ∂_A = partial_vals(A) + ∂_b = partial_vals(b) + + + + if uu isa Number + + else + + end + +end + + + +partial_vals(x::Dual) = ForwardDiff.partials(x) +partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) +partial_vals(x) = nothing + +nodual_value(x) = x +nodual_value(x::Dual) = ForwardDiff.value(x) +nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + + +function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) + A_list = partials_to_list(∂_A) + b_list = partials_to_list(∂_b) + + Auu = [A*uu for A in A_list] + + linsol_rhs = reduce(hcat, b_list .- Auu) + + new_A \ linsol_rhs +end + +function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) + A_list = partials_to_list(∂_A) + + Auu = [A*uu for A in A_list] + + linsol_rhs = reduce(hcat, Auu) + + new_A \ linsol_rhs +end + +function x_p_linsolve(new_A, uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) + b_list = partials_to_list(∂_b) + + linsol_rhs = reduce(hcat, b_list) + + new_A \ linsol_rhs +end + + + +function partials_to_list(partial_matrix::Vector) + p = eachindex(first(partial_matrix)) + [[partial[i] for partial in partial_matrix] for i in p] +end + +function partials_to_list(partial_matrix) + p = length(first(partial_matrix)) + m,n = size(partial_matrix) + res_list = fill(zeros(m,n),p) + for k in 1:p + res = zeros(m,n) + for i in 1:m + for j in 1:n + res[i,j] = partial_matrix[i,j][k] + end + end + res_list[k] = res + end + return res_list +end + + + + + + + + + + From 677570f42055adc86e68b328e7802c0112174394 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 2 Jun 2025 10:02:46 -0400 Subject: [PATCH 02/34] add partial linsolve --- ext/LinearSolveForwardDiffExt.jl | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 2cf36809d..866e21aea 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -39,14 +39,12 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs ∂_A = partial_vals(A) ∂_b = partial_vals(b) - + rhs = xp_linsolve_rhs(uu, ∂_A, ∂_b) - if uu isa Number - - else - - end + partial_prob = remake(newprob, b = rhs) + partial_sol = solve(partial_prob, alg, args...; kwargs...) + sol, partial_sol end @@ -60,33 +58,27 @@ nodual_value(x::Dual) = ForwardDiff.value(x) nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) -function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) +function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) A_list = partials_to_list(∂_A) b_list = partials_to_list(∂_b) Auu = [A*uu for A in A_list] - linsol_rhs = reduce(hcat, b_list .- Auu) - - new_A \ linsol_rhs + reduce(hcat, b_list .- Auu) end -function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) +function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) A_list = partials_to_list(∂_A) Auu = [A*uu for A in A_list] - linsol_rhs = reduce(hcat, Auu) - - new_A \ linsol_rhs + reduce(hcat, Auu) end -function x_p_linsolve(new_A, uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) +function xp_linsolve_rhs(uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) b_list = partials_to_list(∂_b) - linsol_rhs = reduce(hcat, b_list) - - new_A \ linsol_rhs + reduce(hcat, b_list) end From c154d25b80cb7215a82973709c678588c12ee360 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 2 Jun 2025 15:16:56 -0400 Subject: [PATCH 03/34] fix up the linear dual solution --- ext/LinearSolveForwardDiffExt.jl | 85 ++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 25 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 866e21aea..5c64be08a 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,56 +1,91 @@ module LinearSolveForwardDiffExt const DualLinearProblem = LinearProblem{ - <:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}}, - <:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}}, - <:Union{Number, <:AbstractArray} -} where {iip, T, V} + <:Union{Number,<:AbstractArray, Nothing},iip, + <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, + <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, + <:Union{Number,<:AbstractArray, SciMLBase.NullParameters} +} where {iip, T, V, P} const DualALinearProblem = LinearProblem{ - <:Union{Number, <:AbstractArray}, + <:Union{Number,<:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, - <:Union{Number, <:AbstractArray}, - <:Union{Number, <:AbstractArray} -} + <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, + <:Union{Number,<:AbstractArray}, + <:Union{Number,<:AbstractArray, SciMLBase.NullParameters} +} where {iip, T, V, P} const DualBLinearProblem = LinearProblem{ - <:Union{Number, <:AbstractArray}, + <:Union{Number,<:AbstractArray, Nothing}, iip, - <:Union{Number, <:AbstractArray}, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, - <:Union{Number, <:AbstractArray} -} + <:Union{Number,<:AbstractArray}, + <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, + <:Union{Number,<:AbstractArray, SciMLBase.NullParameters} +} where {iip, T, V, P} const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem} - function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...) new_A = nodual_value(prob.A) new_b = nodual_value(prob.b) - newprob = remake(prob; A = new_A, b = new_b) + newprob = remake(prob; A=new_A, b=new_b) sol = solve(newprob, alg, args...; kwargs...) uu = sol.u + + # Solves Dual partials separately ∂_A = partial_vals(A) ∂_b = partial_vals(b) - rhs = xp_linsolve_rhs(uu, ∂_A, ∂_b) + rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - partial_prob = remake(newprob, b = rhs) - partial_sol = solve(partial_prob, alg, args...; kwargs...) + partial_sols = map(rhs_list) do rhs + partial_prob = remake(newprob, b=rhs) + solve(partial_prob, alg, args...; kwargs...).u + end - sol, partial_sol + sol, partial_sols end +function __solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...) + sol, partials = linearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + + if get_dual_type(prob.A) !== nothing + dual_type = get_dual_type(prob.A) + elseif get_dual_type(prob.b) !== nothing + dual_type = get_dual_type(prob.b) + return sol + end + + linearsolve_dual_solution(sol.u, partials, dual_type) + +end + + +function linearsolve_dual_solution( + u::Number, partials, dual_type) + return dual_type(u, partials) +end + +function linearsolve_dual_solution( + u::AbstractArray, partials, dual_type) + partials_list = RecursiveArrayTools.VectorOfArray(partials) + return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) +end + + +get_dual_type(x::Dual) = typeof(x) +get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) +get_dual_type(x) = nothing partial_vals(x::Dual) = ForwardDiff.partials(x) -partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) +partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) partial_vals(x) = nothing nodual_value(x) = x @@ -64,7 +99,7 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials Auu = [A*uu for A in A_list] - reduce(hcat, b_list .- Auu) + b_list .- Auu end function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) @@ -72,13 +107,13 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials Auu = [A*uu for A in A_list] - reduce(hcat, Auu) + Auu end function xp_linsolve_rhs(uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) b_list = partials_to_list(∂_b) - reduce(hcat, b_list) + b_list end From fc8c4b59ce9927937e80acf9f90175abf5d049dc Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 2 Jun 2025 15:19:30 -0400 Subject: [PATCH 04/34] add ForwardDiffExt to project --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 96e7145af..43e4a84f4 100644 --- a/Project.toml +++ b/Project.toml @@ -53,6 +53,7 @@ LinearSolveCUDSSExt = "CUDSS" LinearSolveEnzymeExt = "EnzymeCore" LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices" LinearSolveFastLapackInterfaceExt = "FastLapackInterface" +LinearSolveForwardDiffExt = "ForwardDiff" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" From c419c4828c59b19e6ee607a9bbfc451cef9dccd9 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 2 Jun 2025 15:56:10 -0400 Subject: [PATCH 05/34] use real solve --- ext/LinearSolveForwardDiffExt.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 5c64be08a..4e333e7fe 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -50,7 +50,16 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs sol, partial_sols end -function __solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...) +function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) + return solve(prob, nothing, args...; kwargs...) +end + +function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; + assump = OperatorAssumptions(issquare(prob.A)), kwargs...) + return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...) +end + +function SciMLBase.solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...) sol, partials = linearsolve_forwarddiff_solve( prob, alg, args...; kwargs... ) @@ -59,10 +68,14 @@ function __solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...) dual_type = get_dual_type(prob.A) elseif get_dual_type(prob.b) !== nothing dual_type = get_dual_type(prob.b) - return sol end - linearsolve_dual_solution(sol.u, partials, dual_type) + dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type) + + return SciMLBase.build_linear_solution( + alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats + ) + end From d6bddf9577e40e4e60c82b48fa51217c808017bc Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 2 Jun 2025 21:47:59 -0400 Subject: [PATCH 06/34] add ForwardDiff as weakdep --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 43e4a84f4..f4fcc2604 100644 --- a/Project.toml +++ b/Project.toml @@ -35,6 +35,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" From 81030fa795167a42fd3f1d86cd2c31ef5191d99f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 2 Jun 2025 21:48:17 -0400 Subject: [PATCH 07/34] add imports and fix partial_val --- ext/LinearSolveForwardDiffExt.jl | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 4e333e7fe..5f04b6d57 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,5 +1,11 @@ module LinearSolveForwardDiffExt +using LinearSolve +using ForwardDiff +using ForwardDiff: Dual, Partials +using SciMLBase +using RecursiveArrayTools + const DualLinearProblem = LinearProblem{ <:Union{Number,<:AbstractArray, Nothing},iip, <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, @@ -27,6 +33,7 @@ const DualBLinearProblem = LinearProblem{ const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem} function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...) + @info "here!" new_A = nodual_value(prob.A) new_b = nodual_value(prob.b) @@ -37,8 +44,8 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs # Solves Dual partials separately - ∂_A = partial_vals(A) - ∂_b = partial_vals(b) + ∂_A = partial_vals(prob.A) + ∂_b = partial_vals(prob.b) rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) @@ -56,10 +63,10 @@ end function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; assump = OperatorAssumptions(issquare(prob.A)), kwargs...) - return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...) + return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) end -function SciMLBase.solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...) +function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) sol, partials = linearsolve_forwarddiff_solve( prob, alg, args...; kwargs... ) @@ -152,7 +159,7 @@ function partials_to_list(partial_matrix) return res_list end - +end From d7c56a82d4b3c8476f0b466ecbd771ac6b00d2ba Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 10:22:30 -0400 Subject: [PATCH 08/34] add test --- test/forwarddiff_overloads.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 test/forwarddiff_overloads.jl diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl new file mode 100644 index 000000000..11ac684b2 --- /dev/null +++ b/test/forwarddiff_overloads.jl @@ -0,0 +1,17 @@ +using LinearSolve +using ForwardDiff + + +function h(p) + (A = [p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2] * p[1]-4; + p[2]^2 9*p[1] p[2]], + b = [p[1] + 1, p[2] * 2, p[1]^2]) +end + +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +solve(prob) + + From f706c8f1691da67e2f4b211171a24b1ea54cd66e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 11:00:29 -0400 Subject: [PATCH 09/34] add tests to runtest --- test/forwarddiff_overloads.jl | 26 +++++++++++++++++++++----- test/runtests.jl | 1 + 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 11ac684b2..114783cf2 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -1,17 +1,33 @@ using LinearSolve using ForwardDiff +using Test function h(p) - (A = [p[1] p[2]+1 p[2]^3; - 3*p[1] p[1]+5 p[2] * p[1]-4; - p[2]^2 9*p[1] p[2]], - b = [p[1] + 1, p[2] * 2, p[1]^2]) + (A=[p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2]*p[1]-4; + p[2]^2 9*p[1] p[2]], + b=[p[1] + 1, p[2] * 2, p[1]^2]) end A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) -solve(prob) +overload_x_p = solve(prob) +original_x_p = solve!(init(prob)) + +@test overload_x_p ≈ original_x_p + + +A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +prob = LinearProblem(A, [6.0, 10.0, 25.0]) +@test solve(prob).retcode == ReturnCode.Default + +_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0] +prob = LinearProblem(A,b) +@test solve(prob).retcode == ReturnCode.Default + + diff --git a/test/runtests.jl b/test/runtests.jl index 0d994f787..2133bcd4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") + @time @safetestset "ForwardDiff Overloads" include("forwarddiff_overloads.jl") @time @safetestset "Traits" include("traits.jl") @time @safetestset "BandedMatrices" include("banded.jl") end From f95b799fabbf211fb3146a86ab5a8438ff6662e0 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 11:07:50 -0400 Subject: [PATCH 10/34] format --- ext/LinearSolveForwardDiffExt.jl | 88 ++++++++++++++------------------ test/forwarddiff_overloads.jl | 20 +++----- 2 files changed, 45 insertions(+), 63 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 5f04b6d57..8a95eaf54 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,4 +1,4 @@ -module LinearSolveForwardDiffExt +module LinearSolveForwardDiffExt using LinearSolve using ForwardDiff @@ -7,42 +7,41 @@ using SciMLBase using RecursiveArrayTools const DualLinearProblem = LinearProblem{ - <:Union{Number,<:AbstractArray, Nothing},iip, - <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, - <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, - <:Union{Number,<:AbstractArray, SciMLBase.NullParameters} + <:Union{Number, <:AbstractArray, Nothing}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{Number, <:AbstractArray, SciMLBase.NullParameters} } where {iip, T, V, P} - const DualALinearProblem = LinearProblem{ - <:Union{Number,<:AbstractArray, Nothing}, + <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, - <:Union{Number,<:AbstractArray}, - <:Union{Number,<:AbstractArray, SciMLBase.NullParameters} + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{Number, <:AbstractArray}, + <:Union{Number, <:AbstractArray, SciMLBase.NullParameters} } where {iip, T, V, P} const DualBLinearProblem = LinearProblem{ - <:Union{Number,<:AbstractArray, Nothing}, + <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{Number,<:AbstractArray}, - <:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}}, - <:Union{Number,<:AbstractArray, SciMLBase.NullParameters} + <:Union{Number, <:AbstractArray}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{Number, <:AbstractArray, SciMLBase.NullParameters} } where {iip, T, V, P} -const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem} +const DualAbstractLinearProblem = Union{ + DualLinearProblem, DualALinearProblem, DualBLinearProblem} function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...) @info "here!" new_A = nodual_value(prob.A) new_b = nodual_value(prob.b) - newprob = remake(prob; A=new_A, b=new_b) + newprob = remake(prob; A = new_A, b = new_b) sol = solve(newprob, alg, args...; kwargs...) uu = sol.u - # Solves Dual partials separately ∂_A = partial_vals(prob.A) ∂_b = partial_vals(prob.b) @@ -50,7 +49,7 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) partial_sols = map(rhs_list) do rhs - partial_prob = remake(newprob, b=rhs) + partial_prob = remake(newprob, b = rhs) solve(partial_prob, alg, args...; kwargs...).u end @@ -66,7 +65,8 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) end -function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) +function SciMLBase.solve(prob::DualAbstractLinearProblem, + alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) sol, partials = linearsolve_forwarddiff_solve( prob, alg, args...; kwargs... ) @@ -82,28 +82,24 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciML return SciMLBase.build_linear_solution( alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats ) - - end - function linearsolve_dual_solution( - u::Number, partials, dual_type) + u::Number, partials, dual_type) return dual_type(u, partials) end function linearsolve_dual_solution( - u::AbstractArray, partials, dual_type) + u::AbstractArray, partials, dual_type) partials_list = RecursiveArrayTools.VectorOfArray(partials) - return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) + return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), + zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) end - get_dual_type(x::Dual) = typeof(x) get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) get_dual_type(x) = nothing - partial_vals(x::Dual) = ForwardDiff.partials(x) partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) partial_vals(x) = nothing @@ -112,46 +108,46 @@ nodual_value(x) = x nodual_value(x::Dual) = ForwardDiff.value(x) nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) - -function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) +function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, + ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) A_list = partials_to_list(∂_A) - b_list = partials_to_list(∂_b) + b_list = partials_to_list(∂_b) - Auu = [A*uu for A in A_list] + Auu = [A * uu for A in A_list] b_list .- Auu end -function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) +function xp_linsolve_rhs( + uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) A_list = partials_to_list(∂_A) - Auu = [A*uu for A in A_list] + Auu = [A * uu for A in A_list] Auu end -function xp_linsolve_rhs(uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) +function xp_linsolve_rhs( + uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) b_list = partials_to_list(∂_b) b_list end - - function partials_to_list(partial_matrix::Vector) p = eachindex(first(partial_matrix)) - [[partial[i] for partial in partial_matrix] for i in p] + [[partial[i] for partial in partial_matrix] for i in p] end function partials_to_list(partial_matrix) p = length(first(partial_matrix)) - m,n = size(partial_matrix) - res_list = fill(zeros(m,n),p) + m, n = size(partial_matrix) + res_list = fill(zeros(m, n), p) for k in 1:p - res = zeros(m,n) + res = zeros(m, n) for i in 1:m for j in 1:n - res[i,j] = partial_matrix[i,j][k] + res[i, j] = partial_matrix[i, j][k] end end res_list[k] = res @@ -159,12 +155,4 @@ function partials_to_list(partial_matrix) return res_list end -end - - - - - - - - +end \ No newline at end of file diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 114783cf2..87d4d6aeb 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -2,32 +2,26 @@ using LinearSolve using ForwardDiff using Test - function h(p) - (A=[p[1] p[2]+1 p[2]^3; - 3*p[1] p[1]+5 p[2]*p[1]-4; - p[2]^2 9*p[1] p[2]], - b=[p[1] + 1, p[2] * 2, p[1]^2]) + (A = [p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2] * p[1]-4; + p[2]^2 9*p[1] p[2]], + b = [p[1] + 1, p[2] * 2, p[1]^2]) end A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) overload_x_p = solve(prob) -original_x_p = solve!(init(prob)) +original_x_p = solve!(init(prob)) @test overload_x_p ≈ original_x_p - A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, [6.0, 10.0, 25.0]) @test solve(prob).retcode == ReturnCode.Default _, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0] -prob = LinearProblem(A,b) -@test solve(prob).retcode == ReturnCode.Default - - - - +prob = LinearProblem(A, b) +@test solve(prob).retcode == ReturnCode.Default \ No newline at end of file From 6352024e45a53c719c15b0bc8b1097b307ac6b75 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 11:10:57 -0400 Subject: [PATCH 11/34] rm debug message --- ext/LinearSolveForwardDiffExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 8a95eaf54..e4cebb245 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -33,7 +33,6 @@ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...) - @info "here!" new_A = nodual_value(prob.A) new_b = nodual_value(prob.b) From e6cda650fb58e82e96312e2ac4d4bca61af1d75a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 12:05:01 -0400 Subject: [PATCH 12/34] use inits and caches --- ext/LinearSolveForwardDiffExt.jl | 84 ++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 25 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index e4cebb245..d01aee581 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -32,23 +32,18 @@ const DualBLinearProblem = LinearProblem{ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} -function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...) - new_A = nodual_value(prob.A) - new_b = nodual_value(prob.b) - - newprob = remake(prob; A = new_A, b = new_b) - - sol = solve(newprob, alg, args...; kwargs...) +function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) + sol = solve!(cache, alg, args...; kwargs...) uu = sol.u # Solves Dual partials separately - ∂_A = partial_vals(prob.A) - ∂_b = partial_vals(prob.b) + ∂_A = cache.partials_A + ∂_b = cache.partials_b rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) partial_sols = map(rhs_list) do rhs - partial_prob = remake(newprob, b = rhs) + partial_prob = remake(partial_prob, b = rhs) solve(partial_prob, alg, args...; kwargs...).u end @@ -66,21 +61,7 @@ end function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) - sol, partials = linearsolve_forwarddiff_solve( - prob, alg, args...; kwargs... - ) - - if get_dual_type(prob.A) !== nothing - dual_type = get_dual_type(prob.A) - elseif get_dual_type(prob.b) !== nothing - dual_type = get_dual_type(prob.b) - end - - dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type) - - return SciMLBase.build_linear_solution( - alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats - ) + solve!(init(prob, alg, args...; kwargs...)) end function linearsolve_dual_solution( @@ -154,4 +135,57 @@ function partials_to_list(partial_matrix) return res_list end +function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, + args...; + alias = LinearAliasSpecifier(), + abstol = default_tol(real(eltype(prob.b))), + reltol = default_tol(real(eltype(prob.b))), + maxiters::Int = length(prob.b), + verbose::Bool = false, + Pl = nothing, + Pr = nothing, + assumptions = OperatorAssumptions(issquare(prob.A)), + sensealg = LinearSolveAdjoint(), + kwargs...) + + new_A = nodual_value(prob.A) + new_b = nodual_value(prob.b) + + ∂_A = partial_vals(prob.A) + ∂_b = partial_vals(prob.b) + + newprob = remake(prob; A = new_A, b = new_b) + + non_partial_cache = init(newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, + maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, + sensealg = sensealg, kwargs...) + + return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b) +end + +mutable struct DualLinearCache + cache + prob + alg + partials_A + partials_b +end + +function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) + + sol, partials = linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) + + if get_dual_type(cache.prob.A) !== nothing + dual_type = get_dual_type(prob.A) + elseif get_dual_type(cache.prob.b) !== nothing + dual_type = get_dual_type(prob.b) + end + + dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type) + + return SciMLBase.build_linear_solution( + alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats + ) +end + end \ No newline at end of file From 501f07d64c30aebaad839c756f2b7275b6d46ffd Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 13:05:50 -0400 Subject: [PATCH 13/34] rearrange --- ext/LinearSolveForwardDiffExt.jl | 68 +++++++++++++++++--------------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index d01aee581..c0993536c 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -32,8 +32,16 @@ const DualBLinearProblem = LinearProblem{ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} +LinearSolve.@concrete mutable struct DualLinearCache + cache + prob + alg + partials_A + partials_b +end + function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) - sol = solve!(cache, alg, args...; kwargs...) + sol = solve!(cache.cache, alg, args...; kwargs...) uu = sol.u # Solves Dual partials separately @@ -42,11 +50,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - partial_sols = map(rhs_list) do rhs - partial_prob = remake(partial_prob, b = rhs) - solve(partial_prob, alg, args...; kwargs...).u + partial_prob = LinearProblem(cache.cache.A, rhs_list[1]) + partial_cache = init(partial_prob, alg, args...; kwargs...) + + for i in eachindex(rhs_list) + partial_cache.b = rhs_list[i] + rhs_list[i] = copy(solve!(partial_cache, alg).u) end + partial_sols = rhs_list + sol, partial_sols end @@ -135,19 +148,19 @@ function partials_to_list(partial_matrix) return res_list end -function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, - args...; - alias = LinearAliasSpecifier(), - abstol = default_tol(real(eltype(prob.b))), - reltol = default_tol(real(eltype(prob.b))), - maxiters::Int = length(prob.b), - verbose::Bool = false, - Pl = nothing, - Pr = nothing, - assumptions = OperatorAssumptions(issquare(prob.A)), - sensealg = LinearSolveAdjoint(), - kwargs...) - +function SciMLBase.init( + prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, + args...; + alias = LinearAliasSpecifier(), + abstol = LinearSolve.default_tol(real(eltype(prob.b))), + reltol = LinearSolve.default_tol(real(eltype(prob.b))), + maxiters::Int = length(prob.b), + verbose::Bool = false, + Pl = nothing, + Pr = nothing, + assumptions = OperatorAssumptions(issquare(prob.A)), + sensealg = LinearSolveAdjoint(), + kwargs...) new_A = nodual_value(prob.A) new_b = nodual_value(prob.b) @@ -156,35 +169,28 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAl newprob = remake(prob; A = new_A, b = new_b) - non_partial_cache = init(newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, + non_partial_cache = init( + newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, kwargs...) return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b) end -mutable struct DualLinearCache - cache - prob - alg - partials_A - partials_b -end - function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) - - sol, partials = linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) + sol, partials = linearsolve_forwarddiff_solve( + cache::DualLinearCache, cache.alg, args...; kwargs...) if get_dual_type(cache.prob.A) !== nothing - dual_type = get_dual_type(prob.A) + dual_type = get_dual_type(cache.prob.A) elseif get_dual_type(cache.prob.b) !== nothing - dual_type = get_dual_type(prob.b) + dual_type = get_dual_type(cache.prob.b) end dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type) return SciMLBase.build_linear_solution( - alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats + cache.alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats ) end From 9aa8b19049b05f4a15fddce67aedf26881686627 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 13:15:18 -0400 Subject: [PATCH 14/34] format --- ext/LinearSolveForwardDiffExt.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index c0993536c..72e962b09 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -178,7 +178,8 @@ function SciMLBase.init( end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) - sol, partials = linearsolve_forwarddiff_solve( + sol, + partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) if get_dual_type(cache.prob.A) !== nothing @@ -194,4 +195,4 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) ) end -end \ No newline at end of file +end From 277c4f8c29f7331b9f238fdc248ba61471cb125e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 15:10:03 -0400 Subject: [PATCH 15/34] bring in linalg, add tols to tests --- ext/LinearSolveForwardDiffExt.jl | 8 +++++--- test/forwarddiff_overloads.jl | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 72e962b09..08a269a46 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,6 +1,7 @@ module LinearSolveForwardDiffExt using LinearSolve +using LinearAlgebra using ForwardDiff using ForwardDiff: Dual, Partials using SciMLBase @@ -53,6 +54,8 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa partial_prob = LinearProblem(cache.cache.A, rhs_list[1]) partial_cache = init(partial_prob, alg, args...; kwargs...) + Main.@infiltrate + for i in eachindex(rhs_list) partial_cache.b = rhs_list[i] rhs_list[i] = copy(solve!(partial_cache, alg).u) @@ -107,7 +110,6 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials b_list = partials_to_list(∂_b) Auu = [A * uu for A in A_list] - b_list .- Auu end @@ -117,13 +119,13 @@ function xp_linsolve_rhs( Auu = [A * uu for A in A_list] - Auu + -Auu end function xp_linsolve_rhs( uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) b_list = partials_to_list(∂_b) - + Main.@infiltrate b_list end diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 87d4d6aeb..caeb92617 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -13,15 +13,15 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) overload_x_p = solve(prob) -original_x_p = solve!(init(prob)) +original_x_p = A \ b -@test overload_x_p ≈ original_x_p +@test ≈(overload_x_p, original_x_p, rtol = 1e-9) A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, [6.0, 10.0, 25.0]) -@test solve(prob).retcode == ReturnCode.Default +@test ≈(solve(prob).u, A \ [6.0, 10.0, 25.0], rtol = 1e-9) _, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0] prob = LinearProblem(A, b) -@test solve(prob).retcode == ReturnCode.Default \ No newline at end of file +@test ≈(solve(prob).u, A \ b, rtol = 1e-9) \ No newline at end of file From 313e28656b165540e336abe20fd5b405d08c9acc Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 16:14:49 -0400 Subject: [PATCH 16/34] make sure using nonmutated A --- ext/LinearSolveForwardDiffExt.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 08a269a46..f93ec8803 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -51,11 +51,10 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - partial_prob = LinearProblem(cache.cache.A, rhs_list[1]) + new_A = nodual_value(cache.prob.A) + partial_prob = LinearProblem(new_A, rhs_list[1]) partial_cache = init(partial_prob, alg, args...; kwargs...) - Main.@infiltrate - for i in eachindex(rhs_list) partial_cache.b = rhs_list[i] rhs_list[i] = copy(solve!(partial_cache, alg).u) @@ -110,7 +109,8 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials b_list = partials_to_list(∂_b) Auu = [A * uu for A in A_list] - b_list .- Auu + + return b_list .- Auu end function xp_linsolve_rhs( @@ -119,13 +119,12 @@ function xp_linsolve_rhs( Auu = [A * uu for A in A_list] - -Auu + return -Auu end function xp_linsolve_rhs( uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) b_list = partials_to_list(∂_b) - Main.@infiltrate b_list end From 922f7ec29dfad697d3ce730a7a325bc808c8b9b2 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 3 Jun 2025 17:47:51 -0400 Subject: [PATCH 17/34] dual cache should have original A and b --- ext/LinearSolveForwardDiffExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index f93ec8803..656e9b87b 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -37,6 +37,8 @@ LinearSolve.@concrete mutable struct DualLinearCache cache prob alg + A + b partials_A partials_b end @@ -175,7 +177,7 @@ function SciMLBase.init( maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, kwargs...) - return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, prob, alg, new_A, new_b, ∂_A, ∂_b) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) From 3547ec7708ae77fc860cb959e2ed61ff9d7f084f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 5 Jun 2025 16:04:30 -0400 Subject: [PATCH 18/34] rearrange, make sure that dualcache works --- ext/LinearSolveForwardDiffExt.jl | 184 +++++++++++++++++++++---------- 1 file changed, 123 insertions(+), 61 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 656e9b87b..d676b6019 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -34,17 +34,15 @@ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} LinearSolve.@concrete mutable struct DualLinearCache - cache + linear_cache prob alg - A - b partials_A partials_b end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) - sol = solve!(cache.cache, alg, args...; kwargs...) + sol = solve!(cache.linear_cache, alg, args...; kwargs...) uu = sol.u # Solves Dual partials separately @@ -53,7 +51,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - new_A = nodual_value(cache.prob.A) + new_A = nodual_value(cache.A) partial_prob = LinearProblem(new_A, rhs_list[1]) partial_cache = init(partial_prob, alg, args...; kwargs...) @@ -67,44 +65,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa sol, partial_sols end -function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) - return solve(prob, nothing, args...; kwargs...) -end - -function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; - assump = OperatorAssumptions(issquare(prob.A)), kwargs...) - return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) -end - -function SciMLBase.solve(prob::DualAbstractLinearProblem, - alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) - solve!(init(prob, alg, args...; kwargs...)) -end - -function linearsolve_dual_solution( - u::Number, partials, dual_type) - return dual_type(u, partials) -end - -function linearsolve_dual_solution( - u::AbstractArray, partials, dual_type) - partials_list = RecursiveArrayTools.VectorOfArray(partials) - return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), - zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) -end - -get_dual_type(x::Dual) = typeof(x) -get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) -get_dual_type(x) = nothing - -partial_vals(x::Dual) = ForwardDiff.partials(x) -partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) -partial_vals(x) = nothing - -nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) - function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) A_list = partials_to_list(∂_A) @@ -130,25 +90,30 @@ function xp_linsolve_rhs( b_list end -function partials_to_list(partial_matrix::Vector) - p = eachindex(first(partial_matrix)) - [[partial[i] for partial in partial_matrix] for i in p] +function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) + return solve(prob, nothing, args...; kwargs...) end -function partials_to_list(partial_matrix) - p = length(first(partial_matrix)) - m, n = size(partial_matrix) - res_list = fill(zeros(m, n), p) - for k in 1:p - res = zeros(m, n) - for i in 1:m - for j in 1:n - res[i, j] = partial_matrix[i, j][k] - end - end - res_list[k] = res - end - return res_list +function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; + assump = OperatorAssumptions(issquare(prob.A)), kwargs...) + return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) +end + +function SciMLBase.solve(prob::DualAbstractLinearProblem, + alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) + solve!(init(prob, alg, args...; kwargs...)) +end + +function linearsolve_dual_solution( + u::Number, partials, dual_type) + return dual_type(u, partials) +end + +function linearsolve_dual_solution( + u::AbstractArray, partials, dual_type) + partials_list = RecursiveArrayTools.VectorOfArray(partials) + return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), + zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) end function SciMLBase.init( @@ -164,6 +129,7 @@ function SciMLBase.init( assumptions = OperatorAssumptions(issquare(prob.A)), sensealg = LinearSolveAdjoint(), kwargs...) + new_A = nodual_value(prob.A) new_b = nodual_value(prob.b) @@ -177,7 +143,7 @@ function SciMLBase.init( maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, kwargs...) - return DualLinearCache(non_partial_cache, prob, alg, new_A, new_b, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) @@ -198,4 +164,100 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) ) end +# If setting A or b for DualLinearCache, also set it for the underlying LinearCache +function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) + # If the property is A or b, also update it in the LinearCache + if sym === :A || sym === :b + if hasproperty(dc, :linear_cache) + setproperty!(dc.linear_cache, sym, nodual_value(val)) + end + end + + # Update the partials if setting A or b + if sym === :A + setfield!(dc, :partials_A, partial_vals(val)) + elseif sym === :b + setfield!(dc, :partials_b, partial_vals(val)) + end + + return val +end + +function Base.getproperty(dc::DualLinearCache, sym::Symbol) + if sym === :A + return dc.linear_cache.A + elseif sym === :b + return dc.linear_cache.b + else + getfield(dc,sym) + end +end + +function SciMLBase.reinit!(cache::DualLinearCache; + A = nothing, + b = cache.b, + u = cache.u, + p = nothing, + reuse_precs = false) + (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache + + isfresh = !isnothing(A) + precsisfresh = !reuse_precs && (isfresh || !isnothing(p)) + isfresh |= cache.isfresh + precsisfresh |= cache.precsisfresh + + A = isnothing(A) ? cache.A : A + b = isnothing(b) ? cache.b : b + u = isnothing(u) ? cache.u : u + p = isnothing(p) ? cache.p : p + Pl = cache.Pl + Pr = cache.Pr + + cache.A = A + cache.b = b + cache.u = u + cache.p = p + cache.Pl = Pl + cache.Pr = Pr + cache.isfresh = true + cache.precsisfresh = precsisfresh + nothing +end + +# Helper functions for Dual numbers +get_dual_type(x::Dual) = typeof(x) +get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) +get_dual_type(x) = nothing + +partial_vals(x::Dual) = ForwardDiff.partials(x) +partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) +partial_vals(x) = nothing + +nodual_value(x) = x +nodual_value(x::Dual) = ForwardDiff.value(x) +nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + + +function partials_to_list(partial_matrix::Vector) + p = eachindex(first(partial_matrix)) + [[partial[i] for partial in partial_matrix] for i in p] +end + +function partials_to_list(partial_matrix) + p = length(first(partial_matrix)) + m, n = size(partial_matrix) + res_list = fill(zeros(m, n), p) + for k in 1:p + res = zeros(m, n) + for i in 1:m + for j in 1:n + res[i, j] = partial_matrix[i, j][k] + end + end + res_list[k] = res + end + return res_list +end + + end From 9cd4e1960f3cf73cc190619b3dff12e675c946be Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 5 Jun 2025 16:05:08 -0400 Subject: [PATCH 19/34] reinit! not needed for now --- ext/LinearSolveForwardDiffExt.jl | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index d676b6019..d5ad512c5 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -193,36 +193,7 @@ function Base.getproperty(dc::DualLinearCache, sym::Symbol) end end -function SciMLBase.reinit!(cache::DualLinearCache; - A = nothing, - b = cache.b, - u = cache.u, - p = nothing, - reuse_precs = false) - (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache - - isfresh = !isnothing(A) - precsisfresh = !reuse_precs && (isfresh || !isnothing(p)) - isfresh |= cache.isfresh - precsisfresh |= cache.precsisfresh - - A = isnothing(A) ? cache.A : A - b = isnothing(b) ? cache.b : b - u = isnothing(u) ? cache.u : u - p = isnothing(p) ? cache.p : p - Pl = cache.Pl - Pr = cache.Pr - - cache.A = A - cache.b = b - cache.u = u - cache.p = p - cache.Pl = Pl - cache.Pr = Pr - cache.isfresh = true - cache.precsisfresh = precsisfresh - nothing -end + # Helper functions for Dual numbers get_dual_type(x::Dual) = typeof(x) From b2a42912f2039d4f7d56abbfbdfb3161a3be3849 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 5 Jun 2025 16:15:23 -0400 Subject: [PATCH 20/34] correct setproperty! for DualLinearCache --- ext/LinearSolveForwardDiffExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index d5ad512c5..f3b0f7981 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -178,9 +178,9 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) setfield!(dc, :partials_A, partial_vals(val)) elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) + else + setfield!(dc, sym, val) end - - return val end function Base.getproperty(dc::DualLinearCache, sym::Symbol) From 680aec655ada08e0797638800181b1e7920317cf Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 5 Jun 2025 16:15:36 -0400 Subject: [PATCH 21/34] add tests for updating cache --- test/forwarddiff_overloads.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index caeb92617..78fa32549 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -24,4 +24,18 @@ prob = LinearProblem(A, [6.0, 10.0, 25.0]) _, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0] prob = LinearProblem(A, b) -@test ≈(solve(prob).u, A \ b, rtol = 1e-9) \ No newline at end of file +@test ≈(solve(prob).u, A \ b, rtol = 1e-9) + +A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.A = new_A +cache.b = new_b + +x_p = solve!(cache) +other_x_p = new_A \ new_b + +@test ≈(x_p, other_x_p, rtol = 1e-9) \ No newline at end of file From f9cd2fe9eaf2b3ff53a25f9cca6d5c59d3b22eb4 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 5 Jun 2025 17:48:48 -0400 Subject: [PATCH 22/34] enable dual u0 --- ext/LinearSolveForwardDiffExt.jl | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index f3b0f7981..ede5857d8 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -37,6 +37,7 @@ LinearSolve.@concrete mutable struct DualLinearCache linear_cache prob alg + dual_u0 partials_A partials_b end @@ -48,12 +49,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa # Solves Dual partials separately ∂_A = cache.partials_A ∂_b = cache.partials_b + dual_u0 = only(partials_to_list(cache.dual_u0)) rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) new_A = nodual_value(cache.A) partial_prob = LinearProblem(new_A, rhs_list[1]) - partial_cache = init(partial_prob, alg, args...; kwargs...) + partial_cache = init(partial_prob, alg, args...; u0 = dual_u0, kwargs...) for i in eachindex(rhs_list) partial_cache.b = rhs_list[i] @@ -130,20 +132,23 @@ function SciMLBase.init( sensealg = LinearSolveAdjoint(), kwargs...) - new_A = nodual_value(prob.A) - new_b = nodual_value(prob.b) + (; A, b, u0, p) = prob - ∂_A = partial_vals(prob.A) - ∂_b = partial_vals(prob.b) + new_A = nodual_value(A) + new_b = nodual_value(b) + new_u0 = nodual_value(u0) + + ∂_A = partial_vals(A) + ∂_b = partial_vals(b) + dual_u0 = partial_vals(u0) newprob = remake(prob; A = new_A, b = new_b) non_partial_cache = init( newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, - sensealg = sensealg, kwargs...) - - return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b) + sensealg = sensealg, u0 = new_u0, kwargs...) + return DualLinearCache(non_partial_cache, prob, alg, dual_u0, ∂_A, ∂_b) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) From 1b486662105f162fa631f4ee5c7dd30510d9861a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 5 Jun 2025 17:58:07 -0400 Subject: [PATCH 23/34] use new_u0 --- ext/LinearSolveForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index ede5857d8..8c0f09202 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -142,7 +142,7 @@ function SciMLBase.init( ∂_b = partial_vals(b) dual_u0 = partial_vals(u0) - newprob = remake(prob; A = new_A, b = new_b) + newprob = remake(prob; A = new_A, b = new_b, u0 = new_u0) non_partial_cache = init( newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, From b39ce8735fada9146af4db9823be301da5d82493 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 5 Jun 2025 19:10:46 -0400 Subject: [PATCH 24/34] reuse primal cache for Dual computation --- ext/LinearSolveForwardDiffExt.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 8c0f09202..7972f6a63 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -46,6 +46,8 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa sol = solve!(cache.linear_cache, alg, args...; kwargs...) uu = sol.u + primal_sol = deepcopy(sol) + # Solves Dual partials separately ∂_A = cache.partials_A ∂_b = cache.partials_b @@ -54,9 +56,8 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) new_A = nodual_value(cache.A) - partial_prob = LinearProblem(new_A, rhs_list[1]) - partial_cache = init(partial_prob, alg, args...; u0 = dual_u0, kwargs...) - + partial_cache = cache.linear_cache + partial_cache.u0 = dual_u0 for i in eachindex(rhs_list) partial_cache.b = rhs_list[i] rhs_list[i] = copy(solve!(partial_cache, alg).u) @@ -64,7 +65,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa partial_sols = rhs_list - sol, partial_sols + primal_sol, partial_sols end function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, From e5761c853b3c14e93d23df3f017a6c70fbf51c19 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 9 Jun 2025 12:08:08 -0400 Subject: [PATCH 25/34] redundant line --- ext/LinearSolveForwardDiffExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 7972f6a63..b5455fea0 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -55,7 +55,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - new_A = nodual_value(cache.A) partial_cache = cache.linear_cache partial_cache.u0 = dual_u0 for i in eachindex(rhs_list) From 9b69358b35c7e68412adfeca7a596f41986eab7b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 9 Jun 2025 12:59:30 -0400 Subject: [PATCH 26/34] make sure u0 is correct type --- ext/LinearSolveForwardDiffExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index b5455fea0..a5fdb4d4d 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -51,12 +51,12 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa # Solves Dual partials separately ∂_A = cache.partials_A ∂_b = cache.partials_b - dual_u0 = only(partials_to_list(cache.dual_u0)) + dual_u0 = !isnothing(cache.dual_u0) ? only(partials_to_list(cache.dual_u0)) : cache.linear_cache.u rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) partial_cache = cache.linear_cache - partial_cache.u0 = dual_u0 + partial_cache.u = dual_u0 for i in eachindex(rhs_list) partial_cache.b = rhs_list[i] rhs_list[i] = copy(solve!(partial_cache, alg).u) @@ -142,7 +142,8 @@ function SciMLBase.init( ∂_b = partial_vals(b) dual_u0 = partial_vals(u0) - newprob = remake(prob; A = new_A, b = new_b, u0 = new_u0) + newprob = LinearProblem(new_A, new_b, u0 = new_u0) + #remake(prob; A = new_A, b = new_b, u0 = new_u0) non_partial_cache = init( newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, From f55639a559691c49805def5f6d11be99fff89166 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 9 Jun 2025 13:23:24 -0400 Subject: [PATCH 27/34] add tests for iterative and u0 --- test/forwarddiff_overloads.jl | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 78fa32549..ba3b94dc2 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -13,18 +13,31 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) overload_x_p = solve(prob) -original_x_p = A \ b +backslash_x_p = A \ b +krylov_overload_x_p = solve(prob, KrylovJL_GMRES()) +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) +@test ≈(krylov_overload_x_p, backslash_x_p, rtol = 1e-9) + +krylov_prob = LinearProblem(A, b, u0 = rand(3)) +krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES()) + +@test ≈(krylov_u0_sol, original_x_p, rtol = 1e-9) -@test ≈(overload_x_p, original_x_p, rtol = 1e-9) A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +backslash_x_p = A \ [6.0, 10.0, 25.0] prob = LinearProblem(A, [6.0, 10.0, 25.0]) -@test ≈(solve(prob).u, A \ [6.0, 10.0, 25.0], rtol = 1e-9) + +@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) +@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) _, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0] +backslash_x_p = A \ b prob = LinearProblem(A, b) -@test ≈(solve(prob).u, A \ b, rtol = 1e-9) + +@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) +@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) @@ -36,6 +49,6 @@ cache.A = new_A cache.b = new_b x_p = solve!(cache) -other_x_p = new_A \ new_b +backslash_x_p = new_A \ new_b -@test ≈(x_p, other_x_p, rtol = 1e-9) \ No newline at end of file +@test ≈(x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file From 51ce056f327e8aca31c28d6e1f2c6cd498e82240 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 9 Jun 2025 15:06:24 -0400 Subject: [PATCH 28/34] make sure that linearcache.b is reset after dual solve --- ext/LinearSolveForwardDiffExt.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index a5fdb4d4d..ed505b7fc 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -44,6 +44,7 @@ end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) sol = solve!(cache.linear_cache, alg, args...; kwargs...) + primal_b = copy(cache.linear_cache.b) uu = sol.u primal_sol = deepcopy(sol) @@ -57,11 +58,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa partial_cache = cache.linear_cache partial_cache.u = dual_u0 + for i in eachindex(rhs_list) partial_cache.b = rhs_list[i] - rhs_list[i] = copy(solve!(partial_cache, alg).u) + rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u) end + # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to + partial_cache.b = primal_b + partial_sols = rhs_list primal_sol, partial_sols @@ -173,7 +178,7 @@ end # If setting A or b for DualLinearCache, also set it for the underlying LinearCache function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache - if sym === :A || sym === :b + if sym === :A || sym === :b || sym === :u if hasproperty(dc, :linear_cache) setproperty!(dc.linear_cache, sym, nodual_value(val)) end From 3565d9b1c2eb23c31b550f982777f2b86c74d8aa Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 10 Jun 2025 11:39:31 -0400 Subject: [PATCH 29/34] fix test --- test/forwarddiff_overloads.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index ba3b94dc2..7ed8101a6 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -21,7 +21,7 @@ krylov_overload_x_p = solve(prob, KrylovJL_GMRES()) krylov_prob = LinearProblem(A, b, u0 = rand(3)) krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES()) -@test ≈(krylov_u0_sol, original_x_p, rtol = 1e-9) +@test ≈(krylov_u0_sol, backslash_x_p, rtol = 1e-9) A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) From 6f33486aed15a879733a2a2980cd9602fca8f97f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 10 Jun 2025 12:51:49 -0400 Subject: [PATCH 30/34] forward steproperty and getproperty more --- ext/LinearSolveForwardDiffExt.jl | 41 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index ed505b7fc..d967dfa29 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -35,8 +35,7 @@ const DualAbstractLinearProblem = Union{ LinearSolve.@concrete mutable struct DualLinearCache linear_cache - prob - alg + dual_type dual_u0 partials_A partials_b @@ -147,14 +146,20 @@ function SciMLBase.init( ∂_b = partial_vals(b) dual_u0 = partial_vals(u0) - newprob = LinearProblem(new_A, new_b, u0 = new_u0) + primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) #remake(prob; A = new_A, b = new_b, u0 = new_u0) + if get_dual_type(prob.A) !== nothing + dual_type = get_dual_type(prob.A) + elseif get_dual_type(prob.b) !== nothing + dual_type = get_dual_type(prob.b) + end + non_partial_cache = init( - newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, + primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, prob, alg, dual_u0, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, dual_type, dual_u0, ∂_A, ∂_b) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) @@ -162,26 +167,21 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) - if get_dual_type(cache.prob.A) !== nothing - dual_type = get_dual_type(cache.prob.A) - elseif get_dual_type(cache.prob.b) !== nothing - dual_type = get_dual_type(cache.prob.b) - end - - dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type) + dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats ) end -# If setting A or b for DualLinearCache, also set it for the underlying LinearCache +# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache +# Also "forwards" setproperty so that function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u - if hasproperty(dc, :linear_cache) - setproperty!(dc.linear_cache, sym, nodual_value(val)) - end + setproperty!(dc.linear_cache, sym, nodual_value(val)) + elseif hasfield(LinearSolve.LinearCache, sym) + setproperty!(dc.linear_cache, sym, val) end # Update the partials if setting A or b @@ -194,13 +194,12 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) end end +# "Forwards" getproperty to LinearCache if necessary function Base.getproperty(dc::DualLinearCache, sym::Symbol) - if sym === :A - return dc.linear_cache.A - elseif sym === :b - return dc.linear_cache.b + if hasfield(LinearSolve.LinearCache, sym) + return getproperty(dc.linear_cache, sym) else - getfield(dc,sym) + return getfield(dc, sym) end end From d05ad09c338ea9888e5755cbf68041e62916f7c1 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 10 Jun 2025 17:49:29 -0400 Subject: [PATCH 31/34] use correct u0 --- ext/LinearSolveForwardDiffExt.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index d967dfa29..e66d8d4ee 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -36,12 +36,13 @@ const DualAbstractLinearProblem = Union{ LinearSolve.@concrete mutable struct DualLinearCache linear_cache dual_type - dual_u0 partials_A partials_b end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) + # Solve the primal problem + dual_u0 = copy(cache.linear_cache.u) sol = solve!(cache.linear_cache, alg, args...; kwargs...) primal_b = copy(cache.linear_cache.b) uu = sol.u @@ -51,7 +52,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa # Solves Dual partials separately ∂_A = cache.partials_A ∂_b = cache.partials_b - dual_u0 = !isnothing(cache.dual_u0) ? only(partials_to_list(cache.dual_u0)) : cache.linear_cache.u rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) @@ -137,14 +137,12 @@ function SciMLBase.init( kwargs...) (; A, b, u0, p) = prob - new_A = nodual_value(A) new_b = nodual_value(b) new_u0 = nodual_value(u0) ∂_A = partial_vals(A) ∂_b = partial_vals(b) - dual_u0 = partial_vals(u0) primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) #remake(prob; A = new_A, b = new_b, u0 = new_u0) @@ -159,7 +157,7 @@ function SciMLBase.init( primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, dual_u0, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) @@ -168,9 +166,8 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) - return SciMLBase.build_linear_solution( - cache.alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats + cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) end From fb0626fc1af8687769aac3a6ce2969ef55b27ce9 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 10 Jun 2025 23:50:48 -0400 Subject: [PATCH 32/34] add test for updating one of A or b --- test/forwarddiff_overloads.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 7ed8101a6..eb66c64dc 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -51,4 +51,32 @@ cache.b = new_b x_p = solve!(cache) backslash_x_p = new_A \ new_b +@test ≈(x_p, backslash_x_p, rtol = 1e-9) + +# Just update A +A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.A = new_A + +x_p = solve!(cache) +backslash_x_p = new_A \ b + +@test ≈(x_p, backslash_x_p, rtol = 1e-9) + +# Just update b +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +cache = init(prob) + +_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.b = new_b + +x_p = solve!(cache) +backslash_x_p = A \ new_b + @test ≈(x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file From 613e9aacfdfcf20ec8b5c1d14c7a668b8bcd6f42 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 11 Jun 2025 00:42:46 -0400 Subject: [PATCH 33/34] p can be Any --- ext/LinearSolveForwardDiffExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index e66d8d4ee..cab154af6 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -11,7 +11,7 @@ const DualLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, - <:Union{Number, <:AbstractArray, SciMLBase.NullParameters} + <:Any } where {iip, T, V, P} const DualALinearProblem = LinearProblem{ @@ -19,7 +19,7 @@ const DualALinearProblem = LinearProblem{ iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, <:Union{Number, <:AbstractArray}, - <:Union{Number, <:AbstractArray, SciMLBase.NullParameters} + <:Any } where {iip, T, V, P} const DualBLinearProblem = LinearProblem{ @@ -27,7 +27,7 @@ const DualBLinearProblem = LinearProblem{ iip, <:Union{Number, <:AbstractArray}, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, - <:Union{Number, <:AbstractArray, SciMLBase.NullParameters} + <:Any } where {iip, T, V, P} const DualAbstractLinearProblem = Union{ From d690f1fe38551b11debb94bd31d92f70ce05ed77 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 12 Jun 2025 12:14:18 -0400 Subject: [PATCH 34/34] use remake instead --- ext/LinearSolveForwardDiffExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index cab154af6..f2137eccb 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -144,8 +144,8 @@ function SciMLBase.init( ∂_A = partial_vals(A) ∂_b = partial_vals(b) - primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) - #remake(prob; A = new_A, b = new_b, u0 = new_u0) + #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) + primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) if get_dual_type(prob.A) !== nothing dual_type = get_dual_type(prob.A)