diff --git a/examples/transfer_tags_to_submesh.py b/examples/transfer_tags_to_submesh.py index 8095ce8..0642246 100644 --- a/examples/transfer_tags_to_submesh.py +++ b/examples/transfer_tags_to_submesh.py @@ -12,7 +12,7 @@ import os import sys import pyvista -from scifem import transfer_meshtags_to_submesh +from scifem import transfer_meshtags_to_submesh, compat from mpi4py import MPI import gmsh import dolfinx @@ -82,7 +82,7 @@ def plot_mesh(mesh: dolfinx.mesh.Mesh, values=None): plotter = pyvista.Plotter() V_linear = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) linear_grid = pyvista.UnstructuredGrid(*dolfinx.plot.vtk_mesh(V_linear)) - if mesh.geometry.cmap.degree > 1: + if compat.get_cmap(mesh).degree > 1: ugrid = pyvista.UnstructuredGrid(*dolfinx.plot.vtk_mesh(mesh)) if values is not None: ugrid.cell_data["Marker"] = values diff --git a/src/scifem/compat.py b/src/scifem/compat.py new file mode 100644 index 0000000..d29a9dc --- /dev/null +++ b/src/scifem/compat.py @@ -0,0 +1,11 @@ +"""Layer for small backward compatibility wrappers for DOLFINx""" + +import dolfinx + + +def get_cmap(mesh: dolfinx.mesh.Mesh) -> dolfinx.fem.CoordinateElement: + """Get the basix Cmap for the mesh.""" + if callable(mesh.geometry.cmap): + return mesh.geometry.cmap() + else: + return mesh.geometry.cmap diff --git a/src/scifem/eval.py b/src/scifem/eval.py index ce27517..7327340 100644 --- a/src/scifem/eval.py +++ b/src/scifem/eval.py @@ -10,6 +10,8 @@ import ufl from scipy.optimize import minimize from ufl.algorithms.signature import compute_expression_signature +from .compat import get_cmap + T = typing.TypeVar("T", int, float) MinMaxFunc = typing.Callable[[typing.Sequence[T]], T] @@ -155,6 +157,7 @@ def find_cell_extrema( _cell = np.array([cell], dtype=np.int32) mesh_nodes = mesh.geometry.x[mesh.geometry.dofmap[cell], : mesh.geometry.dim] _x_p = np.zeros(3) + cmap = get_cmap(mesh) def eval_J(x_ref): # Evaluating basis functions through {py:func}`dolfinx.fem.Function.eval` @@ -163,7 +166,8 @@ def eval_J(x_ref): # This could in theory be made even faster by taking out some of the eval code # However, quite a lot of work needs to be reimplemented for minimal gain # to do so, so we rather push forward, then let eval pull back again. - _x_p[: mesh.geometry.dim] = mesh.geometry.cmap.push_forward( + + _x_p[: mesh.geometry.dim] = cmap.push_forward( x_ref.reshape(-1, mesh.topology.dim), mesh_nodes )[0] try: @@ -233,7 +237,7 @@ def eval_dJ(x_ref): tol=tol, ) - X_phys = mesh.geometry.cmap.push_forward(result.x.reshape(-1, mesh.topology.dim), mesh_nodes)[0] + X_phys = cmap.push_forward(result.x.reshape(-1, mesh.topology.dim), mesh_nodes)[0] return X_phys, sign * result.fun diff --git a/src/scifem/point_source.py b/src/scifem/point_source.py index db23d60..cefc3c4 100644 --- a/src/scifem/point_source.py +++ b/src/scifem/point_source.py @@ -23,6 +23,7 @@ import numpy.typing as npt import ufl +from .compat import get_cmap from .utils import unroll_dofmap __all__ = ["PointSource"] @@ -121,7 +122,7 @@ def compute_cell_contributions(self): mesh = self._function_space.mesh # Pull owning points back to reference cell mesh_nodes = mesh.geometry.x - cmap = mesh.geometry.cmap + cmap = get_cmap(mesh) ref_x = np.zeros((len(self._cells), mesh.topology.dim), dtype=mesh.geometry.x.dtype) for i, (point, cell) in enumerate(zip(self._points, self._cells)): diff --git a/tests/test_point_source.py b/tests/test_point_source.py index decc6bf..764844c 100644 --- a/tests/test_point_source.py +++ b/tests/test_point_source.py @@ -4,7 +4,7 @@ import basix import numpy as np import ufl -from scifem import PointSource +from scifem import PointSource, compat def test_midpoint(): @@ -91,7 +91,7 @@ def test_outside(): if cells.offsets[-1] > 0: cell = cells.links(0)[0] geom_dofs = mesh.geometry.dofmap[cell] - ref_x = mesh.geometry.cmap.pull_back(point.reshape(-1, 3), mesh.geometry.x[geom_dofs]) + ref_x = compat.get_cmap(mesh).pull_back(point.reshape(-1, 3), mesh.geometry.x[geom_dofs]) ref_values = V.ufl_element().tabulate(0, ref_x).flatten() b_nonzero = np.flatnonzero(b.x.array) dofs = V.dofmap.cell_dofs(cell)