Skip to content

Commit f17f70d

Browse files
committed
LinearSolver: fix zero initial guess and update after error
1 parent 22ff4d1 commit f17f70d

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

firedrake/linear_solver.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from firedrake.petsc import PETSc
55
from pyop2.mpi import internal_comm
66
from firedrake.variational_solver import LinearVariationalProblem, LinearVariationalSolver
7+
from firedrake.solving_utils import ConvergenceError
8+
79

810
__all__ = ["LinearSolver"]
911

@@ -83,7 +85,14 @@ def solve(self, x, b):
8385
if b.function_space() != self.b.function_space():
8486
raise ValueError(f"b must be a Cofunction in {self.b.function_space()}.")
8587

86-
self.x.assign(x)
8788
self.b.assign(b)
88-
super().solve()
89-
x.assign(self.x)
89+
if self.ksp.getInitialGuessNonzero():
90+
self.x.assign(x)
91+
else:
92+
self.x.zero()
93+
try:
94+
super().solve()
95+
x.assign(self.x)
96+
except ConvergenceError as e:
97+
x.assign(self.x)
98+
raise e
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from firedrake import *
2+
from firedrake.petsc import PETSc
3+
import numpy
4+
5+
6+
def test_zero_initial_guess():
7+
mesh = UnitIntervalMesh(10)
8+
space = FunctionSpace(mesh, "CG", 1)
9+
test = TestFunction(space)
10+
trial = TrialFunction(space)
11+
12+
solver = LinearSolver(assemble(inner(trial, test) * dx),
13+
solver_parameters={"ksp_type": "preonly",
14+
"pc_type": "jacobi",
15+
"ksp_max_it": 1,
16+
"ksp_initial_guess_nonzero": False})
17+
b = assemble(inner(Constant(1), test) * dx)
18+
19+
u1 = Function(space, name="u1")
20+
u1.assign(0)
21+
solver.solve(u1, b)
22+
23+
u2 = Function(space, name="u2")
24+
u2.assign(1)
25+
solver.solve(u2, b)
26+
assert numpy.allclose(u1.dat.data_ro, u2.dat.data_ro)
27+
28+
29+
def test_convergence_error_update():
30+
mesh = UnitIntervalMesh(10)
31+
space = FunctionSpace(mesh, "CG", 1)
32+
test = TestFunction(space)
33+
trial = TrialFunction(space)
34+
35+
solver = LinearSolver(assemble(inner(trial, test) * dx),
36+
solver_parameters={"ksp_type": "cg",
37+
"pc_type": "none",
38+
"ksp_max_it": 1,
39+
"ksp_atol": 1.0e-2})
40+
b = assemble(inner(Constant(1), test) * dx)
41+
42+
u = Function(space, name="u")
43+
u.assign(-1)
44+
uinit = Function(u, name="uinit")
45+
try:
46+
solver.solve(u, b)
47+
except firedrake.exceptions.ConvergenceError:
48+
assert solver.ksp.getConvergedReason() == PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT
49+
50+
assert not numpy.allclose(u.dat.data_ro, uinit.dat.data_ro)

0 commit comments

Comments
 (0)