Skip to content

Commit d757452

Browse files
Merge pull request #564 from jClugstor/alias_API
Use Aliasing API for alias_A and alias_b
2 parents 378f67f + ebda656 commit d757452

File tree

12 files changed

+100
-31
lines changed

12 files changed

+100
-31
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ RecursiveArrayTools = "3.8"
105105
RecursiveFactorization = "0.2.14"
106106
Reexport = "1"
107107
SafeTestsets = "0.1"
108-
SciMLBase = "2.26.3"
108+
SciMLBase = "2.70"
109109
SciMLOperators = "0.3.7"
110110
Setfield = "1"
111111
SparseArrays = "1.10"

benchmarks/applelu.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ for i in 1:length(ns)
3939
for j in 1:length(algs)
4040
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
4141
copy(b);
42-
u0 = copy(u0),
43-
alias_A = true,
44-
alias_b = true))
42+
u0 = copy(u0),
43+
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)
44+
))
4545
push!(res[j], luflop(n) / bt / 1e9)
4646
end
4747
end

benchmarks/cudalu.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ for i in 1:length(ns)
3434
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
3535
copy(b);
3636
u0 = copy(u0),
37-
alias_A = true,
38-
alias_b = true))
37+
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
3938
push!(res[j], luflop(n) / bt / 1e9)
4039
end
4140
end

benchmarks/lu.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ for i in 1:length(ns)
4141
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
4242
copy(b);
4343
u0 = copy(u0),
44-
alias_A = true,
45-
alias_b = true))
44+
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
4645
push!(res[j], luflop(n) / bt / 1e9)
4746
end
4847
end

benchmarks/metallu.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ for i in 1:length(ns)
3434
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
3535
copy(b);
3636
u0 = copy(u0),
37-
alias_A = true,
38-
alias_b = true))
37+
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
3938
GC.gc()
4039
push!(res[j], luflop(n) / bt / 1e9)
4140
end

benchmarks/sparselu.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ function run_and_plot(; dims = [1, 2, 3], kmax = 12)
6969
copy($A),
7070
copy($b);
7171
u0 = copy($u0),
72-
alias_A = true,
73-
alias_b = true))
72+
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
7473
push!(res[dim][j], bt)
7574
end
7675
end

docs/src/basics/common_solver_opts.md

+5-9
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,11 @@ in order to give composability. These are also the options taken at `init` time.
66
The following are the options these algorithms take, along with their defaults.
77

88
## General Controls
9-
10-
- `alias_A::Bool`: Whether to alias the matrix `A` or use a copy by default. When `true`,
11-
algorithms like LU-factorization can be faster by reusing the memory via `lu!`,
12-
but care must be taken as the original input will be modified. Default is `true` if the
13-
algorithm is known not to modify `A`, otherwise is `false`.
14-
- `alias_b::Bool`: Whether to alias the matrix `b` or use a copy by default. When `true`,
15-
algorithms can write and change `b` upon usage. Care must be taken as the
16-
original input will be modified. Default is `true` if the algorithm is known not to
17-
modify `b`, otherwise `false`.
9+
- `alias::LinearAliasSpecifier`: Holds the fields `alias_A` and `alias_b` which specify
10+
whether to alias the matrices `A` and `b` respectively. When these fields are `true`,
11+
`A` and `b` can be written to and changed by the solver algorithm. When fields are `nothing`
12+
the default behavior is used, which is to default to `true` when the algorithm is known
13+
not to modify the matrices, and false otherwise.
1814
- `verbose`: Whether to print extra information. Defaults to `false`.
1915
- `assumptions`: Sets the assumptions of the operator in order to effect the default
2016
choice algorithm. See the [Operator Assumptions page for more details](@ref assumptions).

ext/LinearSolveHYPREExt.jl

+41-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
66
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
77
OperatorAssumptions, default_tol, init_cacheval, __issquare,
88
__conditioning, LinearSolveAdjoint
9-
using SciMLBase: LinearProblem, SciMLBase
9+
using SciMLBase: LinearProblem, LinearAliasSpecifier, SciMLBase
1010
using UnPack: @unpack
1111
using Setfield: @set!
1212

@@ -55,7 +55,7 @@ end
5555

5656
function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
5757
args...;
58-
alias_A = false, alias_b = false,
58+
alias = LinearAliasSpecifier(),
5959
# TODO: Implement eltype for HYPREMatrix in HYPRE.jl? Looks useful
6060
# even if it is not AbstractArray.
6161
abstol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
@@ -72,6 +72,45 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
7272
kwargs...)
7373
@unpack A, b, u0, p = prob
7474

75+
if haskey(kwargs, :alias_A) || haskey(kwargs, :alias_b)
76+
aliases = LinearAliasSpecifier()
77+
78+
if haskey(kwargs, :alias_A)
79+
message = "`alias_A` keyword argument is deprecated, to set `alias_A`,
80+
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))"
81+
Base.depwarn(message, :init)
82+
Base.depwarn(message, :solve)
83+
aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A)
84+
end
85+
86+
if haskey(kwargs, :alias_b)
87+
message = "`alias_b` keyword argument is deprecated, to set `alias_b`,
88+
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))"
89+
Base.depwarn(message, :init)
90+
Base.depwarn(message, :solve)
91+
aliases = LinearAliasSpecifier(
92+
alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b)
93+
end
94+
else
95+
if alias isa Bool
96+
aliases = LinearAliasSpecifier(alias = alias)
97+
else
98+
aliases = alias
99+
end
100+
end
101+
102+
if isnothing(aliases.alias_A)
103+
alias_A = false
104+
else
105+
alias_A = aliases.alias_A
106+
end
107+
108+
if isnothing(aliases.alias_b)
109+
alias_b = false
110+
else
111+
alias_b = aliases.alias_b
112+
end
113+
75114
A = A isa HYPREMatrix ? A : HYPREMatrix(A)
76115
b = b isa HYPREVector ? b : HYPREVector(b)
77116
u0 = u0 isa HYPREVector ? u0 : (u0 === nothing ? nothing : HYPREVector(u0))

src/LinearSolve.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using LinearAlgebra
1212
using SparseArrays
1313
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
1414
using LazyArrays: @~, BroadcastArray
15-
using SciMLBase: AbstractLinearAlgorithm
15+
using SciMLBase: AbstractLinearAlgorithm, LinearAliasSpecifier
1616
using SciMLOperators
1717
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
1818
using Setfield

src/common.jl

+40-2
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ __init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltyp
139139

140140
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
141141
args...;
142-
alias_A = default_alias_A(alg, prob.A, prob.b),
143-
alias_b = default_alias_b(alg, prob.A, prob.b),
142+
alias = LinearAliasSpecifier(),
144143
abstol = default_tol(real(eltype(prob.b))),
145144
reltol = default_tol(real(eltype(prob.b))),
146145
maxiters::Int = length(prob.b),
@@ -152,6 +151,45 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
152151
kwargs...)
153152
(; A, b, u0, p) = prob
154153

154+
if haskey(kwargs,:alias_A) || haskey(kwargs,:alias_b)
155+
aliases = LinearAliasSpecifier()
156+
157+
if haskey(kwargs, :alias_A)
158+
message = "`alias_A` keyword argument is deprecated, to set `alias_A`,
159+
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))"
160+
Base.depwarn(message, :init)
161+
Base.depwarn(message, :solve)
162+
aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A)
163+
end
164+
165+
if haskey(kwargs, :alias_b)
166+
message = "`alias_b` keyword argument is deprecated, to set `alias_b`,
167+
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))"
168+
Base.depwarn(message, :init)
169+
Base.depwarn(message, :solve)
170+
aliases = LinearAliasSpecifier(alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b)
171+
end
172+
else
173+
if alias isa Bool
174+
aliases = LinearAliasSpecifier(alias = alias)
175+
else
176+
aliases = alias
177+
end
178+
end
179+
180+
if isnothing(aliases.alias_A)
181+
alias_A = default_alias_A(alg, prob.A, prob.b)
182+
else
183+
alias_A = aliases.alias_A
184+
end
185+
186+
if isnothing(aliases.alias_b)
187+
alias_b = default_alias_b(alg, prob.A, prob.b)
188+
else
189+
alias_b = aliases.alias_b
190+
end
191+
192+
155193
A = if alias_A || A isa SMatrix
156194
A
157195
elseif A isa Array

test/gpu/cuda.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ prob2 = LinearProblem(transpose(A), b)
8484

8585
@testset "Adjoint/Transpose Type: $(alg)" for alg in (NormalCholeskyFactorization(),
8686
CholeskyFactorization(), LUFactorization(), QRFactorization(), nothing)
87-
sol = solve(prob1, alg; alias_A = false)
87+
sol = solve(prob1, alg; alias = LinearAliasSpecifier(alias = LinearAliasSpecifier(alias_A = false)))
8888
@test norm(A' * sol.u .- b) < 1e-5
8989

90-
sol = solve(prob2, alg; alias_A = false)
90+
sol = solve(prob2, alg; alias = LinearAliasSpecifier(alias_A = false))
9191
@test norm(transpose(A) * sol.u .- b) < 1e-5
9292
end

test/resolve.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
2727
alg in [LDLtFactorization] && (A = SymTridiagonal(A))
2828
b = [1.0, 2.0]
2929
prob = LinearProblem(A, b)
30-
linsolve = init(prob, alg(), alias_A = false, alias_b = false)
30+
linsolve = init(prob, alg(), alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
3131
@test solve!(linsolve).u [-2.0, 1.5]
3232
@test !linsolve.isfresh
3333
@test solve!(linsolve).u [-2.0, 1.5]
@@ -48,7 +48,7 @@ end
4848
A = Diagonal([1.0, 4.0])
4949
b = [1.0, 2.0]
5050
prob = LinearProblem(A, b)
51-
linsolve = init(prob, DiagonalFactorization(), alias_A = false, alias_b = false)
51+
linsolve = init(prob, DiagonalFactorization(), alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
5252
@test solve!(linsolve).u [1.0, 0.5]
5353
@test solve!(linsolve).u [1.0, 0.5]
5454
A = Diagonal([1.0, 4.0])
@@ -59,7 +59,7 @@ A = Symmetric([1.0 2.0
5959
2.0 1.0])
6060
b = [1.0, 2.0]
6161
prob = LinearProblem(A, b)
62-
linsolve = init(prob, BunchKaufmanFactorization(), alias_A = false, alias_b = false)
62+
linsolve = init(prob, BunchKaufmanFactorization(), alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
6363
@test solve!(linsolve).u [1.0, 0.0]
6464
@test solve!(linsolve).u [1.0, 0.0]
6565
A = Symmetric([1.0 2.0

0 commit comments

Comments
 (0)