From e265f0a77ef12124d2f9558324343eb96d99e1d1 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 21 Mar 2025 12:09:59 +0000 Subject: [PATCH 1/4] initial python type FieldsplitSNES implementation --- firedrake/preconditioners/__init__.py | 1 + firedrake/preconditioners/fieldsplit_snes.py | 82 ++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 firedrake/preconditioners/fieldsplit_snes.py diff --git a/firedrake/preconditioners/__init__.py b/firedrake/preconditioners/__init__.py index 491a73657b..1a74065aa1 100644 --- a/firedrake/preconditioners/__init__.py +++ b/firedrake/preconditioners/__init__.py @@ -13,3 +13,4 @@ from firedrake.preconditioners.hiptmair import * # noqa: F401 from firedrake.preconditioners.facet_split import * # noqa: F401 from firedrake.preconditioners.bddc import * # noqa: F401 +from firedrake.preconditioners.fieldsplit_snes import * # noqa: F401 diff --git a/firedrake/preconditioners/fieldsplit_snes.py b/firedrake/preconditioners/fieldsplit_snes.py new file mode 100644 index 0000000000..579318594f --- /dev/null +++ b/firedrake/preconditioners/fieldsplit_snes.py @@ -0,0 +1,82 @@ +from firedrake.preconditioners.base import SNESBase +from firedrake.petsc import PETSc +from firedrake.dmhooks import get_appctx +from firedrake.function import Function + +__all__ = ("FieldsplitSNES",) + + +class FieldsplitSNES(SNESBase): + prefix = "fieldsplit_" + + # TODO: + # - Allow setting field grouping/ordering like fieldsplit + + def initialize(self, snes): + from firedrake.variational_solver import NonlinearVariationalSolver # ImportError if we do this at file level + ctx = get_appctx(snes.dm) + self.sol = ctx._problem.u_restrict + + # buffer to save solution to outer problem during solve + self.sol_outer = Function(self.sol.function_space()) + + # buffers for shuffling solutions during solve + self.sol_current = Function(self.sol.function_space()) + self.sol_new = Function(self.sol.function_space()) + + # options for setting up the fieldsplit + snes_prefix = snes.getOptionsPrefix() + 'snes_' + self.prefix + # options for each field + sub_prefix = snes.getOptionsPrefix() + self.prefix + + snes_options = PETSc.Options(snes_prefix) + self.fieldsplit_type = snes_options.getString('type', 'additive') + if self.fieldsplit_type not in ('additive', 'multiplicative'): + raise ValueError( + 'FieldsplitSNES option snes_fieldsplit_type must be' + ' "additive" or "multiplicative"') + + split_ctxs = ctx.split([(i,) for i in range(len(self.sol))]) + + self.solvers = tuple( + NonlinearVariationalSolver( + ctx._problem, appctx=ctx.appctx, + options_prefix=sub_prefix+str(i)) + for i, ctx in enumerate(split_ctxs) + ) + + def update(self, snes): + pass + + def step(self, snes, x, f, y): + # store current value of outer solution + self.sol_outer.assign(self.sol) + + # the full form in ctx now has the most up to date solution + with self.sol_current.dat.vec_wo as vec: + x.copy(vec) + self.sol.assign(self.sol_current) + + # The current snes solution x is held in sol_current, and we + # will place the new solution in sol_new. + # The solvers evaluate forms containing sol, so for each + # splitting type sol needs to hold: + # - additive: all fields need to hold sol_current values + # - multiplicative: fields need to hold sol_current before + # they are are solved for, and keep the updated sol_new + # values afterwards. + for solver, u, ucurr, unew in zip(self.solvers, + self.sol.subfunctions, + self.sol_current.subfunctions, + self.sol_new.subfunctions): + solver.solve() + unew.assign(u) + if self.fieldsplit_type == 'additive': + u.assign(ucurr) + + with self.sol_new.dat.vec_ro as vec: + vec.copy(y) + y.aypx(-1, x) + + # restore outer solution + self.sol.assign(self.sol_outer) From 136eb7bbd1baca628a5ecd1ebffd513ba8611c1a Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 21 Mar 2025 12:23:56 +0000 Subject: [PATCH 2/4] initial test to check python type FieldsplitSNES doesn't crash --- .../regression/test_fieldsplit_snes.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/firedrake/regression/test_fieldsplit_snes.py diff --git a/tests/firedrake/regression/test_fieldsplit_snes.py b/tests/firedrake/regression/test_fieldsplit_snes.py new file mode 100644 index 0000000000..d489fea805 --- /dev/null +++ b/tests/firedrake/regression/test_fieldsplit_snes.py @@ -0,0 +1,95 @@ +from firedrake import * + + +def test_fieldsplit_snes(): + re = Constant(100) + nu = Constant(1/re) + + nx = 50 + dt = Constant(0.1) # CFL = dt*nx + + mesh = PeriodicUnitIntervalMesh(nx) + x, = SpatialCoordinate(mesh) + + Vu = VectorFunctionSpace(mesh, "CG", 2) + Vq = FunctionSpace(mesh, "DG", 1) + W = Vu*Vq + + w0 = Function(W) + u0, q0 = w0.subfunctions + u0.project(as_vector([0.5 + 1.0*sin(2*pi*x)])) + q0.interpolate(cos(2*pi*x)) + + def M(u, v): + return inner(u, v)*dx + + def Aburgers(u, v, nu): + return ( + inner(dot(u, nabla_grad(u)), v)*dx + + nu*inner(grad(u), grad(v))*dx + ) + + def Ascalar(q, p, u): + n = FacetNormal(mesh) + un = 0.5*(dot(u, n) + abs(dot(u, n))) + return (- q*div(p*u)*dx + + jump(p)*jump(un*q)*dS) + + # current and next timestep + w = Function(W) + wn = Function(W) + + u, q = split(w) + un, qn = split(wn) + + v, p = TestFunctions(W) + + # Trapezium rule + F = ( + M(un - u, v) + 0.5*dt*(Aburgers(un, v, nu) + Aburgers(u, v, nu)) + + M(qn - q, p) + 0.5*dt*(Ascalar(qn, p, un) + Ascalar(q, p, u)) + ) + + common_params = { + 'snes_converged_reason': None, + 'snes_monitor': None, + 'snes_rtol': 1e-8, + 'snes_atol': 1e-12, + 'ksp_converged_reason': None, + 'ksp_monitor': None, + } + + newton_params = { + 'snes_type': 'newtonls', + 'mat_type': 'aij', + 'ksp_type': 'preonly', + 'pc_type': 'lu', + } + + uparams = common_params | newton_params + qparams = common_params | newton_params | {'snes_type': 'ksponly'} + + python_params = { + 'snes_type': 'nrichardson', + 'npc_snes_type': 'python', + 'npc_snes_python_type': 'firedrake.FieldsplitSNES', + 'npc_snes_fieldsplit_type': 'additive', + 'npc_fieldsplit_0': uparams, + 'npc_fieldsplit_1': qparams, + } + + params = common_params | python_params + + w.assign(w0) + wn.assign(w0) + u, q = w.subfunctions + un, qn = wn.subfunctions + solver = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, wn), + solver_parameters=params, + options_prefix="") + + nsteps = 2 + for i in range(nsteps): + w.assign(wn) + solver.solve() From 3645598fc4d3a9c1c73280593c82a477aaede411 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 16 Apr 2025 11:52:14 -0600 Subject: [PATCH 3/4] snesfieldsplit fix for vector function spaces --- firedrake/preconditioners/fieldsplit_snes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/firedrake/preconditioners/fieldsplit_snes.py b/firedrake/preconditioners/fieldsplit_snes.py index 579318594f..0a514cb29b 100644 --- a/firedrake/preconditioners/fieldsplit_snes.py +++ b/firedrake/preconditioners/fieldsplit_snes.py @@ -1,6 +1,6 @@ from firedrake.preconditioners.base import SNESBase from firedrake.petsc import PETSc -from firedrake.dmhooks import get_appctx +from firedrake.dmhooks import get_appctx, get_function_space from firedrake.function import Function __all__ = ("FieldsplitSNES",) @@ -15,6 +15,7 @@ class FieldsplitSNES(SNESBase): def initialize(self, snes): from firedrake.variational_solver import NonlinearVariationalSolver # ImportError if we do this at file level ctx = get_appctx(snes.dm) + W = get_function_space(snes.dm) self.sol = ctx._problem.u_restrict # buffer to save solution to outer problem during solve @@ -36,7 +37,7 @@ def initialize(self, snes): 'FieldsplitSNES option snes_fieldsplit_type must be' ' "additive" or "multiplicative"') - split_ctxs = ctx.split([(i,) for i in range(len(self.sol))]) + split_ctxs = ctx.split([(i,) for i in range(len(W))]) self.solvers = tuple( NonlinearVariationalSolver( From e0e042540dfb4db51ffdc5c81e291d0475cc3c47 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 16 Apr 2025 15:45:18 -0600 Subject: [PATCH 4/4] Update tests/firedrake/regression/test_fieldsplit_snes.py --- tests/firedrake/regression/test_fieldsplit_snes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/firedrake/regression/test_fieldsplit_snes.py b/tests/firedrake/regression/test_fieldsplit_snes.py index d489fea805..2078699294 100644 --- a/tests/firedrake/regression/test_fieldsplit_snes.py +++ b/tests/firedrake/regression/test_fieldsplit_snes.py @@ -32,8 +32,8 @@ def Aburgers(u, v, nu): def Ascalar(q, p, u): n = FacetNormal(mesh) un = 0.5*(dot(u, n) + abs(dot(u, n))) - return (- q*div(p*u)*dx - + jump(p)*jump(un*q)*dS) + return (- q*div(u*p)*dx + + jump(un*q)*jump(p)*dS) # current and next timestep w = Function(W)