diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 47221762c..4f13a499e 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -29,6 +29,7 @@ jobs: - "pre" group: - "Core" + - "DefaultsLoading" - "LinearSolveHYPRE" - "LinearSolvePardiso" - "LinearSolveBandedMatrices" diff --git a/ext/LinearSolveSparseArraysExt.jl b/ext/LinearSolveSparseArraysExt.jl index fe195bf38..24d9f3517 100644 --- a/ext/LinearSolveSparseArraysExt.jl +++ b/ext/LinearSolveSparseArraysExt.jl @@ -3,6 +3,7 @@ module LinearSolveSparseArraysExt using LinearSolve, LinearAlgebra using SparseArrays using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr +using LinearSolve: BLASELTYPES # Can't `using KLU` because cannot have a dependency in there without # requiring the user does `using KLU` @@ -39,7 +40,7 @@ function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC) end function LinearSolve.defaultalg( - A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool}) + A::Symmetric{<:BLASELTYPES, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool}) LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization) end @@ -78,7 +79,7 @@ function LinearSolve.init_cacheval( end function LinearSolve.init_cacheval( - alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr, + alg::UMFPACKFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -136,7 +137,7 @@ function LinearSolve.init_cacheval( end function LinearSolve.init_cacheval( - alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr, + alg::KLUFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -186,15 +187,15 @@ function LinearSolve.init_cacheval(alg::CHOLMODFactorization, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) where {T <: - Union{Float32, Float64}} + BLASELTYPES} PREALLOCATED_CHOLMOD end function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization, - A::Union{AbstractSparseArray, LinearSolve.GPUArraysCore.AnyGPUArray, - Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr, + A::Union{AbstractSparseArray{T}, LinearSolve.GPUArraysCore.AnyGPUArray, + Symmetric{T, <:AbstractSparseArray{T}}}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, - assumptions::OperatorAssumptions) + assumptions::OperatorAssumptions) where {T <: BLASELTYPES} LinearSolve.ArrayInterface.cholesky_instance(convert(AbstractMatrix, A)) end diff --git a/src/default.jl b/src/default.jl index e57d28762..2c86f899a 100644 --- a/src/default.jl +++ b/src/default.jl @@ -229,7 +229,7 @@ function algchoice_to_alg(alg::Symbol) elseif alg === :DirectLdiv! DirectLdiv!() elseif alg === :SparspakFactorization - SparspakFactorization() + SparspakFactorization(throwerror = false) elseif alg === :KLUFactorization KLUFactorization() elseif alg === :UMFPACKFactorization diff --git a/src/factorization.jl b/src/factorization.jl index 759c84675..cd8f35dd7 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -319,7 +319,8 @@ function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray, cholesky(A; check = false) end -function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr, +function init_cacheval( + alg::CholeskyFactorization, A::AbstractArray{<:BLASELTYPES}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot) end @@ -333,7 +334,7 @@ function init_cacheval(alg::CholeskyFactorization, A::Matrix{Float64}, b, u, Pl, end function init_cacheval(alg::CholeskyFactorization, - A::Union{Diagonal, AbstractSciMLOperator}, b, u, Pl, Pr, + A::Union{Diagonal, AbstractSciMLOperator, AbstractArray}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) nothing @@ -1044,8 +1045,17 @@ dispatch to route around standard BLAS routines in the case e.g. of arbitrary-pr floating point numbers or ForwardDiff.Dual. This e.g. allows for Automatic Differentiation (AD) of a sparse-matrix solve. """ -Base.@kwdef struct SparspakFactorization <: AbstractSparseFactorization - reuse_symbolic::Bool = true +struct SparspakFactorization <: AbstractSparseFactorization + reuse_symbolic::Bool + + function SparspakFactorization(; reuse_symbolic = true, throwerror = true) + ext = Base.get_extension(@__MODULE__, :LinearSolveSparspakExt) + if throwerror && ext === nothing + error("SparspakFactorization requires that Sparspak is loaded, i.e. `using Sparspak`") + else + new(reuse_symbolic) + end + end end function init_cacheval(alg::SparspakFactorization, diff --git a/test/defaults_loading.jl b/test/defaults_loading.jl new file mode 100644 index 000000000..a7b456f41 --- /dev/null +++ b/test/defaults_loading.jl @@ -0,0 +1,34 @@ +using SparseArrays +using LinearSolve +using Test + +n = 10 +dx = 1 / n +dx2 = dx^-2 +vals = Vector{BigFloat}(undef, 0) +cols = Vector{Int}(undef, 0) +rows = Vector{Int}(undef, 0) +for i in 1:n + if i != 1 + push!(vals, dx2) + push!(cols, i - 1) + push!(rows, i) + end + push!(vals, -2dx2) + push!(cols, i) + push!(rows, i) + if i != n + push!(vals, dx2) + push!(cols, i + 1) + push!(rows, i) + end +end +mat = sparse(rows, cols, vals, n, n) +rhs = big.(zeros(n)) +rhs[begin] = rhs[end] = -2 +prob = LinearProblem(mat, rhs) +@test_throws ["SparspakFactorization required", "using Sparspak"] sol=solve(prob).u + +using Sparspak +sol = solve(prob).u +@test sol isa Vector{BigFloat} \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 44be59058..677976feb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,10 @@ if GROUP == "All" || GROUP == "Enzyme" @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") end +if GROUP == "All" || GROUP == "DefaultsLoading" + @time @safetestset "Enzyme Derivative Rules" include("defaults_loading.jl") +end + if GROUP == "LinearSolveCUDA" Pkg.activate("gpu") Pkg.develop(PackageSpec(path = dirname(@__DIR__)))