diff --git a/python/demo/demo_lagrange_variants.py b/python/demo/demo_lagrange_variants.py index 209e7312c6..135f12a863 100644 --- a/python/demo/demo_lagrange_variants.py +++ b/python/demo/demo_lagrange_variants.py @@ -26,6 +26,7 @@ from mpi4py import MPI import matplotlib.pylab as plt +import numpy as np import basix import basix.ufl @@ -142,11 +143,11 @@ def saw_tooth(x): uh.interpolate(lambda x: saw_tooth(x[0])) if MPI.COMM_WORLD.size == 1: # Skip this plotting in parallel pts: list[list[float]] = [] - cells: list[int] = [] + cells = np.empty((0,), dtype=np.int32) for cell in range(N): for i in range(51): pts.append([cell / N + i / 50 / N, 0, 0]) - cells.append(cell) + cells = np.append(cells, [cell]) values = uh.eval(pts, cells) plt.plot(pts, [saw_tooth(i[0]) for i in pts], "k--") plt.plot(pts, values, "r-") diff --git a/python/dolfinx/__init__.py b/python/dolfinx/__init__.py index d0cbe1344b..2fc0d7b4e0 100644 --- a/python/dolfinx/__init__.py +++ b/python/dolfinx/__init__.py @@ -29,7 +29,7 @@ from dolfinx import common from dolfinx import cpp as _cpp -from dolfinx import fem, geometry, graph, io, jit, la, log, mesh, nls, plot +from dolfinx import fem, geometry, graph, io, jit, la, log, mesh, nls, plot, typing from dolfinx.common import ( git_commit_hash, @@ -79,6 +79,7 @@ def get_include(user=False): "mesh", "nls", "plot", + "typing", "git_commit_hash", "hardware_concurrency", "has_adios2", diff --git a/python/dolfinx/fem/bcs.py b/python/dolfinx/fem/bcs.py index d4abcaa8a7..b8bb1be1ae 100644 --- a/python/dolfinx/fem/bcs.py +++ b/python/dolfinx/fem/bcs.py @@ -13,6 +13,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable +from typing import Generic import numpy as np import numpy.typing as npt @@ -20,6 +21,7 @@ import dolfinx from dolfinx import cpp as _cpp from dolfinx.fem.function import Constant, Function, FunctionSpace +from dolfinx.typing import Scalar def locate_dofs_geometrical( @@ -90,7 +92,7 @@ def locate_dofs_topological( return _cpp.fem.locate_dofs_topological(_V, entity_dim, _entities, remote) -class DirichletBC: +class DirichletBC(Generic[Scalar]): """Representation of Dirichlet boundary conditions. The conditions are imposed on a linear system. @@ -119,8 +121,9 @@ class initialiser. This class is combined with different self._cpp_object = bc @property - def g(self) -> Function | Constant | np.ndarray: + def g(self) -> Function | Constant: """The boundary condition value(s).""" + # TODO: needs to be wrapped into Function or Constant return self._cpp_object.value @property @@ -129,7 +132,7 @@ def function_space(self) -> dolfinx.fem.FunctionSpace: return self._cpp_object.function_space def set( - self, x: npt.NDArray, x0: npt.NDArray[np.int32] | None = None, alpha: float = 1 + self, x: npt.NDArray[Scalar], x0: npt.NDArray[Scalar] | None = None, alpha: float = 1 ) -> None: """Set array entries that are constrained by a Dirichlet condition. @@ -171,10 +174,10 @@ def dof_indices(self) -> tuple[npt.NDArray[np.int32], int]: def dirichletbc( - value: Function | Constant | np.ndarray | float | complex, + value: Function | Constant | npt.NDArray[Scalar] | float | complex, dofs: npt.NDArray[np.int32], V: dolfinx.fem.FunctionSpace | None = None, -) -> DirichletBC: +) -> DirichletBC[Scalar]: """Representation of Dirichlet boundary condition. Args: @@ -231,8 +234,8 @@ def dirichletbc( def bcs_by_block( - spaces: Iterable[FunctionSpace | None], bcs: Iterable[DirichletBC] -) -> list[list[DirichletBC]]: + spaces: Iterable[FunctionSpace | None], bcs: Iterable[DirichletBC[Scalar]] +) -> list[list[DirichletBC[Scalar]]]: """Arrange boundary conditions by the space that they constrain. Given a sequence of function spaces ``spaces`` and a sequence of diff --git a/python/dolfinx/fem/element.py b/python/dolfinx/fem/element.py index 55ad6626f1..b33244bf7d 100644 --- a/python/dolfinx/fem/element.py +++ b/python/dolfinx/fem/element.py @@ -6,6 +6,7 @@ """Finite elements.""" from functools import singledispatch +from typing import Generic import numpy as np import numpy.typing as npt @@ -13,9 +14,10 @@ import basix import basix.ufl from dolfinx import cpp as _cpp +from dolfinx.typing import Real -class CoordinateElement: +class CoordinateElement(Generic[Real]): """Coordinate element describing the geometry map for mesh cells.""" _cpp_object: _cpp.fem.CoordinateElement_float32 | _cpp.fem.CoordinateElement_float64 @@ -61,10 +63,8 @@ def create_dof_layout(self) -> _cpp.fem.ElementDofLayout: return self._cpp_object.create_dof_layout() def push_forward( - self, - X: npt.NDArray[np.float32] | npt.NDArray[np.float64], - cell_geometry: npt.NDArray[np.float32] | npt.NDArray[np.float64], - ) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + self, X: npt.NDArray[Real], cell_geometry: npt.NDArray[Real] + ) -> npt.NDArray[Real]: """Push points on the reference cell forward to the physical cell. Args: @@ -82,11 +82,11 @@ def push_forward( def pull_back( self, - x: npt.NDArray[np.float32] | npt.NDArray[np.float64], - cell_geometry: npt.NDArray[np.float32] | npt.NDArray[np.float64], + x: npt.NDArray[Real], + cell_geometry: npt.NDArray[Real], tol: float = 1.0e-6, maxit: int = 15, - ) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + ) -> npt.NDArray[Real]: """Pull points on the physical cell back to the reference cell. For non-affine cells, the pull-back is a nonlinear operation. @@ -130,7 +130,7 @@ def coordinate_element( degree: int, variant=int(basix.LagrangeVariant.unset), dtype: npt.DTypeLike = np.float64, -): +) -> CoordinateElement: """Create a Lagrange CoordinateElement from element metadata. Coordinate elements are typically used to create meshes. @@ -153,7 +153,7 @@ def coordinate_element( @coordinate_element.register(basix.finite_element.FiniteElement) -def _(e: basix.finite_element.FiniteElement): +def _(e: basix.finite_element.FiniteElement) -> CoordinateElement: """Create a Lagrange CoordinateElement from a Basix finite element. Coordinate elements are typically used when creating meshes. @@ -170,7 +170,7 @@ def _(e: basix.finite_element.FiniteElement): return CoordinateElement(_cpp.fem.CoordinateElement_float64(e._e)) -class FiniteElement: +class FiniteElement(Generic[Real]): """A finite element.""" _cpp_object: _cpp.fem.FiniteElement_float32 | _cpp.fem.FiniteElement_float64 @@ -224,7 +224,7 @@ def value_shape(self) -> npt.NDArray[np.integer]: return self._cpp_object.value_shape @property - def interpolation_points(self) -> npt.NDArray[np.floating]: + def interpolation_points(self) -> npt.NDArray[Real]: """Points at which to evaluate the function to be interpolated. Interpolation point coordinates on the reference cell, returning @@ -282,7 +282,7 @@ def signature(self) -> str: return self._cpp_object.signature def T_apply( - self, x: npt.NDArray[np.floating], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[Real], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Transform basis from reference to physical ordering/orientation. @@ -305,7 +305,7 @@ def T_apply( self._cpp_object.T_apply(x, cell_permutations, dim) def Tt_apply( - self, x: npt.NDArray[np.floating], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[Real], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Apply the transpose of the operator applied by T_apply(). @@ -319,7 +319,7 @@ def Tt_apply( self._cpp_object.Tt_apply(x, cell_permutations, dim) def Tt_inv_apply( - self, x: npt.NDArray[np.floating], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[Real], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Apply the inverse transpose of T_apply(). @@ -336,7 +336,7 @@ def Tt_inv_apply( def finiteelement( cell_type: _cpp.mesh.CellType, ufl_e: basix.ufl._ElementBase, - FiniteElement_dtype: np.dtype, + FiniteElement_dtype: npt.DTypeLike, ) -> FiniteElement: """Create a DOLFINx element from a basix.ufl element. diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index 77434579d4..23ca331e17 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -25,6 +25,7 @@ from dolfinx import default_scalar_type, jit from dolfinx.fem import IntegralType from dolfinx.fem.function import Constant, Function, FunctionSpace +from dolfinx.typing import Scalar if typing.TYPE_CHECKING: # import dolfinx.mesh just when doing type checking to avoid @@ -33,7 +34,7 @@ from dolfinx.mesh import Mesh, MeshTags -class Form: +class Form(typing.Generic[Scalar]): """A finite element form.""" _cpp_object: ( diff --git a/python/dolfinx/fem/function.py b/python/dolfinx/fem/function.py index bbbfdd7cdb..58d9640365 100644 --- a/python/dolfinx/fem/function.py +++ b/python/dolfinx/fem/function.py @@ -11,6 +11,7 @@ import typing from collections.abc import Callable, Sequence from functools import cached_property, singledispatch +from typing import Generic import numpy as np import numpy.typing as npt @@ -22,6 +23,7 @@ from dolfinx.fem.dofmap import DofMap from dolfinx.fem.element import FiniteElement, finiteelement from dolfinx.geometry import PointOwnershipData +from dolfinx.typing import Real, Scalar if typing.TYPE_CHECKING: from mpi4py import MPI as _MPI @@ -29,7 +31,7 @@ from dolfinx.mesh import Mesh -class Constant(ufl.Constant): +class Constant(ufl.Constant, Generic[Scalar]): """A constant with respect to a domain.""" _cpp_object: ( @@ -72,30 +74,30 @@ def value(self): return self._cpp_object.value @value.setter - def value(self, v): + def value(self, v: npt.NDArray[Scalar]) -> None: np.copyto(self._cpp_object.value, np.asarray(v)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> npt.DTypeLike: """Value dtype of the constant.""" return np.dtype(self._cpp_object.dtype) - def __float__(self): + def __float__(self) -> float: """Real representation of the constant.""" if self.ufl_shape or self.ufl_free_indices: raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.") - else: - return float(self.value) - def __complex__(self): + return float(self.value) + + def __complex__(self) -> complex: """Complex representation of the constant.""" if self.ufl_shape or self.ufl_free_indices: raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.") - else: - return complex(self.value) + return complex(self.value) -class Expression: + +class Expression(Generic[Scalar]): """An object for evaluating functions of finite element functions. Represents a mathematical expression evaluated at a pre-defined set @@ -210,9 +212,9 @@ def _create_expression(dtype): def eval( self, mesh: Mesh, - entities: np.ndarray, - values: np.ndarray | None = None, - ) -> np.ndarray: + entities: npt.NDArray[np.int32], + values: npt.NDArray[Scalar] | None = None, + ) -> npt.NDArray[Scalar]: """Evaluate Expression on entities. Args: @@ -268,12 +270,12 @@ def eval( ) return values - def X(self) -> np.ndarray: + def X(self) -> npt.NDArray: """Evaluation points on the reference cell.""" return self._cpp_object.X() @property - def ufl_expression(self): + def ufl_expression(self) -> ufl.core.expr.Expr: """Original UFL Expression.""" return self._ufl_expression @@ -303,12 +305,12 @@ def code(self) -> str: return self._code @property - def dtype(self) -> np.dtype: + def dtype(self) -> npt.DTypeLike: """Expression value dtype.""" return np.dtype(self._cpp_object.dtype) -class Function(ufl.Coefficient): +class Function(ufl.Coefficient, Generic[Scalar]): """A finite element function. A finite element function is represented by a function space @@ -324,10 +326,12 @@ class Function(ufl.Coefficient): | _cpp.fem.Function_float64 ) + _x: la.Vector[Scalar] + def __init__( self, V: FunctionSpace, - x: la.Vector | None = None, + x: la.Vector[Scalar] | None = None, name: str | None = None, dtype: npt.DTypeLike | None = None, ): @@ -395,11 +399,11 @@ def function_space(self) -> FunctionSpace: def eval( self, x: npt.ArrayLike, - cells: npt.ArrayLike, - u: None | npt.NDArray[np.float32 | np.float64 | np.complex128 | np.complex64] = None, + cells: npt.NDArray[np.int32], + u: None | npt.NDArray[Scalar] = None, tol: float = 1.0e-6, maxit: int = 15, - ) -> np.ndarray: + ) -> npt.NDArray[Scalar]: """Evaluate Function at points x. Args: @@ -443,7 +447,7 @@ def eval( def interpolate_nonmatching( self, - u0: Function, + u0: Function[Scalar], cells: npt.NDArray[np.int32], interpolation_data: PointOwnershipData, tol: float = 1e-6, @@ -469,9 +473,9 @@ def interpolate_nonmatching( def interpolate( self, - u0: Callable | Expression | Function, - cells0: np.ndarray | None = None, - cells1: np.ndarray | None = None, + u0: Callable | Expression[Scalar] | Function[Scalar], + cells0: npt.NDArray[np.int32] | None = None, + cells1: npt.NDArray[np.int32] | None = None, ) -> None: """Interpolate an expression. @@ -517,7 +521,7 @@ def _(e0: Expression): ) self._cpp_object.interpolate_f(np.asarray(u0(x), dtype=self.dtype), cells0) - def copy(self) -> Function: + def copy(self) -> Function[Scalar]: """Create a copy of the Function. The function space is shared and the degree-of-freedom vector is @@ -533,12 +537,12 @@ def copy(self) -> Function: ) @property - def x(self) -> la.Vector: + def x(self) -> la.Vector[Scalar]: """Vector holding the degrees-of-freedom.""" return self._x @property - def dtype(self) -> np.dtype: + def dtype(self) -> npt.DTypeLike: """Function value dtype.""" return np.dtype(self._cpp_object.x.array.dtype) @@ -551,11 +555,11 @@ def name(self) -> str: def name(self, name): self._cpp_object.name = name - def __str__(self): + def __str__(self) -> str: """Pretty print representation.""" return self.name - def sub(self, i: int) -> Function: + def sub(self, i: int) -> Function[Scalar]: """Return a sub-function (a view into the ``Function``). Sub-functions are indexed ``i = 0, ..., N-1``, where ``N`` is @@ -574,7 +578,7 @@ def sub(self, i: int) -> Function: """ return Function(self._V.sub(i), self.x, name=f"{self!s}_{i}") - def split(self) -> tuple[Function, ...]: + def split(self) -> tuple[Function[Scalar], ...]: """Extract (any) sub-functions. A sub-function can be extracted from a discrete function that is @@ -589,7 +593,7 @@ def split(self) -> tuple[Function, ...]: raise RuntimeError("No subfunctions to extract") return tuple(self.sub(i) for i in range(num_sub_spaces)) - def collapse(self) -> Function: + def collapse(self) -> Function[Scalar]: """Create a collapsed version of this Function.""" u_collapsed = self._cpp_object.collapse() # type: ignore V_collapsed = FunctionSpace( @@ -680,11 +684,11 @@ def functionspace( return FunctionSpace(mesh, ufl_e, cppV) -class FunctionSpace(ufl.FunctionSpace): +class FunctionSpace(ufl.FunctionSpace, Generic[Real]): """A space on which Functions (fields) can be defined.""" _cpp_object: _cpp.fem.FunctionSpace_float32 | _cpp.fem.FunctionSpace_float64 - _mesh: Mesh + _mesh: Mesh[Real] def __init__( self, @@ -711,7 +715,7 @@ def __init__( self._mesh = mesh super().__init__(ufl_domain, element) - def clone(self) -> FunctionSpace: + def clone(self) -> FunctionSpace[Real]: """Create a FunctionSpace which shares data with this space. The new space has a different unique integer ID. @@ -746,7 +750,7 @@ def num_sub_spaces(self) -> int: """Number of sub spaces.""" return self.element.num_sub_elements - def sub(self, i: int) -> FunctionSpace: + def sub(self, i: int) -> FunctionSpace[Real]: """Return the i-th sub space. Args: @@ -794,7 +798,7 @@ def ufl_function_space(self) -> ufl.FunctionSpace: return self @cached_property - def element(self) -> FiniteElement: + def element(self) -> FiniteElement[Real]: """Function space finite element.""" return FiniteElement(self._cpp_object.element) @@ -808,22 +812,22 @@ def dofmaps(self, idx: int) -> DofMap: return DofMap(self._cpp_object.dofmaps(idx)) @property - def mesh(self) -> Mesh: + def mesh(self) -> Mesh[Real]: """Mesh on which the function space is defined.""" return self._mesh - def collapse(self) -> tuple[FunctionSpace, np.ndarray]: + def collapse(self) -> tuple[FunctionSpace[Real], list[npt.NDArray[np.int32]]]: """Create a new function space by collapsing a subspace. Returns: A new function space and the map from new to old degrees-of-freedom. """ - cpp_space, dofs = self._cpp_object.collapse() # type: ignore + cpp_space, dofs = self._cpp_object.collapse() V = FunctionSpace(self._mesh, self.ufl_element(), cpp_space) return V, dofs - def tabulate_dof_coordinates(self) -> npt.NDArray[np.float64]: + def tabulate_dof_coordinates(self) -> npt.NDArray[Real]: """Tabulate coordinates of function space degrees-of-freedom. Returns: @@ -833,4 +837,4 @@ def tabulate_dof_coordinates(self) -> npt.NDArray[np.float64]: This method is only for elements with point evaluation degrees-of-freedom. """ - return self._cpp_object.tabulate_dof_coordinates() # type: ignore + return self._cpp_object.tabulate_dof_coordinates() diff --git a/python/dolfinx/geometry.py b/python/dolfinx/geometry.py index 9c9d734797..5c3d6e40d0 100644 --- a/python/dolfinx/geometry.py +++ b/python/dolfinx/geometry.py @@ -18,6 +18,7 @@ from dolfinx import cpp as _cpp from dolfinx.graph import AdjacencyList +from dolfinx.typing import Real __all__ = [ "BoundingBoxTree", @@ -35,7 +36,7 @@ ] -class PointOwnershipData: +class PointOwnershipData(typing.Generic[Real]): """Class for storing data related to the ownership of points.""" _cpp_object: _cpp.geometry.PointOwnershipData_float32 | _cpp.geometry.PointOwnershipData_float64 @@ -55,7 +56,7 @@ def dest_owner(self) -> npt.NDArray[np.int32]: return self._cpp_object.dest_owners @property - def dest_points(self) -> npt.NDArray[np.floating]: + def dest_points(self) -> npt.NDArray[Real]: """Points owned by current rank.""" return self._cpp_object.dest_points @@ -65,7 +66,7 @@ def dest_cells(self) -> npt.NDArray[np.int32]: return self._cpp_object.dest_cells -class BoundingBoxTree: +class BoundingBoxTree(typing.Generic[Real]): """Bounding box trees used in collision detection.""" _cpp_object: _cpp.geometry.BoundingBoxTree_float32 | _cpp.geometry.BoundingBoxTree_float64 @@ -86,7 +87,7 @@ def num_bboxes(self) -> int: return self._cpp_object.num_bboxes @property - def bbox_coordinates(self) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + def bbox_coordinates(self) -> npt.NDArray[Real]: """Coordinates of lower and upper corners of bounding boxes. Note: @@ -95,7 +96,7 @@ def bbox_coordinates(self) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: """ return self._cpp_object.bbox_coordinates - def get_bbox(self, i) -> npt.NDArray[np.floating]: + def get_bbox(self, i) -> npt.NDArray[Real]: """Get lower and upper corners of the ith bounding box. Args: @@ -108,18 +109,18 @@ def get_bbox(self, i) -> npt.NDArray[np.floating]: """ return self._cpp_object.get_bbox(i) - def create_global_tree(self, comm) -> BoundingBoxTree: + def create_global_tree(self, comm) -> BoundingBoxTree[Real]: """Create a global bounding box tree.""" return BoundingBoxTree(self._cpp_object.create_global_tree(comm)) def bb_tree( - mesh: Mesh, + mesh: Mesh[Real], dim: int, *, padding: float = 0.0, entities: npt.NDArray[np.int32] | None = None, -) -> BoundingBoxTree: +) -> BoundingBoxTree[Real]: """Create a bounding box tree for use in collision detection. Args: @@ -151,7 +152,7 @@ def bb_tree( def compute_collisions_trees( - tree0: BoundingBoxTree, tree1: BoundingBoxTree + tree0: BoundingBoxTree[Real], tree1: BoundingBoxTree[Real] ) -> npt.NDArray[np.int32]: """Compute all collisions between two bounding box trees. @@ -167,7 +168,7 @@ def compute_collisions_trees( return _cpp.geometry.compute_collisions_trees(tree0._cpp_object, tree1._cpp_object) -def compute_collisions_points(tree: BoundingBoxTree, x: npt.NDArray[np.floating]) -> AdjacencyList: +def compute_collisions_points(tree: BoundingBoxTree[Real], x: npt.NDArray[Real]) -> AdjacencyList: """Compute collisions between points and leaf bounding boxes. Bounding boxes can overlap, therefore points can collide with more @@ -186,10 +187,10 @@ def compute_collisions_points(tree: BoundingBoxTree, x: npt.NDArray[np.floating] def compute_closest_entity( - tree: BoundingBoxTree, - midpoint_tree: BoundingBoxTree, - mesh: Mesh, - points: npt.NDArray[np.floating], + tree: BoundingBoxTree[Real], + midpoint_tree: BoundingBoxTree[Real], + mesh: Mesh[Real], + points: npt.NDArray[Real], ) -> npt.NDArray[np.int32]: """Compute closest mesh entity to a point. @@ -211,7 +212,9 @@ def compute_closest_entity( ) -def create_midpoint_tree(mesh: Mesh, dim: int, entities: npt.NDArray[np.int32]) -> BoundingBoxTree: +def create_midpoint_tree( + mesh: Mesh[Real], dim: int, entities: npt.NDArray[np.int32] +) -> BoundingBoxTree[Real]: """Create bounding box tree for the midpoints of a subset of entities. Args: @@ -226,7 +229,7 @@ def create_midpoint_tree(mesh: Mesh, dim: int, entities: npt.NDArray[np.int32]) def compute_colliding_cells( - msh: Mesh, candidates: AdjacencyList, x: npt.NDArray[np.floating] + msh: Mesh[Real], candidates: AdjacencyList, x: npt.NDArray[Real] ) -> AdjacencyList: """From a mesh, find which cells collide with a set of points. @@ -247,8 +250,8 @@ def compute_colliding_cells( def squared_distance( - mesh: Mesh, dim: int, entities: npt.NDArray[np.int32], points: npt.NDArray[np.floating] -) -> npt.NDArray[np.floating]: + mesh: Mesh[Real], dim: int, entities: npt.NDArray[np.int32], points: npt.NDArray[Real] +) -> npt.NDArray[Real]: """Compute the squared distance between a point and a mesh entity. The distance is computed between the ith input points and the ith @@ -268,9 +271,7 @@ def squared_distance( return _cpp.geometry.squared_distance(mesh._cpp_object, dim, entities, points) -def compute_distance_gjk( - p: npt.NDArray[np.floating], q: npt.NDArray[np.floating] -) -> npt.NDArray[np.floating]: +def compute_distance_gjk(p: npt.NDArray[Real], q: npt.NDArray[Real]) -> npt.NDArray[Real]: """Compute the distance between two convex bodies. Each body is defined by a set of points. Uses the @@ -292,8 +293,8 @@ def compute_distance_gjk( def compute_distances_gjk( - bodies: list[npt.NDArray[np.floating]], q: npt.NDArray[np.floating], num_threads: int -) -> npt.NDArray[np.floating]: + bodies: list[npt.NDArray[Real]], q: npt.NDArray[Real], num_threads: int +) -> npt.NDArray[Real]: """Compute the distance between a set of convex bodies. For each convex body defined in `bodies`; @@ -321,10 +322,10 @@ def compute_distances_gjk( def determine_point_ownership( mesh: Mesh, - points: npt.NDArray[np.floating], + points: npt.NDArray[Real], padding: float, cells: npt.NDArray[np.int32] | None = None, -) -> PointOwnershipData: +) -> PointOwnershipData[Real]: """Build point ownership data for a mesh-points pair. First, potential collisions are found by computing intersections diff --git a/python/dolfinx/graph.py b/python/dolfinx/graph.py index af8ec7afea..2f3f22ea4a 100644 --- a/python/dolfinx/graph.py +++ b/python/dolfinx/graph.py @@ -5,13 +5,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Graph representations and operations on graphs.""" -from typing import Generic, TypeVar +from typing import Generic import numpy as np import numpy.typing as npt from dolfinx import cpp as _cpp from dolfinx.cpp.graph import partitioner +from dolfinx.typing import Index # Import graph partitioners, which may or may not be available # (dependent on build configuration) @@ -39,10 +40,7 @@ ] -_T = TypeVar("_T", np.int32, np.int64) - - -class AdjacencyList(Generic[_T]): +class AdjacencyList(Generic[Index]): """Adjacency list representation of a graph.""" _cpp_object: ( @@ -74,7 +72,7 @@ def __repr__(self): """String representation of the adjacency list.""" return self._cpp_object.__repr__() - def links(self, node: int) -> npt.NDArray[_T]: + def links(self, node: int) -> npt.NDArray[Index]: """Retrieve the links of a node. Note: @@ -90,7 +88,7 @@ def links(self, node: int) -> npt.NDArray[_T]: return self._cpp_object.links(node) @property - def array(self) -> npt.NDArray[_T]: + def array(self) -> npt.NDArray[Index]: """Array representation of the adjacency list. Note: @@ -122,8 +120,8 @@ def num_nodes(self) -> np.int32: def adjacencylist( - data: npt.NDArray[_T], offsets: npt.NDArray[np.int32] | None = None -) -> AdjacencyList[_T]: + data: npt.NDArray[Index], offsets: npt.NDArray[np.int32] | None = None +) -> AdjacencyList[Index]: """Create an :class:`AdjacencyList` for `int32` or `int64` datasets. Args: diff --git a/python/dolfinx/la/__init__.py b/python/dolfinx/la/__init__.py index 34cb6b48d1..120fadef7e 100644 --- a/python/dolfinx/la/__init__.py +++ b/python/dolfinx/la/__init__.py @@ -5,6 +5,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Linear algebra functionality.""" +from typing import Generic, TypeVar + import numpy as np import numpy.typing as npt @@ -12,6 +14,7 @@ from dolfinx import cpp as _cpp from dolfinx.cpp.common import IndexMap from dolfinx.cpp.la import BlockMode, InsertMode, Norm +from dolfinx.typing import Scalar __all__ = [ "InsertMode", @@ -26,7 +29,10 @@ ] -class Vector: +_T = TypeVar("_T", np.float32, np.float64, np.complex64, np.complex128, np.int8, np.int32, np.int64) + + +class Vector(Generic[_T]): """Distributed vector object.""" _cpp_object: ( @@ -79,7 +85,7 @@ def block_size(self) -> int: return self._cpp_object.bs @property - def array(self) -> np.ndarray: + def array(self) -> npt.NDArray[_T]: """Local representation of the vector.""" return self._cpp_object.array @@ -117,7 +123,7 @@ def scatter_reverse(self, mode: InsertMode) -> None: self._cpp_object.scatter_reverse(mode) -class MatrixCSR: +class MatrixCSR(Generic[Scalar]): """Distributed compressed sparse row matrix.""" _cpp_object: ( @@ -155,7 +161,7 @@ def index_map(self, i: int) -> IndexMap: """ return self._cpp_object.index_map(i) - def mult(self, x: Vector, y: Vector, transpose: bool = False) -> None: + def mult(self, x: Vector[Scalar], y: Vector[Scalar], transpose: bool = False) -> None: """Compute ``y += Ax`` or ``y += A^T x``. Args: @@ -200,7 +206,7 @@ def block_size(self) -> list: def add( self, - x: npt.NDArray[np.floating], + x: npt.NDArray[Scalar], rows: npt.NDArray[np.int32], cols: npt.NDArray[np.int32], bs: int = 1, @@ -210,7 +216,7 @@ def add( def set( self, - x: npt.NDArray[np.floating], + x: npt.NDArray[Scalar], rows: npt.NDArray[np.int32], cols: npt.NDArray[np.int32], bs: int = 1, @@ -218,7 +224,7 @@ def set( """Set a block of values in the matrix.""" self._cpp_object.set(x, rows, cols, bs) - def set_value(self, x: np.floating) -> None: + def set_value(self, x: Scalar) -> None: """Set all non-zero entries to a value. Args: @@ -239,7 +245,7 @@ def squared_norm(self) -> np.floating: return self._cpp_object.squared_norm() @property - def data(self) -> npt.NDArray[np.floating]: + def data(self) -> npt.NDArray[Scalar]: """Underlying matrix entry data.""" return self._cpp_object.data @@ -253,7 +259,7 @@ def indptr(self) -> npt.NDArray[np.int64]: """Local row pointers.""" return self._cpp_object.indptr - def to_dense(self) -> npt.NDArray[np.floating]: + def to_dense(self) -> npt.NDArray[Scalar]: """Copy to a dense 2D array. Note: @@ -364,17 +370,17 @@ def vector(map, bs=1, dtype: npt.DTypeLike = np.float64) -> Vector: return Vector(vtype(map, bs)) -def orthonormalize(basis: list[Vector]): +def orthonormalize(basis: list[Vector[_T]]) -> None: """Orthogonalise set of vectors in-place.""" _cpp.la.orthonormalize([x._cpp_object for x in basis]) -def is_orthonormal(basis: list[Vector], eps: float = 1.0e-12) -> bool: +def is_orthonormal(basis: list[Vector[_T]], eps: float = 1.0e-12) -> bool: """Check that list of vectors are orthonormal.""" return _cpp.la.is_orthonormal([x._cpp_object for x in basis], eps) -def norm(x: Vector, type: _cpp.la.Norm = _cpp.la.Norm.l2) -> np.floating: +def norm(x: Vector[_T], type: _cpp.la.Norm = _cpp.la.Norm.l2) -> np.floating: """Compute a norm of the vector. Args: diff --git a/python/dolfinx/la/superlu_dist.py b/python/dolfinx/la/superlu_dist.py index 72ea934c43..e51ebd8782 100644 --- a/python/dolfinx/la/superlu_dist.py +++ b/python/dolfinx/la/superlu_dist.py @@ -12,6 +12,8 @@ Users with advanced linear solver requirements should use PETSc/petsc4py. """ +from typing import Generic, TypeVar + import numpy as np import numpy.typing as npt @@ -22,8 +24,12 @@ __all__ = ["SuperLUDistMatrix", "SuperLUDistSolver", "superlu_dist_matrix", "superlu_dist_solver"] +# As of 2026, SuperLU_DIST only supports these types, so the general Scalar +# type including np.complex64 cannot be used. +_T = TypeVar("_T", np.float32, np.float64, np.complex128) + -class SuperLUDistMatrix: +class SuperLUDistMatrix(Generic[_T]): """SuperLU_DIST matrix.""" _cpp_object: ( @@ -52,7 +58,7 @@ def dtype(self) -> npt.DTypeLike: return self._cpp_object.dtype -def superlu_dist_matrix(A: dolfinx.la.MatrixCSR) -> SuperLUDistMatrix: +def superlu_dist_matrix(A: dolfinx.la.MatrixCSR[_T]) -> SuperLUDistMatrix[_T]: """Create a SuperLU_DIST matrix. Deep copies all required data from ``A``. @@ -75,7 +81,7 @@ def superlu_dist_matrix(A: dolfinx.la.MatrixCSR) -> SuperLUDistMatrix: return SuperLUDistMatrix(stype(A._cpp_object)) -class SuperLUDistSolver: +class SuperLUDistSolver(Generic[_T]): """SuperLU_DIST solver.""" _cpp_object: ( @@ -112,7 +118,7 @@ def set_option(self, name: str, value: str): """ self._cpp_object.set_option(name, value) - def set_A(self, A: SuperLUDistMatrix): + def set_A(self, A: SuperLUDistMatrix[_T]): """Set assembled left-hand side matrix. For advanced use with SuperLU_DIST option `Factor` allowing use of @@ -123,7 +129,7 @@ def set_A(self, A: SuperLUDistMatrix): """ self._cpp_object.set_A(A._cpp_object) - def solve(self, b: dolfinx.la.Vector, u: dolfinx.la.Vector) -> int: + def solve(self, b: dolfinx.la.Vector[_T], u: dolfinx.la.Vector[_T]) -> int: """Solve linear system :math:`Au = b`. Note: @@ -155,7 +161,7 @@ def solve(self, b: dolfinx.la.Vector, u: dolfinx.la.Vector) -> int: return self._cpp_object.solve(b._cpp_object, u._cpp_object) -def superlu_dist_solver(A: SuperLUDistMatrix) -> SuperLUDistSolver: +def superlu_dist_solver(A: SuperLUDistMatrix[_T]) -> SuperLUDistSolver[_T]: """Create a SuperLU_DIST linear solver. Solve linear system :math:`Au = b` via LU decomposition. diff --git a/python/dolfinx/mesh.py b/python/dolfinx/mesh.py index 4cfaeb39c5..cec93ec3ee 100644 --- a/python/dolfinx/mesh.py +++ b/python/dolfinx/mesh.py @@ -43,6 +43,7 @@ from dolfinx.fem import CoordinateElement as _CoordinateElement from dolfinx.fem import coordinate_element as _coordinate_element from dolfinx.graph import AdjacencyList +from dolfinx.typing import Real __all__ = [ "CellType", @@ -270,7 +271,7 @@ def cell_type(self) -> CellType: return self._cpp_object.cell_type -class Geometry: +class Geometry(typing.Generic[Real]): """The geometry of a :class:`dolfinx.mesh.Mesh`.""" _cpp_object: _cpp.mesh.Geometry_float32 | _cpp.mesh.Geometry_float64 @@ -316,7 +317,7 @@ def input_global_indices(self) -> npt.NDArray[np.int64]: return self._cpp_object.input_global_indices @property - def x(self) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + def x(self) -> npt.NDArray[Real]: """Geometry coordinate points. Shape is ``shape=(num_points, 3)``. @@ -324,12 +325,12 @@ def x(self) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: return self._cpp_object.x -class Mesh: +class Mesh(typing.Generic[Real]): """A mesh.""" _mesh: _cpp.mesh.Mesh_float32 | _cpp.mesh.Mesh_float64 _topology: Topology - _geometry: Geometry + _geometry: Geometry[Real] _ufl_domain: ufl.Mesh | None def __init__( @@ -413,7 +414,7 @@ def topology(self) -> Topology: return self._topology @property - def geometry(self) -> Geometry: + def geometry(self) -> Geometry[Real]: """Mesh geometry.""" return self._geometry diff --git a/python/dolfinx/typing.py b/python/dolfinx/typing.py new file mode 100644 index 0000000000..78101bad1c --- /dev/null +++ b/python/dolfinx/typing.py @@ -0,0 +1,18 @@ +# Copyright (C) 2026 Paul T. Kühner +# +# This file is part of DOLFINx (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +"""Common typing functionality.""" + +from typing import TypeVar + +import numpy as np + +__all__ = ["Index", "Real", "Scalar"] + +Index = TypeVar("Index", np.int32, np.int64) + +Real = TypeVar("Real", np.float32, np.float64) +Scalar = TypeVar("Scalar", np.float32, np.float64, np.complex64, np.complex128)