Skip to content

Python type FieldsplitSNES #4139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firedrake/preconditioners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from firedrake.preconditioners.hiptmair import * # noqa: F401
from firedrake.preconditioners.facet_split import * # noqa: F401
from firedrake.preconditioners.bddc import * # noqa: F401
from firedrake.preconditioners.fieldsplit_snes import * # noqa: F401
83 changes: 83 additions & 0 deletions firedrake/preconditioners/fieldsplit_snes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from firedrake.preconditioners.base import SNESBase
from firedrake.petsc import PETSc
from firedrake.dmhooks import get_appctx, get_function_space
from firedrake.function import Function

__all__ = ("FieldsplitSNES",)


class FieldsplitSNES(SNESBase):
prefix = "fieldsplit_"

# TODO:
# - Allow setting field grouping/ordering like fieldsplit

def initialize(self, snes):
from firedrake.variational_solver import NonlinearVariationalSolver # ImportError if we do this at file level
ctx = get_appctx(snes.dm)
W = get_function_space(snes.dm)
self.sol = ctx._problem.u_restrict

# buffer to save solution to outer problem during solve
self.sol_outer = Function(self.sol.function_space())

# buffers for shuffling solutions during solve
self.sol_current = Function(self.sol.function_space())
self.sol_new = Function(self.sol.function_space())

# options for setting up the fieldsplit
snes_prefix = snes.getOptionsPrefix() + 'snes_' + self.prefix
# options for each field
sub_prefix = snes.getOptionsPrefix() + self.prefix

snes_options = PETSc.Options(snes_prefix)
self.fieldsplit_type = snes_options.getString('type', 'additive')
if self.fieldsplit_type not in ('additive', 'multiplicative'):
raise ValueError(
'FieldsplitSNES option snes_fieldsplit_type must be'
' "additive" or "multiplicative"')

split_ctxs = ctx.split([(i,) for i in range(len(W))])

self.solvers = tuple(
NonlinearVariationalSolver(
ctx._problem, appctx=ctx.appctx,
options_prefix=sub_prefix+str(i))
for i, ctx in enumerate(split_ctxs)
)

def update(self, snes):
pass

def step(self, snes, x, f, y):
# store current value of outer solution
self.sol_outer.assign(self.sol)

# the full form in ctx now has the most up to date solution
with self.sol_current.dat.vec_wo as vec:
x.copy(vec)
self.sol.assign(self.sol_current)

# The current snes solution x is held in sol_current, and we
# will place the new solution in sol_new.
# The solvers evaluate forms containing sol, so for each
# splitting type sol needs to hold:
# - additive: all fields need to hold sol_current values
# - multiplicative: fields need to hold sol_current before
# they are are solved for, and keep the updated sol_new
# values afterwards.
for solver, u, ucurr, unew in zip(self.solvers,
self.sol.subfunctions,
self.sol_current.subfunctions,
self.sol_new.subfunctions):
solver.solve()
unew.assign(u)
if self.fieldsplit_type == 'additive':
u.assign(ucurr)

with self.sol_new.dat.vec_ro as vec:
vec.copy(y)
y.aypx(-1, x)

# restore outer solution
self.sol.assign(self.sol_outer)
95 changes: 95 additions & 0 deletions tests/firedrake/regression/test_fieldsplit_snes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from firedrake import *


def test_fieldsplit_snes():
re = Constant(100)
nu = Constant(1/re)

nx = 50
dt = Constant(0.1) # CFL = dt*nx

mesh = PeriodicUnitIntervalMesh(nx)
x, = SpatialCoordinate(mesh)

Vu = VectorFunctionSpace(mesh, "CG", 2)
Vq = FunctionSpace(mesh, "DG", 1)
W = Vu*Vq

w0 = Function(W)
u0, q0 = w0.subfunctions
u0.project(as_vector([0.5 + 1.0*sin(2*pi*x)]))
q0.interpolate(cos(2*pi*x))

def M(u, v):
return inner(u, v)*dx

def Aburgers(u, v, nu):
return (
inner(dot(u, nabla_grad(u)), v)*dx
+ nu*inner(grad(u), grad(v))*dx
)

def Ascalar(q, p, u):
n = FacetNormal(mesh)
un = 0.5*(dot(u, n) + abs(dot(u, n)))
return (- q*div(u*p)*dx
+ jump(un*q)*jump(p)*dS)

# current and next timestep
w = Function(W)
wn = Function(W)

u, q = split(w)
un, qn = split(wn)

v, p = TestFunctions(W)

# Trapezium rule
F = (
M(un - u, v) + 0.5*dt*(Aburgers(un, v, nu) + Aburgers(u, v, nu))
+ M(qn - q, p) + 0.5*dt*(Ascalar(qn, p, un) + Ascalar(q, p, u))
)

common_params = {
'snes_converged_reason': None,
'snes_monitor': None,
'snes_rtol': 1e-8,
'snes_atol': 1e-12,
'ksp_converged_reason': None,
'ksp_monitor': None,
}

newton_params = {
'snes_type': 'newtonls',
'mat_type': 'aij',
'ksp_type': 'preonly',
'pc_type': 'lu',
}

uparams = common_params | newton_params
qparams = common_params | newton_params | {'snes_type': 'ksponly'}

python_params = {
'snes_type': 'nrichardson',
'npc_snes_type': 'python',
'npc_snes_python_type': 'firedrake.FieldsplitSNES',
'npc_snes_fieldsplit_type': 'additive',
'npc_fieldsplit_0': uparams,
'npc_fieldsplit_1': qparams,
}

params = common_params | python_params

w.assign(w0)
wn.assign(w0)
u, q = w.subfunctions
un, qn = wn.subfunctions
solver = NonlinearVariationalSolver(
NonlinearVariationalProblem(F, wn),
solver_parameters=params,
options_prefix="")

nsteps = 2
for i in range(nsteps):
w.assign(wn)
solver.solve()
Loading