Skip to content

Commit 020ea8c

Browse files
committed
Merge branch 'main' into jf/document-precs
2 parents 1582645 + de8b1c1 commit 020ea8c

File tree

11 files changed

+181
-119
lines changed

11 files changed

+181
-119
lines changed

docs/src/basics/FAQ.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(we
8383
Pr = Diagonal(weights)
8484
8585
prob = LinearProblem(A, b)
86-
sol = solve(prob, KrylovJL_GMRES(precs=Returns((Pl,Pr))))
86+
sol = solve(prob, KrylovJL_GMRES(precs = Returns((Pl, Pr))))
8787
```

ext/LinearSolvePardisoExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs
134134
if cache.isfresh
135135
phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT
136136
Pardiso.set_phase!(cache.cacheval, phase)
137-
Pardiso.pardiso(cache.cacheval, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), eltype(A)[])
137+
Pardiso.pardiso(cache.cacheval,
138+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
139+
eltype(A)[])
138140
cache.isfresh = false
139141
end
140142
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
141-
Pardiso.pardiso(cache.cacheval, u, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), b)
143+
Pardiso.pardiso(cache.cacheval, u,
144+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), b)
142145
return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
143146
end
144147

src/LinearSolve.jl

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,48 @@ if isdefined(Base, :Experimental) &&
55
end
66

77
import PrecompileTools
8-
using ArrayInterface
9-
using RecursiveFactorization
10-
using Base: cache_dependencies, Bool
11-
using LinearAlgebra
12-
using SparseArrays
13-
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
14-
using LazyArrays: @~, BroadcastArray
15-
using SciMLBase: AbstractLinearAlgorithm
16-
using SciMLOperators
17-
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
18-
using Setfield
19-
using UnPack
20-
using KLU
21-
using Sparspak
22-
using FastLapackInterface
23-
using DocStringExtensions
24-
using EnumX
25-
using Markdown
26-
using ChainRulesCore
27-
import InteractiveUtils
28-
29-
import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
30-
31-
using LinearAlgebra: BlasInt, LU
32-
using LinearAlgebra.LAPACK: require_one_based_indexing,
33-
chkfinite, chkstride1,
34-
@blasfunc, chkargsok
35-
36-
import GPUArraysCore
37-
import Preferences
38-
import ConcreteStructs: @concrete
39-
40-
# wrap
41-
import Krylov
42-
using SciMLBase
43-
import Preferences
8+
using ArrayInterface
9+
using RecursiveFactorization
10+
using Base: cache_dependencies, Bool
11+
using LinearAlgebra
12+
using SparseArrays
13+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
14+
using LazyArrays: @~, BroadcastArray
15+
using SciMLBase: AbstractLinearAlgorithm
16+
using SciMLOperators
17+
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
18+
using Setfield
19+
using UnPack
20+
using KLU
21+
using Sparspak
22+
using FastLapackInterface
23+
using DocStringExtensions
24+
using EnumX
25+
using Markdown
26+
using ChainRulesCore
27+
import InteractiveUtils
28+
29+
import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
30+
31+
using LinearAlgebra: BlasInt, LU
32+
using LinearAlgebra.LAPACK: require_one_based_indexing,
33+
chkfinite, chkstride1,
34+
@blasfunc, chkargsok
35+
36+
import GPUArraysCore
37+
import Preferences
38+
import ConcreteStructs: @concrete
39+
40+
# wrap
41+
import Krylov
42+
using SciMLBase
43+
import Preferences
4444

4545
const CRC = ChainRulesCore
4646

4747
@static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686
48-
if Preferences.@load_preference("LoadMKL_JLL", !occursin("EPYC", Sys.cpu_info()[1].model))
48+
if Preferences.@load_preference("LoadMKL_JLL",
49+
!occursin("EPYC", Sys.cpu_info()[1].model))
4950
using MKL_jll
5051
const usemkl = MKL_jll.is_available()
5152
else

src/common.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
150150
assumptions = OperatorAssumptions(issquare(prob.A)),
151151
sensealg = LinearSolveAdjoint(),
152152
kwargs...)
153-
(;A, b, u0, p) = prob
153+
(; A, b, u0, p) = prob
154154

155155
A = if alias_A || A isa SMatrix
156156
A
@@ -206,22 +206,21 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
206206

207207
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
208208
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
209-
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
209+
typeof(sensealg)}(
210+
A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
210211
maxiters, verbose, assumptions, sensealg)
211212
return cache
212213
end
213214

214-
215215
function SciMLBase.reinit!(cache::LinearCache;
216-
A = nothing,
217-
b = cache.b,
218-
u = cache.u,
219-
p = nothing,
220-
reinit_cache = false,
221-
reuse_precs = false)
216+
A = nothing,
217+
b = cache.b,
218+
u = cache.u,
219+
p = nothing,
220+
reinit_cache = false,
221+
reuse_precs = false)
222222
(; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache
223223

224-
225224
isfresh = !isnothing(A)
226225
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
227226
isfresh |= cache.isfresh
@@ -234,9 +233,11 @@ function SciMLBase.reinit!(cache::LinearCache;
234233
Pl = cache.Pl
235234
Pr = cache.Pr
236235
if reinit_cache
237-
return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
236+
return LinearCache{
237+
typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
238238
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
239-
typeof(sensealg)}(A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol,
239+
typeof(sensealg)}(
240+
A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol,
240241
maxiters, verbose, assumptions, sensealg)
241242
else
242243
cache.A = A

src/default.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
179179
__conditioning(assump) === OperatorCondition.WellConditioned)
180180
if length(b) <= 10
181181
DefaultAlgorithmChoice.RFLUFactorization
182-
elseif appleaccelerate_isavailable() && b isa Array &&
183-
eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64}
182+
elseif appleaccelerate_isavailable() && b isa Array &&
183+
eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64}
184184
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
185185
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
186186
(usemkl && length(b) <= 200)) &&
@@ -189,8 +189,8 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
189189
DefaultAlgorithmChoice.RFLUFactorization
190190
#elseif A === nothing || A isa Matrix
191191
# alg = FastLUFactorization()
192-
elseif usemkl && b isa Array &&
193-
eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64}
192+
elseif usemkl && b isa Array &&
193+
eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64}
194194
DefaultAlgorithmChoice.MKLLUFactorization
195195
else
196196
DefaultAlgorithmChoice.LUFactorization

src/extension_algs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ All values default to `nothing` and the solver internally determines the values
217217
given the input types, and these keyword arguments are only for overriding the
218218
default handling process. This should not be required by most users.
219219
"""
220-
struct PardisoJL{T1, T2} <: AbstractSparseFactorization
220+
struct PardisoJL{T1, T2} <: AbstractSparseFactorization
221221
nprocs::Union{Int, Nothing}
222222
solver_type::T1
223223
matrix_type::T2

test/basictests.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,12 @@ end
267267

268268
@testset "KrylovJL" begin
269269
kwargs = (; gmres_restart = 5)
270-
precs = (A,p=nothing) -> (BlockJacobiPreconditioner(A, 2), I)
270+
precs = (A, p = nothing) -> (BlockJacobiPreconditioner(A, 2), I)
271271
algorithms = (
272272
("Default", KrylovJL(kwargs...)),
273273
("CG", KrylovJL_CG(kwargs...)),
274274
("GMRES", KrylovJL_GMRES(kwargs...)),
275-
("GMRES_prec", KrylovJL_GMRES(;precs, ldiv=false, kwargs...)),
275+
("GMRES_prec", KrylovJL_GMRES(; precs, ldiv = false, kwargs...)),
276276
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
277277
("MINRES", KrylovJL_MINRES(kwargs...))
278278
)
@@ -579,28 +579,27 @@ end
579579
# test default algorithn
580580
@time "solve MySparseMatrixCSC" u=solve(pr)
581581
@test norm(u - u0, Inf) < 1.0e-13
582-
582+
583583
# test Krylov algorithm with reinit!
584584
pr = LinearProblem(B, b)
585-
solver=KrylovJL_CG()
586-
cache=init(pr,solver,maxiters=1000,reltol=1.0e-10)
587-
u=solve!(cache)
585+
solver = KrylovJL_CG()
586+
cache = init(pr, solver, maxiters = 1000, reltol = 1.0e-10)
587+
u = solve!(cache)
588588
A1 = spdiagm(1 => -ones(N - 1), 0 => fill(100.0, N), -1 => -ones(N - 1))
589-
b1=A1*u0
590-
B1= MySparseMatrixCSC(A1)
589+
b1 = A1 * u0
590+
B1 = MySparseMatrixCSC(A1)
591591
@test norm(u - u0, Inf) < 1.0e-8
592-
reinit!(cache; A=B1, b=b1)
593-
u=solve!(cache)
592+
reinit!(cache; A = B1, b = b1)
593+
u = solve!(cache)
594594
@test norm(u - u0, Inf) < 1.0e-8
595-
595+
596596
# test factorization with reinit!
597597
pr = LinearProblem(B, b)
598-
solver=SparspakFactorization()
599-
cache=init(pr,solver)
600-
u=solve!(cache)
598+
solver = SparspakFactorization()
599+
cache = init(pr, solver)
600+
u = solve!(cache)
601601
@test norm(u - u0, Inf) < 1.0e-8
602-
reinit!(cache; A=B1, b=b1)
603-
u=solve!(cache)
602+
reinit!(cache; A = B1, b = b1)
603+
u = solve!(cache)
604604
@test norm(u - u0, Inf) < 1.0e-8
605-
606605
end

test/enzyme.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ end
207207
@show fd_jac
208208

209209
en_jac = map(onehot(A)) do dA
210-
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
211-
Duplicated(A, dA), Const(b1), Const(alg)))
212-
end |> collect
210+
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
211+
Duplicated(A, dA), Const(b1), Const(alg)))
212+
end |> collect
213213
@show en_jac
214214

215215
@test en_jacfd_jac rtol=1e-4

0 commit comments

Comments
 (0)