Skip to content

Commit e33a43d

Browse files
farhadrclassdpo
andauthored
solve_shifted_system! method for LBFGS solving step in recursive way (#338)
Co-authored-by: Dominique <[email protected]>
1 parent 4c92b94 commit e33a43d

File tree

5 files changed

+212
-1
lines changed

5 files changed

+212
-1
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Function | Description
7878
`size` | Return the size of a linear operator
7979
`symmetric` | Determine whether the operator is symmetric
8080
`normest` | Estimate the 2-norm
81+
`solve_shifted_system!` | Solves linear system $(B + \sigma I) x = b$, where $B$ is a forward L-BFGS operator and $\sigma \geq 0$.
8182

8283

8384
## Other Operations on Operators

src/lbfgs.jl

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ mutable struct LBFGSData{T, I <: Integer}
1616
b::Vector{Vector{T}}
1717
insert::I
1818
Ax::Vector{T}
19+
shifted_p::Matrix{T} # Temporary matrix used in the computation solve_shifted_system!
20+
shifted_v::Vector{T}
21+
shifted_u::Vector{T}
1922
end
2023

2124
function LBFGSData(
@@ -43,6 +46,9 @@ function LBFGSData(
4346
inverse ? Vector{T}(undef, 0) : [zeros(T, n) for _ = 1:mem],
4447
1,
4548
Vector{T}(undef, n),
49+
Array{T}(undef, (n, 2*mem)),
50+
Vector{T}(undef, 2*mem),
51+
Vector{T}(undef, n)
4652
)
4753
end
4854

src/utilities.jl

+140-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
export check_ctranspose, check_hermitian, check_positive_definite, normest
1+
export check_ctranspose, check_hermitian, check_positive_definite, normest, solve_shifted_system!, ldiv!
2+
import LinearAlgebra.ldiv!
23

34
"""
45
normest(S) estimates the matrix 2-norm of S.
@@ -145,3 +146,141 @@ end
145146

146147
check_positive_definite(M::AbstractMatrix; kwargs...) =
147148
check_positive_definite(LinearOperator(M); kwargs...)
149+
150+
151+
"""
152+
solve_shifted_system!(x, B, b, σ)
153+
154+
Solve linear system (B + σI) x = b, where B is a forward L-BFGS operator and σ ≥ 0.
155+
156+
### Parameters
157+
158+
- `x::AbstractVector{T}`: preallocated vector of length n that is used to store the solution x.
159+
- `B::LBFGSOperator`: forward L-BFGS operator that models a matrix of size n x n.
160+
- `b::AbstractVector{T}`: right-hand side vector of length n.
161+
- `σ::T`: nonnegative shift.
162+
163+
### Returns
164+
165+
- `x::AbstractVector{T}`: solution vector `x` of length n.
166+
167+
### Method
168+
169+
The method uses a two-loop recursion-like approach with modifications to handle the shift `σ`.
170+
171+
### Example
172+
173+
```julia
174+
using Random
175+
176+
# Problem setup
177+
n = 100 # size of the problem
178+
mem = 10 # L-BFGS memory size
179+
scaling = true # enable scaling
180+
181+
# Create an L-BFGS operator
182+
B = LBFGSOperator(n, mem = mem, scaling = scaling)
183+
184+
# Add random {s, y} pairs to the L-BFGS operator
185+
for _ = 1:10
186+
s = rand(n)
187+
y = rand(n)
188+
push!(B, s, y) # Add the {s, y} pair to B
189+
end
190+
191+
# Prepare vectors for the system
192+
x = zeros(n) # Preallocated solution vector
193+
b = rand(n) # Right-hand side vector
194+
σ = 0.1 # Small shift value
195+
196+
# Solve the shifted system
197+
result = solve_shifted_system!(x, B, b, σ)
198+
199+
# Check that the solution is close enough (residual test)
200+
@assert norm(B * x + σ * x - b) / norm(b) < 1e-8
201+
```
202+
203+
### References
204+
205+
Erway, J. B., Jain, V., & Marcia, R. F. Shifted L-BFGS Systems. Optimization Methods and Software, 29(5), pp. 992-1004, 2014.
206+
"""
207+
function solve_shifted_system!(
208+
x::AbstractVector{T},
209+
B::LBFGSOperator{T, I, F1, F2, F3},
210+
b::AbstractVector{T},
211+
σ::T,
212+
) where {T, I, F1, F2, F3}
213+
214+
if σ < 0
215+
throw(ArgumentError("σ must be nonnegative"))
216+
end
217+
data = B.data
218+
insert = data.insert
219+
220+
γ_inv = 1 / data.scaling_factor
221+
x_0 = 1 / (γ_inv + σ)
222+
@. x = x_0 * b
223+
224+
max_i = 2 * data.mem
225+
sign_i = 1
226+
227+
for i = 1:max_i
228+
j = (i + 1) ÷ 2
229+
k = mod(insert + j - 1, data.mem) + 1
230+
data.shifted_u .= ((sign_i == -1) ? data.b[k] : data.a[k])
231+
232+
@. data.shifted_p[:, i] = x_0 * data.shifted_u
233+
234+
sign_t = 1
235+
for t = 1:(i - 1)
236+
c0 = dot(view(data.shifted_p, :, t), data.shifted_u)
237+
c1= sign_t .*data.shifted_v[t]
238+
c2 = c1 * c0
239+
view(data.shifted_p, :, i) .+= c2 .* view(data.shifted_p, :, t)
240+
sign_t = -sign_t
241+
end
242+
243+
data.shifted_v[i] = 1 / (1 - sign_i * dot(data.shifted_u, view(data.shifted_p, :, i)))
244+
x .+= sign_i *data.shifted_v[i] * (view(data.shifted_p, :, i)' * b) .* view(data.shifted_p, :, i)
245+
sign_i = -sign_i
246+
end
247+
return x
248+
end
249+
250+
251+
"""
252+
ldiv!(x, B, b)
253+
254+
Solves the linear system Bx = b.
255+
256+
### Arguments:
257+
258+
- `x::AbstractVector{T}`: preallocated vector of length n that is used to store the solution x.
259+
- `B::LBFGSOperator`: forward L-BFGS operator that models a matrix of size n x n.
260+
- `b::AbstractVector{T}`: right-hand side vector of length n.
261+
### Returns:
262+
263+
- `x::AbstractVector{T}`: The modified solution vector containing the solution to the linear system.
264+
265+
### Examples:
266+
267+
```julia
268+
269+
# Create an L-BFGS operator
270+
B = LBFGSOperator(10)
271+
272+
# Generate random vectors
273+
x = rand(10)
274+
b = rand(10)
275+
276+
# Solve the linear system
277+
ldiv!(x, B, b)
278+
279+
# The vector `x` now contains the solution
280+
"""
281+
282+
function ldiv!(x::AbstractVector{T}, B::LBFGSOperator{T, I, F1, F2, F3}, b::AbstractVector{T}) where {T, I, F1, F2, F3}
283+
# Call solve_shifted_system! with σ = 0
284+
solve_shifted_system!(x, B, b, T(0.0))
285+
return x
286+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ include("test_deprecated.jl")
1515
include("test_normest.jl")
1616
include("test_diag.jl")
1717
include("test_chainrules.jl")
18+
include("test_solve_shifted_system.jl")

test/test_solve_shifted_system.jl

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Test
2+
using LinearOperators
3+
using LinearAlgebra
4+
5+
function setup_test_val(; M = 5, n = 100, scaling = false, σ = 0.1)
6+
B = LBFGSOperator(n, mem = M, scaling = scaling)
7+
H = InverseLBFGSOperator(n, mem = M, scaling = false)
8+
9+
for _ = 1:10
10+
s = rand(n)
11+
y = rand(n)
12+
push!(B, s, y)
13+
push!(H, s, y)
14+
end
15+
16+
x = randn(n)
17+
b = B * x + σ .* x # so we know the true answer is x
18+
19+
return B, H , b, σ, zeros(n), x
20+
end
21+
22+
function test_solve_shifted_system()
23+
@testset "solve_shifted_system! Default setup test" begin
24+
# Setup Test Case 1: Default setup from setup_test_val
25+
B,_, b, σ, x_sol, x_true = setup_test_val(n = 100, M = 5)
26+
27+
result = solve_shifted_system!(x_sol, B, b, σ)
28+
29+
# Test 1: Check if result is a vector of the same size as z
30+
@test length(result) == length(b)
31+
32+
# Test 2: Verify that x_sol (result) is modified in place
33+
@test result === x_sol
34+
35+
# Test 3: Check if the function produces finite values
36+
@test all(isfinite, result)
37+
38+
# Test 4: Check if x_sol is close to the known solution x
39+
@test isapprox(x_sol, x_true, atol = 1e-6, rtol = 1e-6)
40+
end
41+
@testset "solve_shifted_system! Negative σ test" begin
42+
# Setup Test Case 2: Negative σ
43+
B,_, b, _, x_sol, _ = setup_test_val(n = 100, M = 5)
44+
σ = -0.1
45+
46+
# Expect an ArgumentError to be thrown
47+
@test_throws ArgumentError solve_shifted_system!(x_sol, B, b, σ)
48+
end
49+
50+
@testset "ldiv! test" begin
51+
# Setup Test Case 1: Default setup from setup_test_val
52+
B, H, b, _, x_sol, x_true = setup_test_val(n = 100, M = 5, σ = 0.0)
53+
54+
# Solve the system using solve_shifted_system!
55+
result = ldiv!(x_sol, B, b)
56+
57+
# Check consistency with operator-vector product using H
58+
x_H = H * b
59+
@test isapprox(x_sol, x_H, atol = 1e-6, rtol = 1e-6)
60+
@test isapprox(x_sol, x_true, atol = 1e-6, rtol = 1e-6)
61+
end
62+
end
63+
64+
test_solve_shifted_system()

0 commit comments

Comments
 (0)