Skip to content

SLATE: add support for ufl.replace #4093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from firedrake.functionspace import MixedFunctionSpace
from firedrake.cofunction import Cofunction
from firedrake.matrix import AssembledMatrix
from firedrake import slate


def subspace(V, indices):
Expand Down Expand Up @@ -77,6 +78,10 @@ def split(self, form, argument_indices):
assert (len(idx) == 1 for idx in self.blocks.values())
assert (idx[0] == 0 for idx in self.blocks.values())
return form

if isinstance(form, slate.slate.TensorBase):
return slate.slate.Block(form, argument_indices)

# TODO find a way to distinguish empty Forms avoiding expand_derivatives
f = map_integrand_dags(self, form)
if expand_derivatives(f).empty():
Expand Down
41 changes: 38 additions & 3 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
compiler, which interprets expressions and produces C++ kernel
functions to be executed within the Firedrake architecture.
"""
from abc import ABCMeta, abstractproperty, abstractmethod
from abc import abstractproperty, abstractmethod

from collections import OrderedDict, namedtuple, defaultdict

Expand Down Expand Up @@ -117,7 +117,7 @@ def __call__(self):
return self


class TensorBase(object, metaclass=ABCMeta):
class TensorBase(BaseForm):
"""An abstract Slate node class.

.. warning::
Expand Down Expand Up @@ -156,6 +156,15 @@ def _metakernel_cache(self):
def children(self):
return self.operands

@property
def ufl_operands(self):
return self.operands

def _ufl_expr_reconstruct_(self, *operands):
if len(operands) == 0:
return self
return self.reconstruct(*operands)

@cached_property
def expression_hash(self):
from firedrake.slate.slac.utils import traverse_dags
Expand Down Expand Up @@ -451,6 +460,12 @@ def __new__(cls, function):
raise TypeError("Expecting a BaseCoefficient or AssembledVector (not a %r)" %
type(function))

def reconstruct(self, form):
"""Reconstructs this TensorBase with new operands."""
if not isinstance(form, BaseCoefficient):
form = Function(self.form.function_space()).interpolate(form)
return as_slate(form)

@cached_property
def form(self):
return self._function
Expand All @@ -474,7 +489,11 @@ def _argument(self):

def arguments(self):
"""Returns a tuple of arguments associated with the tensor."""
return (self._argument,)
tensor = self._function
if isinstance(tensor, BaseForm):
return tensor.arguments()
else:
return (self._argument,)

def coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""
Expand Down Expand Up @@ -662,6 +681,10 @@ def __init__(self, tensor, indices):
self._blocks = dict(enumerate(map(as_tuple, indices)))
self._indices = indices

def reconstruct(self, tensor, indices=None):
"""Reconstructs this TensorBase with new operands."""
return Block(tensor, indices=indices or self._indices)

@cached_property
def terminal(self):
"""Blocks are only terminal when they sit on Tensors or AssembledVectors"""
Expand Down Expand Up @@ -788,6 +811,10 @@ def __init__(self, tensor, decomposition=None):
self.operands = (tensor,)
self.decomposition = decomposition

def reconstruct(self, tensor, decomposition=None):
"""Reconstructs this TensorBase with new operands."""
return Factorization(tensor, decomposition=decomposition or self.decomposition)

@cached_property
def arg_function_spaces(self):
"""Returns a tuple of function spaces that the tensor
Expand Down Expand Up @@ -894,6 +921,10 @@ def __init__(self, form, diagonal=False):
self.form = form
self.diagonal = diagonal

def reconstruct(self, form, diagonal=None):
"""Reconstructs this TensorBase with new operands."""
return Tensor(form, diagonal=diagonal or self.diagonal)

@cached_property
def arg_function_spaces(self):
"""Returns a tuple of function spaces that the tensor
Expand Down Expand Up @@ -961,6 +992,10 @@ def __init__(self, *operands):
super(TensorOp, self).__init__()
self.operands = tuple(operands)

def reconstruct(self, *operands):
"""Reconstructs this TensorBase with new operands."""
return type(self)(*operands)

def coefficients(self):
"""Returns the expected coefficients of the resulting tensor."""
coeffs = [op.coefficients() for op in self.operands]
Expand Down
4 changes: 3 additions & 1 deletion firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,12 @@ def _extract_args(*args, **kwargs):
near_nullspace = kwargs.get("near_nullspace", None)
# Extract parameters
form_compiler_parameters = kwargs.get("form_compiler_parameters", {})
solver_parameters = kwargs.get("solver_parameters", {})
solver_parameters = kwargs.get("solver_parameters", None)
options_prefix = kwargs.get("options_prefix", None)
restrict = kwargs.get("restrict", False)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)
if solver_parameters is None:
solver_parameters = {}

return eq, u, bcs, J, Jp, M, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, near_nullspace, \
Expand Down
21 changes: 9 additions & 12 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from firedrake.exceptions import ConvergenceError
from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.ufl_expr import replace
from firedrake.utils import cached_property
from firedrake.logging import warning

Expand Down Expand Up @@ -233,8 +234,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
self._bc_residual = Function(self._x.function_space())
if problem.is_linear:
# Drop existing lifting term from the residual
assert isinstance(self.F, ufl.BaseForm)
self.F = ufl.replace(self.F, {self._x: ufl.zero(self._x.ufl_shape)})
self.F = replace(self.F, {self._x: ufl.zero(self._x.ufl_shape)})

self.F -= problem.compute_bc_lifting(self.J, self._bc_residual)

Expand Down Expand Up @@ -319,11 +319,11 @@ def set_nullspace(self, nullspace, ises=None, transpose=False, near=False):

@PETSc.Log.EventDecorator()
def split(self, fields):
from firedrake import replace, as_vector, split, zero
from firedrake import as_vector, split, zero
from firedrake import NonlinearVariationalProblem as NLVP
from firedrake.bcs import DirichletBC, EquationBC
fields = tuple(tuple(f) for f in fields)
splits = self._splits.get(tuple(fields))
splits = self._splits.get(fields)
if splits is not None:
return splits

Expand All @@ -334,7 +334,7 @@ def split(self, fields):
F = splitter.split(problem.F, argument_indices=(field, ))
J = splitter.split(problem.J, argument_indices=(field, field))
us = problem.u_restrict.subfunctions
V = F.arguments()[0].function_space()
V = J.arguments()[-1].function_space()
# Exposition:
# We are going to make a new solution Function on the sub
# mixed space defined by the relevant fields.
Expand All @@ -353,16 +353,13 @@ def split(self, fields):
# Split it apart to shove in the form.
subsplit = split(subu)
vec = []
for i, u in enumerate(us):
for i, ui in enumerate(us):
if i in field:
# If this is a field we're keeping, get it from
# the new function. Otherwise just point to the
# old data.
u = subsplit[field.index(i)]
if u.ufl_shape == ():
vec.append(u)
else:
vec.extend(u[idx] for idx in numpy.ndindex(u.ufl_shape))
ui = subsplit[field.index(i)]
vec.extend(ui[idx] for idx in numpy.ndindex(ui.ufl_shape))

# So now we have a new representation for the solution
# vector in the old problem. For the fields we're going
Expand Down Expand Up @@ -404,7 +401,7 @@ def split(self, fields):
appctx=self.appctx,
transfer_manager=self.transfer_manager,
pre_apply_bcs=self.pre_apply_bcs))
return self._splits.setdefault(tuple(fields), splits)
return self._splits.setdefault(fields, splits)

@staticmethod
def form_function(snes, X, F):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/ufl_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from ufl.split_functions import split
from ufl.algorithms import extract_arguments, extract_coefficients
from ufl.domain import as_domain
from ufl.algorithms.replace import replace

import firedrake
from firedrake import utils, function, cofunction
from firedrake.constant import Constant
from firedrake.petsc import PETSc


__all__ = ['Argument', 'Coargument', 'TestFunction', 'TrialFunction',
'TestFunctions', 'TrialFunctions',
'derivative', 'adjoint',
'action', 'CellSize', 'FacetNormal']
'action', 'replace', 'CellSize', 'FacetNormal']


class Argument(ufl.argument.Argument):
Expand Down Expand Up @@ -263,7 +263,7 @@ def derivative(form, u, du=None, coefficient_derivatives=None):
V = firedrake.FunctionSpace(mesh, "Real", 0)
x = ufl.Coefficient(V)
# TODO: Update this line when https://github.com/FEniCS/ufl/issues/171 is fixed
form = ufl.replace(form, {u: x})
form = replace(form, {u: x})
u_orig, u = u, x
else:
raise RuntimeError("Can't compute derivative for form")
Expand All @@ -286,7 +286,7 @@ def derivative(form, u, du=None, coefficient_derivatives=None):
if isinstance(uc, firedrake.Constant):
# If we replaced constants with ``x`` to differentiate,
# replace them back to the original symbolic constant
dform = ufl.replace(dform, {u: u_orig})
dform = replace(dform, {u: u_orig})
return dform


Expand Down
4 changes: 2 additions & 2 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
)
from firedrake.function import Function
from firedrake.matrix import MatrixBase
from firedrake.ufl_expr import TrialFunction, TestFunction, action
from firedrake.ufl_expr import TrialFunction, TestFunction, action, replace
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
from firedrake.__future__ import interpolate
from ufl import replace, Form
from ufl import Form

__all__ = ["LinearVariationalProblem",
"LinearVariationalSolver",
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/slate/test_cg_poisson.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from firedrake import *
from firedrake.petsc import DEFAULT_DIRECT_SOLVER
import numpy as np


def run_CG_problem(r, degree, quads=False):
Expand Down Expand Up @@ -62,7 +63,6 @@ def run_CG_problem(r, degree, quads=False):
[(3, False, 3.75),
(5, True, 5.75)])
def test_cg_convergence(degree, quads, rate):
import numpy as np
diff = np.array([run_CG_problem(r, degree, quads) for r in range(2, 5)])
conv = np.log2(diff[:-1] / diff[1:])
assert (np.array(conv) > rate).all()
86 changes: 86 additions & 0 deletions tests/firedrake/slate/test_stabilized_stokes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from firedrake import *
from firedrake.petsc import DEFAULT_DIRECT_SOLVER
import numpy as np


def run_stabilized_stokes(r, degree, quads):
"""
Test that we can solve problems involving local projections in Slate.

We formulate Stokes with equal-order continuous Lagrange elements
and stabilize the discretizations by adding a mass matrix on the homogeneous
pressure subspace.

Reference: arxiv.org/abs/2407.07498
"""
msh = UnitSquareMesh(2**r, 2**r, quadrilateral=quads)
V = VectorFunctionSpace(msh, "CG", degree)
Q = FunctionSpace(msh, "CG", degree)
Z = V * Q

u, p = TrialFunctions(Z)
v, q = TestFunctions(Z)

# Stokes PDE
nu = Constant(1)
gamma = Constant(1E-1)
a = Tensor((inner(grad(u)*nu, grad(v)) + inner(div(u)*gamma, div(v))) * dx)
b = Tensor(-inner(div(u), q) * dx)

# Stabilization on DG space
X = FunctionSpace(msh, "DG", degree-1)
tau = TestFunction(X)
sig = TrialFunction(X)

h = 1/(nu + gamma)
mcc = Tensor(h * inner(p, q) * dx)
mdd = Tensor(h * inner(sig, tau) * dx)
mdc = Tensor(h * inner(p, tau) * dx)
mcd = Tensor(h * inner(sig, q) * dx)

s = mcc - mcd * Inverse(mdd) * mdc

# Saddle-point bilinear form
A = a + b.T + b - s

# Upper-triangular preconditioner
aP = a + b.T - (mcc + s)

x, y = SpatialCoordinate(msh)
bcs = [DirichletBC(Z.sub(0), as_vector([4*y*(1-y), 0]), (1,)),
DirichletBC(Z.sub(0), 0, (3, 4))]

solver_parameters = {
"mat_type": "nest",
"snes_type": "ksponly",
"ksp_type": "gmres",
"ksp_max_it": 40,
"ksp_rtol": 1E-10,
"ksp_monitor": None,
"pc_type": "fieldsplit",
"pc_fieldsplit_type": "schur",
"pc_fieldsplit_schur_type": "upper",
"fieldsplit_ksp_type": "preonly",
"fieldsplit_pc_type": "lu",
"fieldsplit_pc_factor_mat_solver_type": DEFAULT_DIRECT_SOLVER,
}
z = Function(Z)
L = 0
problem = LinearVariationalProblem(A, L, z, aP=aP, bcs=bcs)
solver = LinearVariationalSolver(problem, solver_parameters=solver_parameters)
solver.solve()

u = z.subfunctions[0]
return norm(div(u))


@pytest.mark.parametrize(('degree', 'quads', 'rate'),
[(1, True, 1.0),
(1, False, 1.0),
(2, True, 2.0)])
def test_stabilized_stokes(degree, quads, rate):
diff = np.array([run_stabilized_stokes(r, degree, quads) for r in range(3, 6)])
conv = np.log2(diff[:-1] / diff[1:])
tol = 1E-10
assert (c > rate or d < tol for d, c in zip(diff[1:], conv))
Loading