Skip to content

Commit fcc7719

Browse files
Add FiniteElement python wrapper (#3542)
* Start to wrap FiniteElement * Ruff * Ufl naming * More wrapping * ruff * last one? * Misse done * nomatching * Add docstrings * Add tpye hints * Remove returns for properties * Add argument doc strings * more document args * Rename: dtype -> FiniteElement_dtype * finite_element -> finiteelement * Ruff * Apply suggestion * Switch to cached_propert for element access * FixreStructured text * Fix return type descriptions
1 parent 3aed8d5 commit fcc7719

File tree

3 files changed

+210
-42
lines changed

3 files changed

+210
-42
lines changed

python/dolfinx/fem/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
locate_dofs_topological,
3333
)
3434
from dolfinx.fem.dofmap import DofMap
35-
from dolfinx.fem.element import CoordinateElement, coordinate_element
35+
from dolfinx.fem.element import CoordinateElement, FiniteElement, coordinate_element, finiteelement
3636
from dolfinx.fem.forms import (
3737
Form,
3838
compile_form,
@@ -91,7 +91,11 @@ def create_interpolation_data(
9191
"""
9292
return _PointOwnershipData(
9393
_create_interpolation_data(
94-
V_to.mesh._cpp_object.geometry, V_to.element, V_from.mesh._cpp_object, cells, padding
94+
V_to.mesh._cpp_object.geometry,
95+
V_to.element._cpp_object,
96+
V_from.mesh._cpp_object,
97+
cells,
98+
padding,
9599
)
96100
)
97101

@@ -169,6 +173,7 @@ def compute_integration_domains(
169173
"DofMap",
170174
"ElementMetaData",
171175
"Expression",
176+
"FiniteElement",
172177
"Form",
173178
"Function",
174179
"FunctionSpace",
@@ -189,6 +194,7 @@ def compute_integration_domains(
189194
"dirichletbc",
190195
"discrete_gradient",
191196
"extract_function_spaces",
197+
"finiteelement",
192198
"form",
193199
"form_cpp_class",
194200
"functionspace",

python/dolfinx/fem/element.py

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2024 Garth N. Wells
1+
# Copyright (C) 2024 Garth N. Wells and Paul T. Kühner
22
#
33
# This file is part of DOLFINx (https://www.fenicsproject.org)
44
#
@@ -12,6 +12,8 @@
1212
import numpy.typing as npt
1313

1414
import basix
15+
import ufl
16+
import ufl.finiteelement
1517
from dolfinx import cpp as _cpp
1618

1719

@@ -93,7 +95,7 @@ def pull_back(
9395
``shape=(num_points, geometrical_dimension)``.
9496
cell_geometry: Physical coordinates describing the cell,
9597
shape ``(num_of_geometry_basis_functions, geometrical_dimension)``
96-
They can be created by accessing `geometry.x[geometry.dofmap.cell_dofs(i)]`,
98+
They can be created by accessing ``geometry.x[geometry.dofmap.cell_dofs(i)]``,
9799
98100
Returns:
99101
Reference coordinates of the physical points ``x``.
@@ -160,3 +162,190 @@ def _(e: basix.finite_element.FiniteElement):
160162
return CoordinateElement(_cpp.fem.CoordinateElement_float32(e._e))
161163
except TypeError:
162164
return CoordinateElement(_cpp.fem.CoordinateElement_float64(e._e))
165+
166+
167+
class FiniteElement:
168+
_cpp_object: typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]
169+
170+
def __init__(
171+
self,
172+
cpp_object: typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64],
173+
):
174+
"""Creates a Python wrapper for the exported finite element class.
175+
176+
Note:
177+
Do not use this constructor directly. Instead use :func:``finiteelement``.
178+
179+
Args:
180+
The underlying cpp instance that this object will wrap.
181+
"""
182+
self._cpp_object = cpp_object
183+
184+
def __eq__(self, other):
185+
return self._cpp_object == other._cpp_object
186+
187+
@property
188+
def dtype(self) -> np.dtype:
189+
"""Geometry type of the Mesh that the FunctionSpace is defined on."""
190+
return self._cpp_object.dtype
191+
192+
@property
193+
def basix_element(self) -> basix.finite_element.FiniteElement:
194+
"""Return underlying Basix C++ element (if it exists).
195+
196+
Raises:
197+
Runtime error if Basix element does not exist.
198+
"""
199+
return self._cpp_object.basix_element
200+
201+
@property
202+
def num_sub_elements(self) -> int:
203+
"""Number of sub elements (for a mixed or blocked element)."""
204+
return self._cpp_object.num_sub_elements
205+
206+
@property
207+
def value_shape(self) -> npt.NDArray[np.integer]:
208+
"""Value shape of the finite element field.
209+
210+
The value shape describes the shape of the finite element field, e.g. ``{}`` for a scalar,
211+
``{2}`` for a vector in 2D, ``{3, 3}`` for a rank-2 tensor in 3D, etc.
212+
"""
213+
return self._cpp_object.value_shape
214+
215+
@property
216+
def interpolation_points(self) -> npt.NDArray[np.floating]:
217+
"""Points on the reference cell at which an expression needs to be evaluated in order to
218+
interpolate the expression in the finite element space.
219+
220+
Interpolation point coordinates on the reference cell, returning the coordinates data
221+
(row-major) storage with shape ``(num_points, tdim)``.
222+
223+
Note:
224+
For Lagrange elements the points will just be the nodal positions. For other elements
225+
the points will typically be the quadrature points used to evaluate moment degrees of
226+
freedom.
227+
"""
228+
return self._cpp_object.interpolation_points
229+
230+
@property
231+
def interpolation_ident(self) -> bool:
232+
"""Check if interpolation into the finite element space is an identity operation given the
233+
evaluation on an expression at specific points, i.e. the degree-of-freedom are equal to
234+
point evaluations. The function will return `true` for Lagrange elements."""
235+
return self._cpp_object.interpolation_ident
236+
237+
@property
238+
def space_dimension(self) -> int:
239+
"""Dimension of the finite element function space (the number of degrees-of-freedom for the
240+
element).
241+
242+
For 'blocked' elements, this function returns the dimension of the full element rather than
243+
the dimension of the base element.
244+
"""
245+
return self._cpp_object.space_dimension
246+
247+
@property
248+
def needs_dof_transformations(self) -> bool:
249+
"""Check if DOF transformations are needed for this element.
250+
251+
DOF transformations will be needed for elements which might not be continuous when two
252+
neighbouring cells disagree on the orientation of a shared sub-entity, and when this cannot
253+
be corrected for by permuting the DOF numbering in the dofmap.
254+
255+
For example, Raviart-Thomas elements will need DOF transformations, as the neighbouring
256+
cells may disagree on the orientation of a basis function, and this orientation cannot be
257+
corrected for by permuting the DOF numbers on each cell.
258+
"""
259+
return self._cpp_object.needs_dof_transformations
260+
261+
@property
262+
def signature(self) -> str:
263+
"""String identifying the finite element."""
264+
return self._cpp_object.signature
265+
266+
def T_apply(self, x: npt.NDArray[np.floating], cell_permutations: np.int32, dim: int) -> None:
267+
"""Transform basis functions from the reference element ordering and orientation to the
268+
globally consistent physical element ordering and orientation.
269+
270+
Args:
271+
x: Data to transform (in place). The shape is ``(m, n)``, where `m` is the number of
272+
dgerees-of-freedom and the storage is row-major.
273+
cell_permutations: Permutation data for the cell.
274+
dim: Number of columns in ``data``.
275+
276+
Note:
277+
Exposed for testing. Function is not vectorised across multiple cells. Please see
278+
`basix.numba_helpers` for performant versions.
279+
"""
280+
self._cpp_object.T_apply(x, cell_permutations, dim)
281+
282+
def Tt_apply(self, x: npt.NDArray[np.floating], cell_permutations: np.int32, dim: int) -> None:
283+
"""Apply the transpose of the operator applied by T_apply().
284+
285+
Args:
286+
x: Data to transform (in place). The shape is ``(m, n)``, where `m` is the number of
287+
dgerees-of-freedom and the storage is row-major.
288+
cell_permutations: Permutation data for the cell.
289+
dim: Number of columns in `data`.
290+
291+
Note:
292+
Exposed for testing. Function is not vectorised across multiple cells. Please see
293+
`basix.numba_helpers` for performant versions.
294+
"""
295+
self._cpp_object.Tt_apply(x, cell_permutations, dim)
296+
297+
def Tt_inv_apply(
298+
self, x: npt.NDArray[np.floating], cell_permutations: np.int32, dim: int
299+
) -> None:
300+
"""Apply the inverse transpose of the operator applied by T_apply().
301+
302+
Args:
303+
x: Data to transform (in place). The shape is ``(m, n)``, where ``m`` is the number of
304+
dgerees-of-freedom and the storage is row-major.
305+
cell_permutations: Permutation data for the cell.
306+
dim: Number of columns in `data`.
307+
308+
Note:
309+
Exposed for testing. Function is not vectorised across multiple cells. Please see
310+
``basix.numba_helpers`` for performant versions.
311+
"""
312+
self._cpp_object.Tt_apply(x, cell_permutations, dim)
313+
314+
315+
def finiteelement(
316+
cell_type: _cpp.mesh.CellType,
317+
ufl_e: ufl.finiteelement,
318+
FiniteElement_dtype: np.dtype,
319+
) -> FiniteElement:
320+
"""Create a DOLFINx element from a basix.ufl element.
321+
322+
Args:
323+
cell_type: Element cell type, see ``mesh.CellType``
324+
ufl_e: UFL element, holding quadrature rule and other properties of the selected element.
325+
FiniteElement_dtype: Geometry type of the element.
326+
"""
327+
if np.issubdtype(FiniteElement_dtype, np.float32):
328+
CppElement = _cpp.fem.FiniteElement_float32
329+
elif np.issubdtype(FiniteElement_dtype, np.float64):
330+
CppElement = _cpp.fem.FiniteElement_float64
331+
else:
332+
raise ValueError(f"Unsupported dtype: {FiniteElement_dtype}")
333+
334+
if ufl_e.is_mixed:
335+
elements = [
336+
finiteelement(cell_type, e, FiniteElement_dtype)._cpp_object for e in ufl_e.sub_elements
337+
]
338+
return FiniteElement(CppElement(elements))
339+
elif ufl_e.is_quadrature:
340+
return FiniteElement(
341+
CppElement(
342+
cell_type,
343+
ufl_e.custom_quadrature()[0],
344+
ufl_e.reference_value_shape,
345+
ufl_e.is_symmetric,
346+
)
347+
)
348+
else:
349+
basix_e = ufl_e.basix_element._e
350+
value_shape = ufl_e.reference_value_shape if ufl_e.block_size > 1 else None
351+
return FiniteElement(CppElement(basix_e, value_shape, ufl_e.is_symmetric))

python/dolfinx/fem/function.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
import typing
11-
from functools import singledispatch
11+
from functools import cached_property, singledispatch
1212

1313
import numpy as np
1414
import numpy.typing as npt
@@ -18,6 +18,7 @@
1818
from dolfinx import cpp as _cpp
1919
from dolfinx import default_scalar_type, jit, la
2020
from dolfinx.fem import dofmap
21+
from dolfinx.fem.element import FiniteElement, finiteelement
2122
from dolfinx.geometry import PointOwnershipData
2223

2324
if typing.TYPE_CHECKING:
@@ -461,7 +462,7 @@ def _(e0: Expression):
461462
# u0 is callable
462463
assert callable(u0)
463464
x = _cpp.fem.interpolation_coords(
464-
self._V.element, self._V.mesh.geometry._cpp_object, cells0
465+
self._V.element._cpp_object, self._V.mesh.geometry._cpp_object, cells0
465466
)
466467
self._cpp_object.interpolate(np.asarray(u0(x), dtype=self.dtype), cells0) # type: ignore
467468

@@ -560,32 +561,6 @@ class ElementMetaData(typing.NamedTuple):
560561
symmetry: typing.Optional[bool] = None
561562

562563

563-
def _create_dolfinx_element(
564-
cell_type: _cpp.mesh.CellType,
565-
ufl_e: ufl.FiniteElementBase,
566-
dtype: np.dtype,
567-
) -> typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]:
568-
"""Create a DOLFINx element from a basix.ufl element."""
569-
if np.issubdtype(dtype, np.float32):
570-
CppElement = _cpp.fem.FiniteElement_float32
571-
elif np.issubdtype(dtype, np.float64):
572-
CppElement = _cpp.fem.FiniteElement_float64
573-
else:
574-
raise ValueError(f"Unsupported dtype: {dtype}")
575-
576-
if ufl_e.is_mixed:
577-
elements = [_create_dolfinx_element(cell_type, e, dtype) for e in ufl_e.sub_elements]
578-
return CppElement(elements)
579-
elif ufl_e.is_quadrature:
580-
return CppElement(
581-
cell_type, ufl_e.custom_quadrature()[0], ufl_e.reference_value_shape, ufl_e.is_symmetric
582-
)
583-
else:
584-
basix_e = ufl_e.basix_element._e
585-
value_shape = ufl_e.reference_value_shape if ufl_e.block_size > 1 else None
586-
return CppElement(basix_e, value_shape, ufl_e.is_symmetric)
587-
588-
589564
def functionspace(
590565
mesh: Mesh,
591566
element: typing.Union[ufl.FiniteElementBase, ElementMetaData, tuple[str, int, tuple, bool]],
@@ -614,18 +589,18 @@ def functionspace(
614589
raise ValueError("Non-matching UFL cell and mesh cell shapes.")
615590

616591
# Create DOLFINx objects
617-
cpp_element = _create_dolfinx_element(mesh.topology.cell_type, ufl_e, dtype)
618-
cpp_dofmap = _cpp.fem.create_dofmap(mesh.comm, mesh.topology._cpp_object, cpp_element)
592+
element = finiteelement(mesh.topology.cell_type, ufl_e, dtype)
593+
cpp_dofmap = _cpp.fem.create_dofmap(mesh.comm, mesh.topology._cpp_object, element._cpp_object)
619594

620595
assert np.issubdtype(
621-
mesh.geometry.x.dtype, cpp_element.dtype
596+
mesh.geometry.x.dtype, element.dtype
622597
), "Mesh and element dtype are not compatible."
623598

624599
# Initialize the cpp.FunctionSpace
625600
try:
626-
cppV = _cpp.fem.FunctionSpace_float64(mesh._cpp_object, cpp_element, cpp_dofmap)
601+
cppV = _cpp.fem.FunctionSpace_float64(mesh._cpp_object, element._cpp_object, cpp_dofmap)
627602
except TypeError:
628-
cppV = _cpp.fem.FunctionSpace_float32(mesh._cpp_object, cpp_element, cpp_dofmap)
603+
cppV = _cpp.fem.FunctionSpace_float32(mesh._cpp_object, element._cpp_object, cpp_dofmap)
629604

630605
return FunctionSpace(mesh, ufl_e, cppV)
631606

@@ -745,12 +720,10 @@ def ufl_function_space(self) -> ufl.FunctionSpace:
745720
"""UFL function space."""
746721
return self
747722

748-
@property
749-
def element(
750-
self,
751-
) -> typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]:
723+
@cached_property
724+
def element(self) -> FiniteElement:
752725
"""Function space finite element."""
753-
return self._cpp_object.element # type: ignore
726+
return FiniteElement(self._cpp_object.element)
754727

755728
@property
756729
def dofmap(self) -> dofmap.DofMap:

0 commit comments

Comments
 (0)