From b337d9d9991dfe58fb39843b3eaa7d847a81031f Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 21 Mar 2025 21:48:50 +0000 Subject: [PATCH 1/2] LinearSolver: fix zero initial guess and update after error --- firedrake/linear_solver.py | 12 ++- .../regression/test_linear_solver.py | 80 +++++++++++++++++++ .../test_linear_solver_change_bc.py | 32 -------- 3 files changed, 89 insertions(+), 35 deletions(-) create mode 100644 tests/firedrake/regression/test_linear_solver.py delete mode 100644 tests/firedrake/regression/test_linear_solver_change_bc.py diff --git a/firedrake/linear_solver.py b/firedrake/linear_solver.py index 1e060ce906..708f1e4f65 100644 --- a/firedrake/linear_solver.py +++ b/firedrake/linear_solver.py @@ -83,7 +83,13 @@ def solve(self, x, b): if b.function_space() != self.b.function_space(): raise ValueError(f"b must be a Cofunction in {self.b.function_space()}.") - self.x.assign(x) self.b.assign(b) - super().solve() - x.assign(self.x) + if self.ksp.getInitialGuessNonzero(): + self.x.assign(x) + else: + self.x.zero() + try: + super().solve() + finally: + # Update x even when ConvergenceError is raised + x.assign(self.x) diff --git a/tests/firedrake/regression/test_linear_solver.py b/tests/firedrake/regression/test_linear_solver.py new file mode 100644 index 0000000000..57c511261a --- /dev/null +++ b/tests/firedrake/regression/test_linear_solver.py @@ -0,0 +1,80 @@ +from firedrake import * +from firedrake.petsc import PETSc +import numpy + + +def test_linear_solver_zero_initial_guess(): + mesh = UnitIntervalMesh(10) + space = FunctionSpace(mesh, "CG", 1) + test = TestFunction(space) + trial = TrialFunction(space) + + solver = LinearSolver(assemble(inner(trial, test) * dx), + solver_parameters={"ksp_type": "preonly", + "pc_type": "jacobi", + "ksp_max_it": 1, + "ksp_initial_guess_nonzero": False}) + b = assemble(inner(Constant(1), test) * dx) + + u1 = Function(space, name="u1") + u1.assign(0) + solver.solve(u1, b) + + u2 = Function(space, name="u2") + u2.assign(1) + solver.solve(u2, b) + assert numpy.allclose(u1.dat.data_ro, u2.dat.data_ro) + + +def test_linear_solver_update_after_error(): + mesh = UnitIntervalMesh(10) + space = FunctionSpace(mesh, "CG", 1) + test = TestFunction(space) + trial = TrialFunction(space) + + solver = LinearSolver(assemble(inner(trial, test) * dx), + solver_parameters={"ksp_type": "cg", + "pc_type": "none", + "ksp_max_it": 1, + "ksp_atol": 1.0e-2}) + b = assemble(inner(Constant(1), test) * dx) + + u = Function(space, name="u") + u.assign(-1) + uinit = Function(u, name="uinit") + try: + solver.solve(u, b) + except firedrake.exceptions.ConvergenceError: + assert solver.ksp.getConvergedReason() == PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT + + assert not numpy.allclose(u.dat.data_ro, uinit.dat.data_ro) + + +def test_linear_solver_change_bc(): + mesh = UnitSquareMesh(4, 4, quadrilateral=False) + V = FunctionSpace(mesh, "P", 1) + u = TrialFunction(V) + v = TestFunction(V) + + a = inner(grad(u), grad(v))*dx + + bcval = Function(V) + x, y = SpatialCoordinate(mesh) + bcval.interpolate(1 + 2*y) + bc = DirichletBC(V, bcval, "on_boundary") + + A = assemble(a, bcs=bc) + b = Cofunction(V.dual()) + + solver = LinearSolver(A) + + uh = Function(V) + + solver.solve(uh, b) + + assert numpy.allclose(uh.dat.data_ro, bc.function_arg.dat.data_ro) + + bcval.interpolate(-(1 + 2*y)) + + solver.solve(uh, b) + assert numpy.allclose(uh.dat.data_ro, bc.function_arg.dat.data_ro) diff --git a/tests/firedrake/regression/test_linear_solver_change_bc.py b/tests/firedrake/regression/test_linear_solver_change_bc.py deleted file mode 100644 index 424a4148a9..0000000000 --- a/tests/firedrake/regression/test_linear_solver_change_bc.py +++ /dev/null @@ -1,32 +0,0 @@ -from firedrake import * -import numpy - - -def test_linear_solver_change_bc(): - mesh = UnitSquareMesh(4, 4, quadrilateral=False) - V = FunctionSpace(mesh, "P", 1) - u = TrialFunction(V) - v = TestFunction(V) - - a = inner(grad(u), grad(v))*dx - - bcval = Function(V) - x, y = SpatialCoordinate(mesh) - bcval.interpolate(1 + 2*y) - bc = DirichletBC(V, bcval, "on_boundary") - - A = assemble(a, bcs=bc) - b = Cofunction(V.dual()) - - solver = LinearSolver(A) - - uh = Function(V) - - solver.solve(uh, b) - - assert numpy.allclose(uh.dat.data_ro, bc.function_arg.dat.data_ro) - - bcval.interpolate(-(1 + 2*y)) - - solver.solve(uh, b) - assert numpy.allclose(uh.dat.data_ro, bc.function_arg.dat.data_ro) From c7f589f60ccd63a9df99f85a40e0b621a4df43b0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 26 Mar 2025 14:32:43 +0000 Subject: [PATCH 2/2] Apply suggestions from code review --- firedrake/linear_solver.py | 5 +--- .../regression/test_linear_solver.py | 23 ------------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/firedrake/linear_solver.py b/firedrake/linear_solver.py index 708f1e4f65..ee151977fd 100644 --- a/firedrake/linear_solver.py +++ b/firedrake/linear_solver.py @@ -84,10 +84,7 @@ def solve(self, x, b): raise ValueError(f"b must be a Cofunction in {self.b.function_space()}.") self.b.assign(b) - if self.ksp.getInitialGuessNonzero(): - self.x.assign(x) - else: - self.x.zero() + self.x.assign(x) try: super().solve() finally: diff --git a/tests/firedrake/regression/test_linear_solver.py b/tests/firedrake/regression/test_linear_solver.py index 57c511261a..227501bd28 100644 --- a/tests/firedrake/regression/test_linear_solver.py +++ b/tests/firedrake/regression/test_linear_solver.py @@ -3,29 +3,6 @@ import numpy -def test_linear_solver_zero_initial_guess(): - mesh = UnitIntervalMesh(10) - space = FunctionSpace(mesh, "CG", 1) - test = TestFunction(space) - trial = TrialFunction(space) - - solver = LinearSolver(assemble(inner(trial, test) * dx), - solver_parameters={"ksp_type": "preonly", - "pc_type": "jacobi", - "ksp_max_it": 1, - "ksp_initial_guess_nonzero": False}) - b = assemble(inner(Constant(1), test) * dx) - - u1 = Function(space, name="u1") - u1.assign(0) - solver.solve(u1, b) - - u2 = Function(space, name="u2") - u2.assign(1) - solver.solve(u2, b) - assert numpy.allclose(u1.dat.data_ro, u2.dat.data_ro) - - def test_linear_solver_update_after_error(): mesh = UnitIntervalMesh(10) space = FunctionSpace(mesh, "CG", 1)