From d556e4721ebcb9768f7c259d29dbfb6fb3db560f Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:50:48 +0100 Subject: [PATCH 01/19] Add typing generics to vector --- python/dolfinx/la/__init__.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/dolfinx/la/__init__.py b/python/dolfinx/la/__init__.py index 34cb6b48d1..4ab7388acc 100644 --- a/python/dolfinx/la/__init__.py +++ b/python/dolfinx/la/__init__.py @@ -5,6 +5,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Linear algebra functionality.""" +from typing import Generic, TypeVar + import numpy as np import numpy.typing as npt @@ -26,7 +28,10 @@ ] -class Vector: +_T = TypeVar("_T", np.float32, np.float64, np.complex64, np.complex128, np.int8, np.int32, np.int64) + + +class Vector(Generic[_T]): """Distributed vector object.""" _cpp_object: ( @@ -79,7 +84,7 @@ def block_size(self) -> int: return self._cpp_object.bs @property - def array(self) -> np.ndarray: + def array(self) -> npt.NDArray[_T]: """Local representation of the vector.""" return self._cpp_object.array @@ -364,17 +369,17 @@ def vector(map, bs=1, dtype: npt.DTypeLike = np.float64) -> Vector: return Vector(vtype(map, bs)) -def orthonormalize(basis: list[Vector]): +def orthonormalize(basis: list[Vector[_T]]) -> None: """Orthogonalise set of vectors in-place.""" _cpp.la.orthonormalize([x._cpp_object for x in basis]) -def is_orthonormal(basis: list[Vector], eps: float = 1.0e-12) -> bool: +def is_orthonormal(basis: list[Vector[_T]], eps: float = 1.0e-12) -> bool: """Check that list of vectors are orthonormal.""" return _cpp.la.is_orthonormal([x._cpp_object for x in basis], eps) -def norm(x: Vector, type: _cpp.la.Norm = _cpp.la.Norm.l2) -> np.floating: +def norm(x: Vector[_T], type: _cpp.la.Norm = _cpp.la.Norm.l2) -> np.floating: """Compute a norm of the vector. Args: From 9b1984b7a545b8d403be20c7795a31392e9c8ddf Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:58:20 +0100 Subject: [PATCH 02/19] Add typing generics to matrix --- python/dolfinx/la/__init__.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/dolfinx/la/__init__.py b/python/dolfinx/la/__init__.py index 4ab7388acc..662684e610 100644 --- a/python/dolfinx/la/__init__.py +++ b/python/dolfinx/la/__init__.py @@ -122,7 +122,10 @@ def scatter_reverse(self, mode: InsertMode) -> None: self._cpp_object.scatter_reverse(mode) -class MatrixCSR: +_MT = TypeVar("_MT", np.float32, np.float64, np.complex64, np.complex128) + + +class MatrixCSR(Generic[_MT]): """Distributed compressed sparse row matrix.""" _cpp_object: ( @@ -160,7 +163,7 @@ def index_map(self, i: int) -> IndexMap: """ return self._cpp_object.index_map(i) - def mult(self, x: Vector, y: Vector, transpose: bool = False) -> None: + def mult(self, x: Vector[_MT], y: Vector[_MT], transpose: bool = False) -> None: """Compute ``y += Ax`` or ``y += A^T x``. Args: @@ -205,7 +208,7 @@ def block_size(self) -> list: def add( self, - x: npt.NDArray[np.floating], + x: npt.NDArray[_MT], rows: npt.NDArray[np.int32], cols: npt.NDArray[np.int32], bs: int = 1, @@ -215,7 +218,7 @@ def add( def set( self, - x: npt.NDArray[np.floating], + x: npt.NDArray[_MT], rows: npt.NDArray[np.int32], cols: npt.NDArray[np.int32], bs: int = 1, @@ -223,7 +226,7 @@ def set( """Set a block of values in the matrix.""" self._cpp_object.set(x, rows, cols, bs) - def set_value(self, x: np.floating) -> None: + def set_value(self, x: _MT) -> None: """Set all non-zero entries to a value. Args: @@ -244,7 +247,7 @@ def squared_norm(self) -> np.floating: return self._cpp_object.squared_norm() @property - def data(self) -> npt.NDArray[np.floating]: + def data(self) -> npt.NDArray[_MT]: """Underlying matrix entry data.""" return self._cpp_object.data @@ -258,7 +261,7 @@ def indptr(self) -> npt.NDArray[np.int64]: """Local row pointers.""" return self._cpp_object.indptr - def to_dense(self) -> npt.NDArray[np.floating]: + def to_dense(self) -> npt.NDArray[_MT]: """Copy to a dense 2D array. Note: From b8440c15ef42796522519a7f99b42d022378743e Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:40:06 +0100 Subject: [PATCH 03/19] Add typing generics to superludist{matrix, solver} --- python/dolfinx/la/superlu_dist.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/dolfinx/la/superlu_dist.py b/python/dolfinx/la/superlu_dist.py index 72ea934c43..4a5d8e9c8b 100644 --- a/python/dolfinx/la/superlu_dist.py +++ b/python/dolfinx/la/superlu_dist.py @@ -12,6 +12,8 @@ Users with advanced linear solver requirements should use PETSc/petsc4py. """ +from typing import Generic, TypeVar + import numpy as np import numpy.typing as npt @@ -22,8 +24,10 @@ __all__ = ["SuperLUDistMatrix", "SuperLUDistSolver", "superlu_dist_matrix", "superlu_dist_solver"] +_T = TypeVar("_T", np.float32, np.float64, np.complex128) + -class SuperLUDistMatrix: +class SuperLUDistMatrix(Generic[_T]): """SuperLU_DIST matrix.""" _cpp_object: ( @@ -52,7 +56,7 @@ def dtype(self) -> npt.DTypeLike: return self._cpp_object.dtype -def superlu_dist_matrix(A: dolfinx.la.MatrixCSR) -> SuperLUDistMatrix: +def superlu_dist_matrix(A: dolfinx.la.MatrixCSR[_T]) -> SuperLUDistMatrix[_T]: """Create a SuperLU_DIST matrix. Deep copies all required data from ``A``. @@ -75,7 +79,7 @@ def superlu_dist_matrix(A: dolfinx.la.MatrixCSR) -> SuperLUDistMatrix: return SuperLUDistMatrix(stype(A._cpp_object)) -class SuperLUDistSolver: +class SuperLUDistSolver(Generic[_T]): """SuperLU_DIST solver.""" _cpp_object: ( @@ -112,7 +116,7 @@ def set_option(self, name: str, value: str): """ self._cpp_object.set_option(name, value) - def set_A(self, A: SuperLUDistMatrix): + def set_A(self, A: SuperLUDistMatrix[_T]): """Set assembled left-hand side matrix. For advanced use with SuperLU_DIST option `Factor` allowing use of @@ -123,7 +127,7 @@ def set_A(self, A: SuperLUDistMatrix): """ self._cpp_object.set_A(A._cpp_object) - def solve(self, b: dolfinx.la.Vector, u: dolfinx.la.Vector) -> int: + def solve(self, b: dolfinx.la.Vector[_T], u: dolfinx.la.Vector[_T]) -> int: """Solve linear system :math:`Au = b`. Note: @@ -155,7 +159,7 @@ def solve(self, b: dolfinx.la.Vector, u: dolfinx.la.Vector) -> int: return self._cpp_object.solve(b._cpp_object, u._cpp_object) -def superlu_dist_solver(A: SuperLUDistMatrix) -> SuperLUDistSolver: +def superlu_dist_solver(A: SuperLUDistMatrix[_T]) -> SuperLUDistSolver[_T]: """Create a SuperLU_DIST linear solver. Solve linear system :math:`Au = b` via LU decomposition. From f196a42f6fbf5c83fc55b316eb80f29068cb0bbd Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:49:52 +0100 Subject: [PATCH 04/19] Add typing generics to CoordinateElement --- python/dolfinx/fem/element.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/dolfinx/fem/element.py b/python/dolfinx/fem/element.py index 55ad6626f1..d0d1479f81 100644 --- a/python/dolfinx/fem/element.py +++ b/python/dolfinx/fem/element.py @@ -6,6 +6,7 @@ """Finite elements.""" from functools import singledispatch +from typing import Generic, TypeVar import numpy as np import numpy.typing as npt @@ -14,8 +15,10 @@ import basix.ufl from dolfinx import cpp as _cpp +_T = TypeVar("_T", np.float32, np.float64) -class CoordinateElement: + +class CoordinateElement(Generic[_T]): """Coordinate element describing the geometry map for mesh cells.""" _cpp_object: _cpp.fem.CoordinateElement_float32 | _cpp.fem.CoordinateElement_float64 @@ -60,11 +63,7 @@ def create_dof_layout(self) -> _cpp.fem.ElementDofLayout: """Compute and return the dof layout.""" return self._cpp_object.create_dof_layout() - def push_forward( - self, - X: npt.NDArray[np.float32] | npt.NDArray[np.float64], - cell_geometry: npt.NDArray[np.float32] | npt.NDArray[np.float64], - ) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + def push_forward(self, X: npt.NDArray[_T], cell_geometry: npt.NDArray[_T]) -> npt.NDArray[_T]: """Push points on the reference cell forward to the physical cell. Args: @@ -82,11 +81,11 @@ def push_forward( def pull_back( self, - x: npt.NDArray[np.float32] | npt.NDArray[np.float64], - cell_geometry: npt.NDArray[np.float32] | npt.NDArray[np.float64], + x: npt.NDArray[_T], + cell_geometry: npt.NDArray[_T], tol: float = 1.0e-6, maxit: int = 15, - ) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + ) -> npt.NDArray[_T]: """Pull points on the physical cell back to the reference cell. For non-affine cells, the pull-back is a nonlinear operation. @@ -130,7 +129,7 @@ def coordinate_element( degree: int, variant=int(basix.LagrangeVariant.unset), dtype: npt.DTypeLike = np.float64, -): +) -> CoordinateElement: """Create a Lagrange CoordinateElement from element metadata. Coordinate elements are typically used to create meshes. @@ -153,7 +152,7 @@ def coordinate_element( @coordinate_element.register(basix.finite_element.FiniteElement) -def _(e: basix.finite_element.FiniteElement): +def _(e: basix.finite_element.FiniteElement) -> CoordinateElement: """Create a Lagrange CoordinateElement from a Basix finite element. Coordinate elements are typically used when creating meshes. From 40b733194854de9fee5130f7d353023010d86c9d Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:52:32 +0100 Subject: [PATCH 05/19] Add typing generics to finiteelement --- python/dolfinx/fem/element.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/dolfinx/fem/element.py b/python/dolfinx/fem/element.py index d0d1479f81..9e13b11f8f 100644 --- a/python/dolfinx/fem/element.py +++ b/python/dolfinx/fem/element.py @@ -169,7 +169,7 @@ def _(e: basix.finite_element.FiniteElement) -> CoordinateElement: return CoordinateElement(_cpp.fem.CoordinateElement_float64(e._e)) -class FiniteElement: +class FiniteElement(Generic[_T]): """A finite element.""" _cpp_object: _cpp.fem.FiniteElement_float32 | _cpp.fem.FiniteElement_float64 @@ -223,7 +223,7 @@ def value_shape(self) -> npt.NDArray[np.integer]: return self._cpp_object.value_shape @property - def interpolation_points(self) -> npt.NDArray[np.floating]: + def interpolation_points(self) -> npt.NDArray[_T]: """Points at which to evaluate the function to be interpolated. Interpolation point coordinates on the reference cell, returning @@ -281,7 +281,7 @@ def signature(self) -> str: return self._cpp_object.signature def T_apply( - self, x: npt.NDArray[np.floating], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[_T], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Transform basis from reference to physical ordering/orientation. @@ -304,7 +304,7 @@ def T_apply( self._cpp_object.T_apply(x, cell_permutations, dim) def Tt_apply( - self, x: npt.NDArray[np.floating], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[_T], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Apply the transpose of the operator applied by T_apply(). @@ -318,7 +318,7 @@ def Tt_apply( self._cpp_object.Tt_apply(x, cell_permutations, dim) def Tt_inv_apply( - self, x: npt.NDArray[np.floating], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[_T], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Apply the inverse transpose of T_apply(). @@ -335,7 +335,7 @@ def Tt_inv_apply( def finiteelement( cell_type: _cpp.mesh.CellType, ufl_e: basix.ufl._ElementBase, - FiniteElement_dtype: np.dtype, + FiniteElement_dtype: npt.DTypeLike, ) -> FiniteElement: """Create a DOLFINx element from a basix.ufl element. From c475a42b89d1d56df0bbbe873afcf32e2d238c07 Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:59:05 +0100 Subject: [PATCH 06/19] Add typing generics to constant --- python/dolfinx/fem/function.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/dolfinx/fem/function.py b/python/dolfinx/fem/function.py index bbbfdd7cdb..13775e4210 100644 --- a/python/dolfinx/fem/function.py +++ b/python/dolfinx/fem/function.py @@ -11,6 +11,7 @@ import typing from collections.abc import Callable, Sequence from functools import cached_property, singledispatch +from typing import Generic, TypeVar import numpy as np import numpy.typing as npt @@ -28,8 +29,10 @@ from dolfinx.mesh import Mesh +_S = TypeVar("_S", np.float32, np.float64, np.complex64, np.complex128) # scalar -class Constant(ufl.Constant): + +class Constant(ufl.Constant, Generic[_S]): """A constant with respect to a domain.""" _cpp_object: ( @@ -72,27 +75,27 @@ def value(self): return self._cpp_object.value @value.setter - def value(self, v): + def value(self, v: npt.NDArray[_S]) -> None: np.copyto(self._cpp_object.value, np.asarray(v)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> npt.DTypeLike: """Value dtype of the constant.""" return np.dtype(self._cpp_object.dtype) - def __float__(self): + def __float__(self) -> float: """Real representation of the constant.""" if self.ufl_shape or self.ufl_free_indices: raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.") - else: - return float(self.value) - def __complex__(self): + return float(self.value) + + def __complex__(self) -> complex: """Complex representation of the constant.""" if self.ufl_shape or self.ufl_free_indices: raise TypeError("Cannot evaluate a nonscalar expression to a scalar value.") - else: - return complex(self.value) + + return complex(self.value) class Expression: From 14063a7079b40c1a14f2fef937ffb10014609ca4 Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:25:08 +0100 Subject: [PATCH 07/19] Add typing generics dirichletbc --- python/dolfinx/fem/bcs.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/dolfinx/fem/bcs.py b/python/dolfinx/fem/bcs.py index d4abcaa8a7..5d32f7e015 100644 --- a/python/dolfinx/fem/bcs.py +++ b/python/dolfinx/fem/bcs.py @@ -13,6 +13,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable +from typing import Generic, TypeVar import numpy as np import numpy.typing as npt @@ -90,7 +91,10 @@ def locate_dofs_topological( return _cpp.fem.locate_dofs_topological(_V, entity_dim, _entities, remote) -class DirichletBC: +_T = TypeVar("_T", np.float32, np.float64, np.complex64, np.complex128) + + +class DirichletBC(Generic[_T]): """Representation of Dirichlet boundary conditions. The conditions are imposed on a linear system. @@ -119,8 +123,9 @@ class initialiser. This class is combined with different self._cpp_object = bc @property - def g(self) -> Function | Constant | np.ndarray: + def g(self) -> Function | Constant: """The boundary condition value(s).""" + # TODO: needs to be wrapped into Function or Constant return self._cpp_object.value @property @@ -128,9 +133,7 @@ def function_space(self) -> dolfinx.fem.FunctionSpace: """Function space on which the boundary condition is defined.""" return self._cpp_object.function_space - def set( - self, x: npt.NDArray, x0: npt.NDArray[np.int32] | None = None, alpha: float = 1 - ) -> None: + def set(self, x: npt.NDArray[_T], x0: npt.NDArray[_T] | None = None, alpha: float = 1) -> None: """Set array entries that are constrained by a Dirichlet condition. Entries in ``x`` that are constrained by a Dirichlet boundary @@ -171,10 +174,10 @@ def dof_indices(self) -> tuple[npt.NDArray[np.int32], int]: def dirichletbc( - value: Function | Constant | np.ndarray | float | complex, + value: Function | Constant | npt.NDArray[_T] | float | complex, dofs: npt.NDArray[np.int32], V: dolfinx.fem.FunctionSpace | None = None, -) -> DirichletBC: +) -> DirichletBC[_T]: """Representation of Dirichlet boundary condition. Args: @@ -231,8 +234,8 @@ def dirichletbc( def bcs_by_block( - spaces: Iterable[FunctionSpace | None], bcs: Iterable[DirichletBC] -) -> list[list[DirichletBC]]: + spaces: Iterable[FunctionSpace | None], bcs: Iterable[DirichletBC[_T]] +) -> list[list[DirichletBC[_T]]]: """Arrange boundary conditions by the space that they constrain. Given a sequence of function spaces ``spaces`` and a sequence of From 587f68a63f076d35d2811aa2ee921de27106bad7 Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:35:32 +0100 Subject: [PATCH 08/19] Add typing generics Expression --- python/dolfinx/fem/function.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/dolfinx/fem/function.py b/python/dolfinx/fem/function.py index 13775e4210..7e3fb660c8 100644 --- a/python/dolfinx/fem/function.py +++ b/python/dolfinx/fem/function.py @@ -98,7 +98,7 @@ def __complex__(self) -> complex: return complex(self.value) -class Expression: +class Expression(Generic[_S]): """An object for evaluating functions of finite element functions. Represents a mathematical expression evaluated at a pre-defined set @@ -213,9 +213,9 @@ def _create_expression(dtype): def eval( self, mesh: Mesh, - entities: np.ndarray, - values: np.ndarray | None = None, - ) -> np.ndarray: + entities: npt.NDArray[np.int32], + values: npt.NDArray[_S] | None = None, + ) -> npt.NDArray[_S]: """Evaluate Expression on entities. Args: @@ -271,12 +271,12 @@ def eval( ) return values - def X(self) -> np.ndarray: + def X(self) -> npt.NDArray: """Evaluation points on the reference cell.""" return self._cpp_object.X() @property - def ufl_expression(self): + def ufl_expression(self) -> ufl.core.expr.Expr: """Original UFL Expression.""" return self._ufl_expression @@ -306,7 +306,7 @@ def code(self) -> str: return self._code @property - def dtype(self) -> np.dtype: + def dtype(self) -> npt.DTypeLike: """Expression value dtype.""" return np.dtype(self._cpp_object.dtype) From 10d59023e2b9a9f7efe211afb21a8f60eab71b6a Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:40:13 +0100 Subject: [PATCH 09/19] Add typing generics Function --- python/dolfinx/fem/function.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/python/dolfinx/fem/function.py b/python/dolfinx/fem/function.py index 7e3fb660c8..1e310cc6d1 100644 --- a/python/dolfinx/fem/function.py +++ b/python/dolfinx/fem/function.py @@ -311,7 +311,7 @@ def dtype(self) -> npt.DTypeLike: return np.dtype(self._cpp_object.dtype) -class Function(ufl.Coefficient): +class Function(ufl.Coefficient, Generic[_S]): """A finite element function. A finite element function is represented by a function space @@ -327,10 +327,12 @@ class Function(ufl.Coefficient): | _cpp.fem.Function_float64 ) + _x: la.Vector[_S] + def __init__( self, V: FunctionSpace, - x: la.Vector | None = None, + x: la.Vector[_S] | None = None, name: str | None = None, dtype: npt.DTypeLike | None = None, ): @@ -398,11 +400,11 @@ def function_space(self) -> FunctionSpace: def eval( self, x: npt.ArrayLike, - cells: npt.ArrayLike, - u: None | npt.NDArray[np.float32 | np.float64 | np.complex128 | np.complex64] = None, + cells: npt.NDArray[np.int32], + u: None | npt.NDArray[_S] = None, tol: float = 1.0e-6, maxit: int = 15, - ) -> np.ndarray: + ) -> npt.NDArray[_S]: """Evaluate Function at points x. Args: @@ -446,7 +448,7 @@ def eval( def interpolate_nonmatching( self, - u0: Function, + u0: Function[_S], cells: npt.NDArray[np.int32], interpolation_data: PointOwnershipData, tol: float = 1e-6, @@ -472,9 +474,9 @@ def interpolate_nonmatching( def interpolate( self, - u0: Callable | Expression | Function, - cells0: np.ndarray | None = None, - cells1: np.ndarray | None = None, + u0: Callable | Expression[_S] | Function[_S], + cells0: npt.NDArray[np.int32] | None = None, + cells1: npt.NDArray[np.int32] | None = None, ) -> None: """Interpolate an expression. @@ -520,7 +522,7 @@ def _(e0: Expression): ) self._cpp_object.interpolate_f(np.asarray(u0(x), dtype=self.dtype), cells0) - def copy(self) -> Function: + def copy(self) -> Function[_S]: """Create a copy of the Function. The function space is shared and the degree-of-freedom vector is @@ -536,12 +538,12 @@ def copy(self) -> Function: ) @property - def x(self) -> la.Vector: + def x(self) -> la.Vector[_S]: """Vector holding the degrees-of-freedom.""" return self._x @property - def dtype(self) -> np.dtype: + def dtype(self) -> npt.DTypeLike: """Function value dtype.""" return np.dtype(self._cpp_object.x.array.dtype) @@ -554,11 +556,11 @@ def name(self) -> str: def name(self, name): self._cpp_object.name = name - def __str__(self): + def __str__(self) -> str: """Pretty print representation.""" return self.name - def sub(self, i: int) -> Function: + def sub(self, i: int) -> Function[_S]: """Return a sub-function (a view into the ``Function``). Sub-functions are indexed ``i = 0, ..., N-1``, where ``N`` is @@ -577,7 +579,7 @@ def sub(self, i: int) -> Function: """ return Function(self._V.sub(i), self.x, name=f"{self!s}_{i}") - def split(self) -> tuple[Function, ...]: + def split(self) -> tuple[Function[_S], ...]: """Extract (any) sub-functions. A sub-function can be extracted from a discrete function that is @@ -592,7 +594,7 @@ def split(self) -> tuple[Function, ...]: raise RuntimeError("No subfunctions to extract") return tuple(self.sub(i) for i in range(num_sub_spaces)) - def collapse(self) -> Function: + def collapse(self) -> Function[_S]: """Create a collapsed version of this Function.""" u_collapsed = self._cpp_object.collapse() # type: ignore V_collapsed = FunctionSpace( From da4ed558cf6fe945df35df64b3b5aa958b94f19e Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:30:13 +0100 Subject: [PATCH 10/19] Fix: demo_lagrange --- python/demo/demo_lagrange_variants.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/demo/demo_lagrange_variants.py b/python/demo/demo_lagrange_variants.py index 209e7312c6..135f12a863 100644 --- a/python/demo/demo_lagrange_variants.py +++ b/python/demo/demo_lagrange_variants.py @@ -26,6 +26,7 @@ from mpi4py import MPI import matplotlib.pylab as plt +import numpy as np import basix import basix.ufl @@ -142,11 +143,11 @@ def saw_tooth(x): uh.interpolate(lambda x: saw_tooth(x[0])) if MPI.COMM_WORLD.size == 1: # Skip this plotting in parallel pts: list[list[float]] = [] - cells: list[int] = [] + cells = np.empty((0,), dtype=np.int32) for cell in range(N): for i in range(51): pts.append([cell / N + i / 50 / N, 0, 0]) - cells.append(cell) + cells = np.append(cells, [cell]) values = uh.eval(pts, cells) plt.plot(pts, [saw_tooth(i[0]) for i in pts], "k--") plt.plot(pts, values, "r-") From 218b82e6474073493168ac28ad88778d807e0039 Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:37:20 +0100 Subject: [PATCH 11/19] Add typing generics Geometry --- python/dolfinx/mesh.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/dolfinx/mesh.py b/python/dolfinx/mesh.py index 4cfaeb39c5..e56e2290b9 100644 --- a/python/dolfinx/mesh.py +++ b/python/dolfinx/mesh.py @@ -85,6 +85,9 @@ ] +_T = typing.TypeVar("_T", np.float32, np.float64) + + @singledispatch def create_cell_partitioner( part: Callable, mode: GhostMode, max_facet_to_cell_links: int @@ -270,7 +273,7 @@ def cell_type(self) -> CellType: return self._cpp_object.cell_type -class Geometry: +class Geometry(typing.Generic[_T]): """The geometry of a :class:`dolfinx.mesh.Mesh`.""" _cpp_object: _cpp.mesh.Geometry_float32 | _cpp.mesh.Geometry_float64 @@ -316,7 +319,7 @@ def input_global_indices(self) -> npt.NDArray[np.int64]: return self._cpp_object.input_global_indices @property - def x(self) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + def x(self) -> npt.NDArray[_T]: """Geometry coordinate points. Shape is ``shape=(num_points, 3)``. From 46e9166625d2e33d90dcb5098dd08e428ccf15c3 Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:38:54 +0100 Subject: [PATCH 12/19] Add typing generics Mesh --- python/dolfinx/mesh.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/dolfinx/mesh.py b/python/dolfinx/mesh.py index e56e2290b9..3d12c09bdf 100644 --- a/python/dolfinx/mesh.py +++ b/python/dolfinx/mesh.py @@ -327,12 +327,12 @@ def x(self) -> npt.NDArray[_T]: return self._cpp_object.x -class Mesh: +class Mesh(typing.Generic[_T]): """A mesh.""" _mesh: _cpp.mesh.Mesh_float32 | _cpp.mesh.Mesh_float64 _topology: Topology - _geometry: Geometry + _geometry: Geometry[_T] _ufl_domain: ufl.Mesh | None def __init__( @@ -416,7 +416,7 @@ def topology(self) -> Topology: return self._topology @property - def geometry(self) -> Geometry: + def geometry(self) -> Geometry[_T]: """Mesh geometry.""" return self._geometry From e1cf96a03d90ebd99dd0234a23aad70cf280b549 Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:41:21 +0100 Subject: [PATCH 13/19] Add typing generics PointOwnershipData --- python/dolfinx/geometry.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/dolfinx/geometry.py b/python/dolfinx/geometry.py index 9c9d734797..9600551042 100644 --- a/python/dolfinx/geometry.py +++ b/python/dolfinx/geometry.py @@ -34,8 +34,10 @@ "squared_distance", ] +_T = typing.TypeVar("_T", np.float32, np.float64) -class PointOwnershipData: + +class PointOwnershipData(typing.Generic[_T]): """Class for storing data related to the ownership of points.""" _cpp_object: _cpp.geometry.PointOwnershipData_float32 | _cpp.geometry.PointOwnershipData_float64 @@ -55,7 +57,7 @@ def dest_owner(self) -> npt.NDArray[np.int32]: return self._cpp_object.dest_owners @property - def dest_points(self) -> npt.NDArray[np.floating]: + def dest_points(self) -> npt.NDArray[_T]: """Points owned by current rank.""" return self._cpp_object.dest_points @@ -321,10 +323,10 @@ def compute_distances_gjk( def determine_point_ownership( mesh: Mesh, - points: npt.NDArray[np.floating], + points: npt.NDArray[_T], padding: float, cells: npt.NDArray[np.int32] | None = None, -) -> PointOwnershipData: +) -> PointOwnershipData[_T]: """Build point ownership data for a mesh-points pair. First, potential collisions are found by computing intersections From a7e4d06e4c90cb9e8cc177c7c427c900a5c522ae Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:45:39 +0100 Subject: [PATCH 14/19] Add typing generics BoundingBoxTree --- python/dolfinx/geometry.py | 42 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/python/dolfinx/geometry.py b/python/dolfinx/geometry.py index 9600551042..900bb5b1f9 100644 --- a/python/dolfinx/geometry.py +++ b/python/dolfinx/geometry.py @@ -67,7 +67,7 @@ def dest_cells(self) -> npt.NDArray[np.int32]: return self._cpp_object.dest_cells -class BoundingBoxTree: +class BoundingBoxTree(typing.Generic[_T]): """Bounding box trees used in collision detection.""" _cpp_object: _cpp.geometry.BoundingBoxTree_float32 | _cpp.geometry.BoundingBoxTree_float64 @@ -88,7 +88,7 @@ def num_bboxes(self) -> int: return self._cpp_object.num_bboxes @property - def bbox_coordinates(self) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: + def bbox_coordinates(self) -> npt.NDArray[_T]: """Coordinates of lower and upper corners of bounding boxes. Note: @@ -97,7 +97,7 @@ def bbox_coordinates(self) -> npt.NDArray[np.float32] | npt.NDArray[np.float64]: """ return self._cpp_object.bbox_coordinates - def get_bbox(self, i) -> npt.NDArray[np.floating]: + def get_bbox(self, i) -> npt.NDArray[_T]: """Get lower and upper corners of the ith bounding box. Args: @@ -110,18 +110,18 @@ def get_bbox(self, i) -> npt.NDArray[np.floating]: """ return self._cpp_object.get_bbox(i) - def create_global_tree(self, comm) -> BoundingBoxTree: + def create_global_tree(self, comm) -> BoundingBoxTree[_T]: """Create a global bounding box tree.""" return BoundingBoxTree(self._cpp_object.create_global_tree(comm)) def bb_tree( - mesh: Mesh, + mesh: Mesh[_T], dim: int, *, padding: float = 0.0, entities: npt.NDArray[np.int32] | None = None, -) -> BoundingBoxTree: +) -> BoundingBoxTree[_T]: """Create a bounding box tree for use in collision detection. Args: @@ -153,7 +153,7 @@ def bb_tree( def compute_collisions_trees( - tree0: BoundingBoxTree, tree1: BoundingBoxTree + tree0: BoundingBoxTree[_T], tree1: BoundingBoxTree[_T] ) -> npt.NDArray[np.int32]: """Compute all collisions between two bounding box trees. @@ -169,7 +169,7 @@ def compute_collisions_trees( return _cpp.geometry.compute_collisions_trees(tree0._cpp_object, tree1._cpp_object) -def compute_collisions_points(tree: BoundingBoxTree, x: npt.NDArray[np.floating]) -> AdjacencyList: +def compute_collisions_points(tree: BoundingBoxTree[_T], x: npt.NDArray[_T]) -> AdjacencyList: """Compute collisions between points and leaf bounding boxes. Bounding boxes can overlap, therefore points can collide with more @@ -188,10 +188,10 @@ def compute_collisions_points(tree: BoundingBoxTree, x: npt.NDArray[np.floating] def compute_closest_entity( - tree: BoundingBoxTree, - midpoint_tree: BoundingBoxTree, - mesh: Mesh, - points: npt.NDArray[np.floating], + tree: BoundingBoxTree[_T], + midpoint_tree: BoundingBoxTree[_T], + mesh: Mesh[_T], + points: npt.NDArray[_T], ) -> npt.NDArray[np.int32]: """Compute closest mesh entity to a point. @@ -213,7 +213,9 @@ def compute_closest_entity( ) -def create_midpoint_tree(mesh: Mesh, dim: int, entities: npt.NDArray[np.int32]) -> BoundingBoxTree: +def create_midpoint_tree( + mesh: Mesh[_T], dim: int, entities: npt.NDArray[np.int32] +) -> BoundingBoxTree[_T]: """Create bounding box tree for the midpoints of a subset of entities. Args: @@ -228,7 +230,7 @@ def create_midpoint_tree(mesh: Mesh, dim: int, entities: npt.NDArray[np.int32]) def compute_colliding_cells( - msh: Mesh, candidates: AdjacencyList, x: npt.NDArray[np.floating] + msh: Mesh[_T], candidates: AdjacencyList, x: npt.NDArray[_T] ) -> AdjacencyList: """From a mesh, find which cells collide with a set of points. @@ -249,8 +251,8 @@ def compute_colliding_cells( def squared_distance( - mesh: Mesh, dim: int, entities: npt.NDArray[np.int32], points: npt.NDArray[np.floating] -) -> npt.NDArray[np.floating]: + mesh: Mesh[_T], dim: int, entities: npt.NDArray[np.int32], points: npt.NDArray[_T] +) -> npt.NDArray[_T]: """Compute the squared distance between a point and a mesh entity. The distance is computed between the ith input points and the ith @@ -270,9 +272,7 @@ def squared_distance( return _cpp.geometry.squared_distance(mesh._cpp_object, dim, entities, points) -def compute_distance_gjk( - p: npt.NDArray[np.floating], q: npt.NDArray[np.floating] -) -> npt.NDArray[np.floating]: +def compute_distance_gjk(p: npt.NDArray[_T], q: npt.NDArray[_T]) -> npt.NDArray[_T]: """Compute the distance between two convex bodies. Each body is defined by a set of points. Uses the @@ -294,8 +294,8 @@ def compute_distance_gjk( def compute_distances_gjk( - bodies: list[npt.NDArray[np.floating]], q: npt.NDArray[np.floating], num_threads: int -) -> npt.NDArray[np.floating]: + bodies: list[npt.NDArray[_T]], q: npt.NDArray[_T], num_threads: int +) -> npt.NDArray[_T]: """Compute the distance between a set of convex bodies. For each convex body defined in `bodies`; From bce56c5fbebbc90a14f360e1c14f6c44508c031a Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:50:11 +0100 Subject: [PATCH 15/19] Add typing generics FunctionSpace --- python/dolfinx/fem/function.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/python/dolfinx/fem/function.py b/python/dolfinx/fem/function.py index 1e310cc6d1..dafccb0eba 100644 --- a/python/dolfinx/fem/function.py +++ b/python/dolfinx/fem/function.py @@ -29,6 +29,7 @@ from dolfinx.mesh import Mesh +_T = TypeVar("_T", np.float32, np.float64) _S = TypeVar("_S", np.float32, np.float64, np.complex64, np.complex128) # scalar @@ -685,11 +686,11 @@ def functionspace( return FunctionSpace(mesh, ufl_e, cppV) -class FunctionSpace(ufl.FunctionSpace): +class FunctionSpace(ufl.FunctionSpace, Generic[_T]): """A space on which Functions (fields) can be defined.""" _cpp_object: _cpp.fem.FunctionSpace_float32 | _cpp.fem.FunctionSpace_float64 - _mesh: Mesh + _mesh: Mesh[_T] def __init__( self, @@ -716,7 +717,7 @@ def __init__( self._mesh = mesh super().__init__(ufl_domain, element) - def clone(self) -> FunctionSpace: + def clone(self) -> FunctionSpace[_T]: """Create a FunctionSpace which shares data with this space. The new space has a different unique integer ID. @@ -751,7 +752,7 @@ def num_sub_spaces(self) -> int: """Number of sub spaces.""" return self.element.num_sub_elements - def sub(self, i: int) -> FunctionSpace: + def sub(self, i: int) -> FunctionSpace[_T]: """Return the i-th sub space. Args: @@ -799,7 +800,7 @@ def ufl_function_space(self) -> ufl.FunctionSpace: return self @cached_property - def element(self) -> FiniteElement: + def element(self) -> FiniteElement[_T]: """Function space finite element.""" return FiniteElement(self._cpp_object.element) @@ -813,22 +814,22 @@ def dofmaps(self, idx: int) -> DofMap: return DofMap(self._cpp_object.dofmaps(idx)) @property - def mesh(self) -> Mesh: + def mesh(self) -> Mesh[_T]: """Mesh on which the function space is defined.""" return self._mesh - def collapse(self) -> tuple[FunctionSpace, np.ndarray]: + def collapse(self) -> tuple[FunctionSpace[_T], list[npt.NDArray[np.int32]]]: """Create a new function space by collapsing a subspace. Returns: A new function space and the map from new to old degrees-of-freedom. """ - cpp_space, dofs = self._cpp_object.collapse() # type: ignore + cpp_space, dofs = self._cpp_object.collapse() V = FunctionSpace(self._mesh, self.ufl_element(), cpp_space) return V, dofs - def tabulate_dof_coordinates(self) -> npt.NDArray[np.float64]: + def tabulate_dof_coordinates(self) -> npt.NDArray[_T]: """Tabulate coordinates of function space degrees-of-freedom. Returns: @@ -838,4 +839,4 @@ def tabulate_dof_coordinates(self) -> npt.NDArray[np.float64]: This method is only for elements with point evaluation degrees-of-freedom. """ - return self._cpp_object.tabulate_dof_coordinates() # type: ignore + return self._cpp_object.tabulate_dof_coordinates() From 349edca850d43fbe6f2bfceda12fdde1f47b6703 Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:51:45 +0100 Subject: [PATCH 16/19] Add typing generics Form --- python/dolfinx/fem/forms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index 1ab16e02d5..450d4f6717 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -32,8 +32,10 @@ from dolfinx.mesh import EntityMap as _EntityMap from dolfinx.mesh import Mesh, MeshTags +_S = typing.TypeVar("_S", np.float32, np.float64, np.complex64, np.complex128) # scalar -class Form: + +class Form(typing.Generic[_S]): """A finite element form.""" _cpp_object: ( From a0d66b7addb1d9f4b2650fbdfbf865a77b45be1d Mon Sep 17 00:00:00 2001 From: schnellerhase <56360279+schnellerhase@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:35:39 +0200 Subject: [PATCH 17/19] Centralise common names --- python/dolfinx/__init__.py | 3 +- python/dolfinx/fem/bcs.py | 20 ++++++------ python/dolfinx/fem/element.py | 27 ++++++++-------- python/dolfinx/fem/forms.py | 5 ++- python/dolfinx/fem/function.py | 56 ++++++++++++++++------------------ python/dolfinx/geometry.py | 51 +++++++++++++++---------------- python/dolfinx/graph.py | 16 +++++----- python/dolfinx/la/__init__.py | 18 +++++------ python/dolfinx/mesh.py | 14 ++++----- python/dolfinx/typing.py | 18 +++++++++++ 10 files changed, 119 insertions(+), 109 deletions(-) create mode 100644 python/dolfinx/typing.py diff --git a/python/dolfinx/__init__.py b/python/dolfinx/__init__.py index d0cbe1344b..2fc0d7b4e0 100644 --- a/python/dolfinx/__init__.py +++ b/python/dolfinx/__init__.py @@ -29,7 +29,7 @@ from dolfinx import common from dolfinx import cpp as _cpp -from dolfinx import fem, geometry, graph, io, jit, la, log, mesh, nls, plot +from dolfinx import fem, geometry, graph, io, jit, la, log, mesh, nls, plot, typing from dolfinx.common import ( git_commit_hash, @@ -79,6 +79,7 @@ def get_include(user=False): "mesh", "nls", "plot", + "typing", "git_commit_hash", "hardware_concurrency", "has_adios2", diff --git a/python/dolfinx/fem/bcs.py b/python/dolfinx/fem/bcs.py index 5d32f7e015..b8bb1be1ae 100644 --- a/python/dolfinx/fem/bcs.py +++ b/python/dolfinx/fem/bcs.py @@ -13,7 +13,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable -from typing import Generic, TypeVar +from typing import Generic import numpy as np import numpy.typing as npt @@ -21,6 +21,7 @@ import dolfinx from dolfinx import cpp as _cpp from dolfinx.fem.function import Constant, Function, FunctionSpace +from dolfinx.typing import Scalar def locate_dofs_geometrical( @@ -91,10 +92,7 @@ def locate_dofs_topological( return _cpp.fem.locate_dofs_topological(_V, entity_dim, _entities, remote) -_T = TypeVar("_T", np.float32, np.float64, np.complex64, np.complex128) - - -class DirichletBC(Generic[_T]): +class DirichletBC(Generic[Scalar]): """Representation of Dirichlet boundary conditions. The conditions are imposed on a linear system. @@ -133,7 +131,9 @@ def function_space(self) -> dolfinx.fem.FunctionSpace: """Function space on which the boundary condition is defined.""" return self._cpp_object.function_space - def set(self, x: npt.NDArray[_T], x0: npt.NDArray[_T] | None = None, alpha: float = 1) -> None: + def set( + self, x: npt.NDArray[Scalar], x0: npt.NDArray[Scalar] | None = None, alpha: float = 1 + ) -> None: """Set array entries that are constrained by a Dirichlet condition. Entries in ``x`` that are constrained by a Dirichlet boundary @@ -174,10 +174,10 @@ def dof_indices(self) -> tuple[npt.NDArray[np.int32], int]: def dirichletbc( - value: Function | Constant | npt.NDArray[_T] | float | complex, + value: Function | Constant | npt.NDArray[Scalar] | float | complex, dofs: npt.NDArray[np.int32], V: dolfinx.fem.FunctionSpace | None = None, -) -> DirichletBC[_T]: +) -> DirichletBC[Scalar]: """Representation of Dirichlet boundary condition. Args: @@ -234,8 +234,8 @@ def dirichletbc( def bcs_by_block( - spaces: Iterable[FunctionSpace | None], bcs: Iterable[DirichletBC[_T]] -) -> list[list[DirichletBC[_T]]]: + spaces: Iterable[FunctionSpace | None], bcs: Iterable[DirichletBC[Scalar]] +) -> list[list[DirichletBC[Scalar]]]: """Arrange boundary conditions by the space that they constrain. Given a sequence of function spaces ``spaces`` and a sequence of diff --git a/python/dolfinx/fem/element.py b/python/dolfinx/fem/element.py index 9e13b11f8f..b33244bf7d 100644 --- a/python/dolfinx/fem/element.py +++ b/python/dolfinx/fem/element.py @@ -6,7 +6,7 @@ """Finite elements.""" from functools import singledispatch -from typing import Generic, TypeVar +from typing import Generic import numpy as np import numpy.typing as npt @@ -14,11 +14,10 @@ import basix import basix.ufl from dolfinx import cpp as _cpp +from dolfinx.typing import Real -_T = TypeVar("_T", np.float32, np.float64) - -class CoordinateElement(Generic[_T]): +class CoordinateElement(Generic[Real]): """Coordinate element describing the geometry map for mesh cells.""" _cpp_object: _cpp.fem.CoordinateElement_float32 | _cpp.fem.CoordinateElement_float64 @@ -63,7 +62,9 @@ def create_dof_layout(self) -> _cpp.fem.ElementDofLayout: """Compute and return the dof layout.""" return self._cpp_object.create_dof_layout() - def push_forward(self, X: npt.NDArray[_T], cell_geometry: npt.NDArray[_T]) -> npt.NDArray[_T]: + def push_forward( + self, X: npt.NDArray[Real], cell_geometry: npt.NDArray[Real] + ) -> npt.NDArray[Real]: """Push points on the reference cell forward to the physical cell. Args: @@ -81,11 +82,11 @@ def push_forward(self, X: npt.NDArray[_T], cell_geometry: npt.NDArray[_T]) -> np def pull_back( self, - x: npt.NDArray[_T], - cell_geometry: npt.NDArray[_T], + x: npt.NDArray[Real], + cell_geometry: npt.NDArray[Real], tol: float = 1.0e-6, maxit: int = 15, - ) -> npt.NDArray[_T]: + ) -> npt.NDArray[Real]: """Pull points on the physical cell back to the reference cell. For non-affine cells, the pull-back is a nonlinear operation. @@ -169,7 +170,7 @@ def _(e: basix.finite_element.FiniteElement) -> CoordinateElement: return CoordinateElement(_cpp.fem.CoordinateElement_float64(e._e)) -class FiniteElement(Generic[_T]): +class FiniteElement(Generic[Real]): """A finite element.""" _cpp_object: _cpp.fem.FiniteElement_float32 | _cpp.fem.FiniteElement_float64 @@ -223,7 +224,7 @@ def value_shape(self) -> npt.NDArray[np.integer]: return self._cpp_object.value_shape @property - def interpolation_points(self) -> npt.NDArray[_T]: + def interpolation_points(self) -> npt.NDArray[Real]: """Points at which to evaluate the function to be interpolated. Interpolation point coordinates on the reference cell, returning @@ -281,7 +282,7 @@ def signature(self) -> str: return self._cpp_object.signature def T_apply( - self, x: npt.NDArray[_T], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[Real], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Transform basis from reference to physical ordering/orientation. @@ -304,7 +305,7 @@ def T_apply( self._cpp_object.T_apply(x, cell_permutations, dim) def Tt_apply( - self, x: npt.NDArray[_T], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[Real], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Apply the transpose of the operator applied by T_apply(). @@ -318,7 +319,7 @@ def Tt_apply( self._cpp_object.Tt_apply(x, cell_permutations, dim) def Tt_inv_apply( - self, x: npt.NDArray[_T], cell_permutations: npt.NDArray[np.uint32], dim: int + self, x: npt.NDArray[Real], cell_permutations: npt.NDArray[np.uint32], dim: int ) -> None: """Apply the inverse transpose of T_apply(). diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index 450d4f6717..59034b440d 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -25,6 +25,7 @@ from dolfinx import default_scalar_type, jit from dolfinx.fem import IntegralType from dolfinx.fem.function import Constant, Function, FunctionSpace +from dolfinx.typing import Scalar if typing.TYPE_CHECKING: # import dolfinx.mesh just when doing type checking to avoid @@ -32,10 +33,8 @@ from dolfinx.mesh import EntityMap as _EntityMap from dolfinx.mesh import Mesh, MeshTags -_S = typing.TypeVar("_S", np.float32, np.float64, np.complex64, np.complex128) # scalar - -class Form(typing.Generic[_S]): +class Form(typing.Generic[Scalar]): """A finite element form.""" _cpp_object: ( diff --git a/python/dolfinx/fem/function.py b/python/dolfinx/fem/function.py index dafccb0eba..58d9640365 100644 --- a/python/dolfinx/fem/function.py +++ b/python/dolfinx/fem/function.py @@ -11,7 +11,7 @@ import typing from collections.abc import Callable, Sequence from functools import cached_property, singledispatch -from typing import Generic, TypeVar +from typing import Generic import numpy as np import numpy.typing as npt @@ -23,17 +23,15 @@ from dolfinx.fem.dofmap import DofMap from dolfinx.fem.element import FiniteElement, finiteelement from dolfinx.geometry import PointOwnershipData +from dolfinx.typing import Real, Scalar if typing.TYPE_CHECKING: from mpi4py import MPI as _MPI from dolfinx.mesh import Mesh -_T = TypeVar("_T", np.float32, np.float64) -_S = TypeVar("_S", np.float32, np.float64, np.complex64, np.complex128) # scalar - -class Constant(ufl.Constant, Generic[_S]): +class Constant(ufl.Constant, Generic[Scalar]): """A constant with respect to a domain.""" _cpp_object: ( @@ -76,7 +74,7 @@ def value(self): return self._cpp_object.value @value.setter - def value(self, v: npt.NDArray[_S]) -> None: + def value(self, v: npt.NDArray[Scalar]) -> None: np.copyto(self._cpp_object.value, np.asarray(v)) @property @@ -99,7 +97,7 @@ def __complex__(self) -> complex: return complex(self.value) -class Expression(Generic[_S]): +class Expression(Generic[Scalar]): """An object for evaluating functions of finite element functions. Represents a mathematical expression evaluated at a pre-defined set @@ -215,8 +213,8 @@ def eval( self, mesh: Mesh, entities: npt.NDArray[np.int32], - values: npt.NDArray[_S] | None = None, - ) -> npt.NDArray[_S]: + values: npt.NDArray[Scalar] | None = None, + ) -> npt.NDArray[Scalar]: """Evaluate Expression on entities. Args: @@ -312,7 +310,7 @@ def dtype(self) -> npt.DTypeLike: return np.dtype(self._cpp_object.dtype) -class Function(ufl.Coefficient, Generic[_S]): +class Function(ufl.Coefficient, Generic[Scalar]): """A finite element function. A finite element function is represented by a function space @@ -328,12 +326,12 @@ class Function(ufl.Coefficient, Generic[_S]): | _cpp.fem.Function_float64 ) - _x: la.Vector[_S] + _x: la.Vector[Scalar] def __init__( self, V: FunctionSpace, - x: la.Vector[_S] | None = None, + x: la.Vector[Scalar] | None = None, name: str | None = None, dtype: npt.DTypeLike | None = None, ): @@ -402,10 +400,10 @@ def eval( self, x: npt.ArrayLike, cells: npt.NDArray[np.int32], - u: None | npt.NDArray[_S] = None, + u: None | npt.NDArray[Scalar] = None, tol: float = 1.0e-6, maxit: int = 15, - ) -> npt.NDArray[_S]: + ) -> npt.NDArray[Scalar]: """Evaluate Function at points x. Args: @@ -449,7 +447,7 @@ def eval( def interpolate_nonmatching( self, - u0: Function[_S], + u0: Function[Scalar], cells: npt.NDArray[np.int32], interpolation_data: PointOwnershipData, tol: float = 1e-6, @@ -475,7 +473,7 @@ def interpolate_nonmatching( def interpolate( self, - u0: Callable | Expression[_S] | Function[_S], + u0: Callable | Expression[Scalar] | Function[Scalar], cells0: npt.NDArray[np.int32] | None = None, cells1: npt.NDArray[np.int32] | None = None, ) -> None: @@ -523,7 +521,7 @@ def _(e0: Expression): ) self._cpp_object.interpolate_f(np.asarray(u0(x), dtype=self.dtype), cells0) - def copy(self) -> Function[_S]: + def copy(self) -> Function[Scalar]: """Create a copy of the Function. The function space is shared and the degree-of-freedom vector is @@ -539,7 +537,7 @@ def copy(self) -> Function[_S]: ) @property - def x(self) -> la.Vector[_S]: + def x(self) -> la.Vector[Scalar]: """Vector holding the degrees-of-freedom.""" return self._x @@ -561,7 +559,7 @@ def __str__(self) -> str: """Pretty print representation.""" return self.name - def sub(self, i: int) -> Function[_S]: + def sub(self, i: int) -> Function[Scalar]: """Return a sub-function (a view into the ``Function``). Sub-functions are indexed ``i = 0, ..., N-1``, where ``N`` is @@ -580,7 +578,7 @@ def sub(self, i: int) -> Function[_S]: """ return Function(self._V.sub(i), self.x, name=f"{self!s}_{i}") - def split(self) -> tuple[Function[_S], ...]: + def split(self) -> tuple[Function[Scalar], ...]: """Extract (any) sub-functions. A sub-function can be extracted from a discrete function that is @@ -595,7 +593,7 @@ def split(self) -> tuple[Function[_S], ...]: raise RuntimeError("No subfunctions to extract") return tuple(self.sub(i) for i in range(num_sub_spaces)) - def collapse(self) -> Function[_S]: + def collapse(self) -> Function[Scalar]: """Create a collapsed version of this Function.""" u_collapsed = self._cpp_object.collapse() # type: ignore V_collapsed = FunctionSpace( @@ -686,11 +684,11 @@ def functionspace( return FunctionSpace(mesh, ufl_e, cppV) -class FunctionSpace(ufl.FunctionSpace, Generic[_T]): +class FunctionSpace(ufl.FunctionSpace, Generic[Real]): """A space on which Functions (fields) can be defined.""" _cpp_object: _cpp.fem.FunctionSpace_float32 | _cpp.fem.FunctionSpace_float64 - _mesh: Mesh[_T] + _mesh: Mesh[Real] def __init__( self, @@ -717,7 +715,7 @@ def __init__( self._mesh = mesh super().__init__(ufl_domain, element) - def clone(self) -> FunctionSpace[_T]: + def clone(self) -> FunctionSpace[Real]: """Create a FunctionSpace which shares data with this space. The new space has a different unique integer ID. @@ -752,7 +750,7 @@ def num_sub_spaces(self) -> int: """Number of sub spaces.""" return self.element.num_sub_elements - def sub(self, i: int) -> FunctionSpace[_T]: + def sub(self, i: int) -> FunctionSpace[Real]: """Return the i-th sub space. Args: @@ -800,7 +798,7 @@ def ufl_function_space(self) -> ufl.FunctionSpace: return self @cached_property - def element(self) -> FiniteElement[_T]: + def element(self) -> FiniteElement[Real]: """Function space finite element.""" return FiniteElement(self._cpp_object.element) @@ -814,11 +812,11 @@ def dofmaps(self, idx: int) -> DofMap: return DofMap(self._cpp_object.dofmaps(idx)) @property - def mesh(self) -> Mesh[_T]: + def mesh(self) -> Mesh[Real]: """Mesh on which the function space is defined.""" return self._mesh - def collapse(self) -> tuple[FunctionSpace[_T], list[npt.NDArray[np.int32]]]: + def collapse(self) -> tuple[FunctionSpace[Real], list[npt.NDArray[np.int32]]]: """Create a new function space by collapsing a subspace. Returns: @@ -829,7 +827,7 @@ def collapse(self) -> tuple[FunctionSpace[_T], list[npt.NDArray[np.int32]]]: V = FunctionSpace(self._mesh, self.ufl_element(), cpp_space) return V, dofs - def tabulate_dof_coordinates(self) -> npt.NDArray[_T]: + def tabulate_dof_coordinates(self) -> npt.NDArray[Real]: """Tabulate coordinates of function space degrees-of-freedom. Returns: diff --git a/python/dolfinx/geometry.py b/python/dolfinx/geometry.py index 900bb5b1f9..5c3d6e40d0 100644 --- a/python/dolfinx/geometry.py +++ b/python/dolfinx/geometry.py @@ -18,6 +18,7 @@ from dolfinx import cpp as _cpp from dolfinx.graph import AdjacencyList +from dolfinx.typing import Real __all__ = [ "BoundingBoxTree", @@ -34,10 +35,8 @@ "squared_distance", ] -_T = typing.TypeVar("_T", np.float32, np.float64) - -class PointOwnershipData(typing.Generic[_T]): +class PointOwnershipData(typing.Generic[Real]): """Class for storing data related to the ownership of points.""" _cpp_object: _cpp.geometry.PointOwnershipData_float32 | _cpp.geometry.PointOwnershipData_float64 @@ -57,7 +56,7 @@ def dest_owner(self) -> npt.NDArray[np.int32]: return self._cpp_object.dest_owners @property - def dest_points(self) -> npt.NDArray[_T]: + def dest_points(self) -> npt.NDArray[Real]: """Points owned by current rank.""" return self._cpp_object.dest_points @@ -67,7 +66,7 @@ def dest_cells(self) -> npt.NDArray[np.int32]: return self._cpp_object.dest_cells -class BoundingBoxTree(typing.Generic[_T]): +class BoundingBoxTree(typing.Generic[Real]): """Bounding box trees used in collision detection.""" _cpp_object: _cpp.geometry.BoundingBoxTree_float32 | _cpp.geometry.BoundingBoxTree_float64 @@ -88,7 +87,7 @@ def num_bboxes(self) -> int: return self._cpp_object.num_bboxes @property - def bbox_coordinates(self) -> npt.NDArray[_T]: + def bbox_coordinates(self) -> npt.NDArray[Real]: """Coordinates of lower and upper corners of bounding boxes. Note: @@ -97,7 +96,7 @@ def bbox_coordinates(self) -> npt.NDArray[_T]: """ return self._cpp_object.bbox_coordinates - def get_bbox(self, i) -> npt.NDArray[_T]: + def get_bbox(self, i) -> npt.NDArray[Real]: """Get lower and upper corners of the ith bounding box. Args: @@ -110,18 +109,18 @@ def get_bbox(self, i) -> npt.NDArray[_T]: """ return self._cpp_object.get_bbox(i) - def create_global_tree(self, comm) -> BoundingBoxTree[_T]: + def create_global_tree(self, comm) -> BoundingBoxTree[Real]: """Create a global bounding box tree.""" return BoundingBoxTree(self._cpp_object.create_global_tree(comm)) def bb_tree( - mesh: Mesh[_T], + mesh: Mesh[Real], dim: int, *, padding: float = 0.0, entities: npt.NDArray[np.int32] | None = None, -) -> BoundingBoxTree[_T]: +) -> BoundingBoxTree[Real]: """Create a bounding box tree for use in collision detection. Args: @@ -153,7 +152,7 @@ def bb_tree( def compute_collisions_trees( - tree0: BoundingBoxTree[_T], tree1: BoundingBoxTree[_T] + tree0: BoundingBoxTree[Real], tree1: BoundingBoxTree[Real] ) -> npt.NDArray[np.int32]: """Compute all collisions between two bounding box trees. @@ -169,7 +168,7 @@ def compute_collisions_trees( return _cpp.geometry.compute_collisions_trees(tree0._cpp_object, tree1._cpp_object) -def compute_collisions_points(tree: BoundingBoxTree[_T], x: npt.NDArray[_T]) -> AdjacencyList: +def compute_collisions_points(tree: BoundingBoxTree[Real], x: npt.NDArray[Real]) -> AdjacencyList: """Compute collisions between points and leaf bounding boxes. Bounding boxes can overlap, therefore points can collide with more @@ -188,10 +187,10 @@ def compute_collisions_points(tree: BoundingBoxTree[_T], x: npt.NDArray[_T]) -> def compute_closest_entity( - tree: BoundingBoxTree[_T], - midpoint_tree: BoundingBoxTree[_T], - mesh: Mesh[_T], - points: npt.NDArray[_T], + tree: BoundingBoxTree[Real], + midpoint_tree: BoundingBoxTree[Real], + mesh: Mesh[Real], + points: npt.NDArray[Real], ) -> npt.NDArray[np.int32]: """Compute closest mesh entity to a point. @@ -214,8 +213,8 @@ def compute_closest_entity( def create_midpoint_tree( - mesh: Mesh[_T], dim: int, entities: npt.NDArray[np.int32] -) -> BoundingBoxTree[_T]: + mesh: Mesh[Real], dim: int, entities: npt.NDArray[np.int32] +) -> BoundingBoxTree[Real]: """Create bounding box tree for the midpoints of a subset of entities. Args: @@ -230,7 +229,7 @@ def create_midpoint_tree( def compute_colliding_cells( - msh: Mesh[_T], candidates: AdjacencyList, x: npt.NDArray[_T] + msh: Mesh[Real], candidates: AdjacencyList, x: npt.NDArray[Real] ) -> AdjacencyList: """From a mesh, find which cells collide with a set of points. @@ -251,8 +250,8 @@ def compute_colliding_cells( def squared_distance( - mesh: Mesh[_T], dim: int, entities: npt.NDArray[np.int32], points: npt.NDArray[_T] -) -> npt.NDArray[_T]: + mesh: Mesh[Real], dim: int, entities: npt.NDArray[np.int32], points: npt.NDArray[Real] +) -> npt.NDArray[Real]: """Compute the squared distance between a point and a mesh entity. The distance is computed between the ith input points and the ith @@ -272,7 +271,7 @@ def squared_distance( return _cpp.geometry.squared_distance(mesh._cpp_object, dim, entities, points) -def compute_distance_gjk(p: npt.NDArray[_T], q: npt.NDArray[_T]) -> npt.NDArray[_T]: +def compute_distance_gjk(p: npt.NDArray[Real], q: npt.NDArray[Real]) -> npt.NDArray[Real]: """Compute the distance between two convex bodies. Each body is defined by a set of points. Uses the @@ -294,8 +293,8 @@ def compute_distance_gjk(p: npt.NDArray[_T], q: npt.NDArray[_T]) -> npt.NDArray[ def compute_distances_gjk( - bodies: list[npt.NDArray[_T]], q: npt.NDArray[_T], num_threads: int -) -> npt.NDArray[_T]: + bodies: list[npt.NDArray[Real]], q: npt.NDArray[Real], num_threads: int +) -> npt.NDArray[Real]: """Compute the distance between a set of convex bodies. For each convex body defined in `bodies`; @@ -323,10 +322,10 @@ def compute_distances_gjk( def determine_point_ownership( mesh: Mesh, - points: npt.NDArray[_T], + points: npt.NDArray[Real], padding: float, cells: npt.NDArray[np.int32] | None = None, -) -> PointOwnershipData[_T]: +) -> PointOwnershipData[Real]: """Build point ownership data for a mesh-points pair. First, potential collisions are found by computing intersections diff --git a/python/dolfinx/graph.py b/python/dolfinx/graph.py index af8ec7afea..2f3f22ea4a 100644 --- a/python/dolfinx/graph.py +++ b/python/dolfinx/graph.py @@ -5,13 +5,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Graph representations and operations on graphs.""" -from typing import Generic, TypeVar +from typing import Generic import numpy as np import numpy.typing as npt from dolfinx import cpp as _cpp from dolfinx.cpp.graph import partitioner +from dolfinx.typing import Index # Import graph partitioners, which may or may not be available # (dependent on build configuration) @@ -39,10 +40,7 @@ ] -_T = TypeVar("_T", np.int32, np.int64) - - -class AdjacencyList(Generic[_T]): +class AdjacencyList(Generic[Index]): """Adjacency list representation of a graph.""" _cpp_object: ( @@ -74,7 +72,7 @@ def __repr__(self): """String representation of the adjacency list.""" return self._cpp_object.__repr__() - def links(self, node: int) -> npt.NDArray[_T]: + def links(self, node: int) -> npt.NDArray[Index]: """Retrieve the links of a node. Note: @@ -90,7 +88,7 @@ def links(self, node: int) -> npt.NDArray[_T]: return self._cpp_object.links(node) @property - def array(self) -> npt.NDArray[_T]: + def array(self) -> npt.NDArray[Index]: """Array representation of the adjacency list. Note: @@ -122,8 +120,8 @@ def num_nodes(self) -> np.int32: def adjacencylist( - data: npt.NDArray[_T], offsets: npt.NDArray[np.int32] | None = None -) -> AdjacencyList[_T]: + data: npt.NDArray[Index], offsets: npt.NDArray[np.int32] | None = None +) -> AdjacencyList[Index]: """Create an :class:`AdjacencyList` for `int32` or `int64` datasets. Args: diff --git a/python/dolfinx/la/__init__.py b/python/dolfinx/la/__init__.py index 662684e610..120fadef7e 100644 --- a/python/dolfinx/la/__init__.py +++ b/python/dolfinx/la/__init__.py @@ -14,6 +14,7 @@ from dolfinx import cpp as _cpp from dolfinx.cpp.common import IndexMap from dolfinx.cpp.la import BlockMode, InsertMode, Norm +from dolfinx.typing import Scalar __all__ = [ "InsertMode", @@ -122,10 +123,7 @@ def scatter_reverse(self, mode: InsertMode) -> None: self._cpp_object.scatter_reverse(mode) -_MT = TypeVar("_MT", np.float32, np.float64, np.complex64, np.complex128) - - -class MatrixCSR(Generic[_MT]): +class MatrixCSR(Generic[Scalar]): """Distributed compressed sparse row matrix.""" _cpp_object: ( @@ -163,7 +161,7 @@ def index_map(self, i: int) -> IndexMap: """ return self._cpp_object.index_map(i) - def mult(self, x: Vector[_MT], y: Vector[_MT], transpose: bool = False) -> None: + def mult(self, x: Vector[Scalar], y: Vector[Scalar], transpose: bool = False) -> None: """Compute ``y += Ax`` or ``y += A^T x``. Args: @@ -208,7 +206,7 @@ def block_size(self) -> list: def add( self, - x: npt.NDArray[_MT], + x: npt.NDArray[Scalar], rows: npt.NDArray[np.int32], cols: npt.NDArray[np.int32], bs: int = 1, @@ -218,7 +216,7 @@ def add( def set( self, - x: npt.NDArray[_MT], + x: npt.NDArray[Scalar], rows: npt.NDArray[np.int32], cols: npt.NDArray[np.int32], bs: int = 1, @@ -226,7 +224,7 @@ def set( """Set a block of values in the matrix.""" self._cpp_object.set(x, rows, cols, bs) - def set_value(self, x: _MT) -> None: + def set_value(self, x: Scalar) -> None: """Set all non-zero entries to a value. Args: @@ -247,7 +245,7 @@ def squared_norm(self) -> np.floating: return self._cpp_object.squared_norm() @property - def data(self) -> npt.NDArray[_MT]: + def data(self) -> npt.NDArray[Scalar]: """Underlying matrix entry data.""" return self._cpp_object.data @@ -261,7 +259,7 @@ def indptr(self) -> npt.NDArray[np.int64]: """Local row pointers.""" return self._cpp_object.indptr - def to_dense(self) -> npt.NDArray[_MT]: + def to_dense(self) -> npt.NDArray[Scalar]: """Copy to a dense 2D array. Note: diff --git a/python/dolfinx/mesh.py b/python/dolfinx/mesh.py index 3d12c09bdf..cec93ec3ee 100644 --- a/python/dolfinx/mesh.py +++ b/python/dolfinx/mesh.py @@ -43,6 +43,7 @@ from dolfinx.fem import CoordinateElement as _CoordinateElement from dolfinx.fem import coordinate_element as _coordinate_element from dolfinx.graph import AdjacencyList +from dolfinx.typing import Real __all__ = [ "CellType", @@ -85,9 +86,6 @@ ] -_T = typing.TypeVar("_T", np.float32, np.float64) - - @singledispatch def create_cell_partitioner( part: Callable, mode: GhostMode, max_facet_to_cell_links: int @@ -273,7 +271,7 @@ def cell_type(self) -> CellType: return self._cpp_object.cell_type -class Geometry(typing.Generic[_T]): +class Geometry(typing.Generic[Real]): """The geometry of a :class:`dolfinx.mesh.Mesh`.""" _cpp_object: _cpp.mesh.Geometry_float32 | _cpp.mesh.Geometry_float64 @@ -319,7 +317,7 @@ def input_global_indices(self) -> npt.NDArray[np.int64]: return self._cpp_object.input_global_indices @property - def x(self) -> npt.NDArray[_T]: + def x(self) -> npt.NDArray[Real]: """Geometry coordinate points. Shape is ``shape=(num_points, 3)``. @@ -327,12 +325,12 @@ def x(self) -> npt.NDArray[_T]: return self._cpp_object.x -class Mesh(typing.Generic[_T]): +class Mesh(typing.Generic[Real]): """A mesh.""" _mesh: _cpp.mesh.Mesh_float32 | _cpp.mesh.Mesh_float64 _topology: Topology - _geometry: Geometry[_T] + _geometry: Geometry[Real] _ufl_domain: ufl.Mesh | None def __init__( @@ -416,7 +414,7 @@ def topology(self) -> Topology: return self._topology @property - def geometry(self) -> Geometry[_T]: + def geometry(self) -> Geometry[Real]: """Mesh geometry.""" return self._geometry diff --git a/python/dolfinx/typing.py b/python/dolfinx/typing.py new file mode 100644 index 0000000000..78101bad1c --- /dev/null +++ b/python/dolfinx/typing.py @@ -0,0 +1,18 @@ +# Copyright (C) 2026 Paul T. Kühner +# +# This file is part of DOLFINx (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +"""Common typing functionality.""" + +from typing import TypeVar + +import numpy as np + +__all__ = ["Index", "Real", "Scalar"] + +Index = TypeVar("Index", np.int32, np.int64) + +Real = TypeVar("Real", np.float32, np.float64) +Scalar = TypeVar("Scalar", np.float32, np.float64, np.complex64, np.complex128) From 1090e705225e141f656ded286ffdeac22ec72c25 Mon Sep 17 00:00:00 2001 From: "Jack S. Hale" Date: Mon, 4 May 2026 13:06:53 +0200 Subject: [PATCH 18/19] Document SuperLU_DIST type limitations Add a comment about SuperLU_DIST type support. --- python/dolfinx/la/superlu_dist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/dolfinx/la/superlu_dist.py b/python/dolfinx/la/superlu_dist.py index 4a5d8e9c8b..0366dd5d2d 100644 --- a/python/dolfinx/la/superlu_dist.py +++ b/python/dolfinx/la/superlu_dist.py @@ -24,6 +24,7 @@ __all__ = ["SuperLUDistMatrix", "SuperLUDistSolver", "superlu_dist_matrix", "superlu_dist_solver"] +# As of 2026, SuperLU_DIST only supports these types, so the general Scalar type cannot be used. _T = TypeVar("_T", np.float32, np.float64, np.complex128) From 4094830f2d3e5db82eed84bd72697c9a0f6e28ee Mon Sep 17 00:00:00 2001 From: "Jack S. Hale" Date: Mon, 4 May 2026 14:23:36 +0200 Subject: [PATCH 19/19] Fix comment --- python/dolfinx/la/superlu_dist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/dolfinx/la/superlu_dist.py b/python/dolfinx/la/superlu_dist.py index 0366dd5d2d..e51ebd8782 100644 --- a/python/dolfinx/la/superlu_dist.py +++ b/python/dolfinx/la/superlu_dist.py @@ -24,7 +24,8 @@ __all__ = ["SuperLUDistMatrix", "SuperLUDistSolver", "superlu_dist_matrix", "superlu_dist_solver"] -# As of 2026, SuperLU_DIST only supports these types, so the general Scalar type cannot be used. +# As of 2026, SuperLU_DIST only supports these types, so the general Scalar +# type including np.complex64 cannot be used. _T = TypeVar("_T", np.float32, np.float64, np.complex128)