1- # mypy: disallow-untyped-defs
2-
31"""
42.. currentmodule:: arraycontext
53
64.. autoclass:: ArrayContainer
5+ A protocol for generic containers of the array type supported by the
6+ :class:`ArrayContext`.
7+
8+ The functionality required for the container to operated is supplied via
9+ :func:`functools.singledispatch`. Implementations of the following functions need
10+ to be registered for a type serving as an :class:`ArrayContainer`:
11+
12+ * :func:`serialize_container` for serialization, which gives the components
13+ of the array.
14+ * :func:`deserialize_container` for deserialization, which constructs a
15+ container from a set of components.
16+ * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
17+ a container, if it has one.
18+
19+ This allows enumeration of the component arrays in a container and the
20+ construction of modified containers from an iterable of those component arrays.
21+
22+ Packages may register their own types as array containers. They must not
23+ register other types (e.g. :class:`list`) as array containers.
24+ The type :class:`numpy.ndarray` is considered an array container, but
25+ only arrays with dtype *object* may be used as such. (This is so
26+ because object arrays cannot be distinguished from non-object arrays
27+ via their type.)
28+
29+ The container and its serialization interface has goals and uses
30+ approaches similar to JAX's
31+ `PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
32+ however its implementation differs a bit.
33+
34+ .. note::
35+
36+ This class is used in type annotation and as a marker of array container
37+ attributes for :func:`~arraycontext.dataclass_array_container`.
38+ As a protocol, it is not intended as a superclass.
39+ k
740.. autoclass:: ArithArrayContainer
841.. class:: ArrayContainerT
942
5184
5285from __future__ import annotations
5386
87+ from types import GenericAlias , UnionType
88+
89+ from numpy .typing import NDArray
90+
91+ from arraycontext .context import ArrayOrArithContainer , ArrayOrContainerOrScalar
92+
5493
5594__copyright__ = """
5695Copyright (C) 2020-1 University of Illinois Board of Trustees
78117
79118from collections .abc import Hashable , Sequence
80119from functools import singledispatch
81- from typing import TYPE_CHECKING , Protocol , TypeAlias , TypeVar
120+ from typing import (
121+ TYPE_CHECKING ,
122+ Any ,
123+ ClassVar ,
124+ Protocol ,
125+ TypeAlias ,
126+ TypeVar ,
127+ get_origin ,
128+ )
82129
83130# For use in singledispatch type annotations, because sphinx can't figure out
84131# what 'np' is.
85132import numpy
86133import numpy as np
87- from typing_extensions import Self
134+ from typing_extensions import Self , TypeIs
88135
89136
90137if TYPE_CHECKING :
91- from pymbolic .geometric_algebra import MultiVector
138+ from pymbolic .geometric_algebra import CoeffT , MultiVector
92139
93- from arraycontext import ArrayOrContainer
94140 from arraycontext .context import ArrayContext , ArrayOrScalar
95141
96142
97143# {{{ ArrayContainer
98144
99- class ArrayContainer (Protocol ):
100- """
101- A protocol for generic containers of the array type supported by the
102- :class:`ArrayContext`.
103-
104- The functionality required for the container to operated is supplied via
105- :func:`functools.singledispatch`. Implementations of the following functions need
106- to be registered for a type serving as an :class:`ArrayContainer`:
107-
108- * :func:`serialize_container` for serialization, which gives the components
109- of the array.
110- * :func:`deserialize_container` for deserialization, which constructs a
111- container from a set of components.
112- * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
113- a container, if it has one.
114-
115- This allows enumeration of the component arrays in a container and the
116- construction of modified containers from an iterable of those component arrays.
117-
118- Packages may register their own types as array containers. They must not
119- register other types (e.g. :class:`list`) as array containers.
120- The type :class:`numpy.ndarray` is considered an array container, but
121- only arrays with dtype *object* may be used as such. (This is so
122- because object arrays cannot be distinguished from non-object arrays
123- via their type.)
124-
125- The container and its serialization interface has goals and uses
126- approaches similar to JAX's
127- `PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
128- however its implementation differs a bit.
129-
130- .. note::
131-
132- This class is used in type annotation and as a marker of array container
133- attributes for :func:`~arraycontext.dataclass_array_container`.
134- As a protocol, it is not intended as a superclass.
135- """
136-
137- # Array containers do not need to have any particular features, so this
138- # protocol is deliberately empty.
139-
140- # This *is* used as a type annotation in dataclasses that are processed
145+ class _UserDefinedArrayContainer (Protocol ):
146+ # This is used as a type annotation in dataclasses that are processed
141147 # by dataclass_array_container, where it's used to recognize attributes
142148 # that are container-typed.
143149
150+ # This method prevents ArrayContainer from matching any object, while
151+ # matching numpy object arrays and many array containers.
152+ __array_ufunc__ : ClassVar [None ]
144153
145- class ArithArrayContainer (ArrayContainer , Protocol ):
146- """
147- A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
148- """
149154
155+ ArrayContainer : TypeAlias = NDArray [Any ] | _UserDefinedArrayContainer
156+
157+
158+ class _UserDefinedArithArrayContainer (_UserDefinedArrayContainer , Protocol ):
150159 # This is loose and permissive, assuming that any array can be added
151160 # to any container. The alternative would be to plaster type-ignores
152161 # on all those uses. Achieving typing precision on what broadcasting is
@@ -167,6 +176,9 @@ def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
167176 def __rpow__ (self , other : ArrayOrScalar | Self ) -> Self : ...
168177
169178
179+ ArithArrayContainer : TypeAlias = NDArray [Any ] | _UserDefinedArithArrayContainer
180+
181+
170182ArrayContainerT = TypeVar ("ArrayContainerT" , bound = ArrayContainer )
171183
172184
@@ -175,7 +187,8 @@ class NotAnArrayContainerError(TypeError):
175187
176188
177189SerializationKey : TypeAlias = Hashable
178- SerializedContainer : TypeAlias = Sequence [tuple [SerializationKey , "ArrayOrContainer" ]]
190+ SerializedContainer : TypeAlias = Sequence [
191+ tuple [SerializationKey , ArrayOrContainerOrScalar ]]
179192
180193
181194@singledispatch
@@ -221,7 +234,7 @@ def deserialize_container(
221234 f"'{ type (template ).__name__ } ' cannot be deserialized as a container" )
222235
223236
224- def is_array_container_type (cls : type ) -> bool :
237+ def is_array_container_type (cls : type | GenericAlias | UnionType ) -> bool :
225238 """
226239 :returns: *True* if the type *cls* has a registered implementation of
227240 :func:`serialize_container`, or if it is an :class:`ArrayContainer`.
@@ -233,15 +246,22 @@ def is_array_container_type(cls: type) -> bool:
233246 function will say that :class:`numpy.ndarray` is an array container
234247 type, only object arrays *actually are* array containers.
235248 """
236- assert isinstance (cls , type ), f"must pass a { type !r} , not a '{ cls !r} '"
249+ if cls is ArrayContainer :
250+ return True
251+
252+ while isinstance (cls , GenericAlias ):
253+ cls = get_origin (cls )
254+
255+ assert isinstance (cls , type ), (
256+ f"must pass a { type !r} , not a '{ cls !r} '" )
237257
238258 return (
239- cls is ArrayContainer
259+ cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
240260 or (serialize_container .dispatch (cls )
241261 is not serialize_container .__wrapped__ )) # type:ignore[attr-defined]
242262
243263
244- def is_array_container (ary : object ) -> bool :
264+ def is_array_container (ary : object ) -> TypeIs [ ArrayContainer ] :
245265 """
246266 :returns: *True* if the instance *ary* has a registered implementation of
247267 :func:`serialize_container`.
@@ -317,7 +337,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]
317337# {{{ get_container_context_recursively
318338
319339def get_container_context_recursively_opt (
320- ary : ArrayContainer ) -> ArrayContext | None :
340+ ary : ArrayOrContainerOrScalar ) -> ArrayContext | None :
321341 """Walks the :class:`ArrayContainer` hierarchy to find an
322342 :class:`ArrayContext` associated with it.
323343
@@ -351,7 +371,7 @@ def get_container_context_recursively_opt(
351371 return actx
352372
353373
354- def get_container_context_recursively (ary : ArrayContainer ) -> ArrayContext | None :
374+ def get_container_context_recursively (ary : ArrayContainer ) -> ArrayContext :
355375 """Walks the :class:`ArrayContainer` hierarchy to find an
356376 :class:`ArrayContext` associated with it.
357377
@@ -362,13 +382,7 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
362382 """
363383 actx = get_container_context_recursively_opt (ary )
364384 if actx is None :
365- # raise ValueError("no array context was found")
366- from warnings import warn
367- warn ("No array context was found. This will be an error starting in "
368- "July of 2022. If you would like the function to return "
369- "None if no array context was found, use "
370- "get_container_context_recursively_opt." ,
371- DeprecationWarning , stacklevel = 2 )
385+ raise ValueError ("no array context was found" )
372386
373387 return actx
374388
@@ -380,19 +394,20 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
380394# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
381395# (Though clearly there exists a dependency via loopy.)
382396
383- def _serialize_multivec_as_container (mv : MultiVector ) -> SerializedContainer :
397+ def _serialize_multivec_as_container (
398+ mv : MultiVector [ArrayOrArithContainer ]
399+ ) -> SerializedContainer :
384400 return list (mv .data .items ())
385401
386402
387- # FIXME: Ignored due to https://github.com/python/mypy/issues/13040
388- def _deserialize_multivec_as_container ( # type: ignore[misc]
389- template : MultiVector ,
390- serialized : SerializedContainer ) -> MultiVector :
403+ def _deserialize_multivec_as_container (
404+ template : MultiVector [CoeffT ],
405+ serialized : SerializedContainer ) -> MultiVector [CoeffT ]:
391406 from pymbolic .geometric_algebra import MultiVector
392407 return MultiVector (dict (serialized ), space = template .space )
393408
394409
395- def _get_container_context_opt_from_multivec (mv : MultiVector ) -> None :
410+ def _get_container_context_opt_from_multivec (mv : MultiVector [ CoeffT ] ) -> None :
396411 return None
397412
398413
0 commit comments