Skip to content

Commit 6ef4bbc

Browse files
committed
Towards grudge typing
This improves many aspects of typing in arraycontext: - It improves type checking (and consistency) in the traversals. - It allows scalars consistently. - It adds some overloads for traversal functions. - It adds `rec_map_container`, which is simpler (and easier to type). - It makes array containers recognizable to the type checker. This works via a heuristic, by having `__array_ufunc__ == None`. - It adds more types in the base fake numpy and shifts some implementation aspects there.
1 parent 9f0934b commit 6ef4bbc

21 files changed

Lines changed: 1335 additions & 604 deletions

arraycontext/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
multimapped_over_array_containers,
6464
outer,
6565
rec_map_array_container,
66+
rec_map_container,
6667
rec_map_reduce_array_container,
6768
rec_multimap_array_container,
6869
rec_multimap_reduce_array_container,
@@ -154,6 +155,7 @@
154155
"outer",
155156
"pytest_generate_tests_for_array_contexts",
156157
"rec_map_array_container",
158+
"rec_map_container",
157159
"rec_map_reduce_array_container",
158160
"rec_multimap_array_container",
159161
"rec_multimap_reduce_array_container",

arraycontext/container/__init__.py

Lines changed: 87 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,42 @@
1-
# mypy: disallow-untyped-defs
2-
31
"""
42
.. currentmodule:: arraycontext
53
64
.. autoclass:: ArrayContainer
5+
A protocol for generic containers of the array type supported by the
6+
:class:`ArrayContext`.
7+
8+
The functionality required for the container to operated is supplied via
9+
:func:`functools.singledispatch`. Implementations of the following functions need
10+
to be registered for a type serving as an :class:`ArrayContainer`:
11+
12+
* :func:`serialize_container` for serialization, which gives the components
13+
of the array.
14+
* :func:`deserialize_container` for deserialization, which constructs a
15+
container from a set of components.
16+
* :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
17+
a container, if it has one.
18+
19+
This allows enumeration of the component arrays in a container and the
20+
construction of modified containers from an iterable of those component arrays.
21+
22+
Packages may register their own types as array containers. They must not
23+
register other types (e.g. :class:`list`) as array containers.
24+
The type :class:`numpy.ndarray` is considered an array container, but
25+
only arrays with dtype *object* may be used as such. (This is so
26+
because object arrays cannot be distinguished from non-object arrays
27+
via their type.)
28+
29+
The container and its serialization interface has goals and uses
30+
approaches similar to JAX's
31+
`PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
32+
however its implementation differs a bit.
33+
34+
.. note::
35+
36+
This class is used in type annotation and as a marker of array container
37+
attributes for :func:`~arraycontext.dataclass_array_container`.
38+
As a protocol, it is not intended as a superclass.
39+
k
740
.. autoclass:: ArithArrayContainer
841
.. class:: ArrayContainerT
942
@@ -51,6 +84,12 @@
5184

5285
from __future__ import annotations
5386

87+
from types import GenericAlias, UnionType
88+
89+
from numpy.typing import NDArray
90+
91+
from arraycontext.context import ArrayOrArithContainer, ArrayOrContainerOrScalar
92+
5493

5594
__copyright__ = """
5695
Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -78,75 +117,45 @@
78117

79118
from collections.abc import Hashable, Sequence
80119
from functools import singledispatch
81-
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
120+
from typing import (
121+
TYPE_CHECKING,
122+
Any,
123+
ClassVar,
124+
Protocol,
125+
TypeAlias,
126+
TypeVar,
127+
get_origin,
128+
)
82129

83130
# For use in singledispatch type annotations, because sphinx can't figure out
84131
# what 'np' is.
85132
import numpy
86133
import numpy as np
87-
from typing_extensions import Self
134+
from typing_extensions import Self, TypeIs
88135

89136

90137
if TYPE_CHECKING:
91-
from pymbolic.geometric_algebra import MultiVector
138+
from pymbolic.geometric_algebra import CoeffT, MultiVector
92139

93-
from arraycontext import ArrayOrContainer
94140
from arraycontext.context import ArrayContext, ArrayOrScalar
95141

96142

97143
# {{{ ArrayContainer
98144

99-
class ArrayContainer(Protocol):
100-
"""
101-
A protocol for generic containers of the array type supported by the
102-
:class:`ArrayContext`.
103-
104-
The functionality required for the container to operated is supplied via
105-
:func:`functools.singledispatch`. Implementations of the following functions need
106-
to be registered for a type serving as an :class:`ArrayContainer`:
107-
108-
* :func:`serialize_container` for serialization, which gives the components
109-
of the array.
110-
* :func:`deserialize_container` for deserialization, which constructs a
111-
container from a set of components.
112-
* :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
113-
a container, if it has one.
114-
115-
This allows enumeration of the component arrays in a container and the
116-
construction of modified containers from an iterable of those component arrays.
117-
118-
Packages may register their own types as array containers. They must not
119-
register other types (e.g. :class:`list`) as array containers.
120-
The type :class:`numpy.ndarray` is considered an array container, but
121-
only arrays with dtype *object* may be used as such. (This is so
122-
because object arrays cannot be distinguished from non-object arrays
123-
via their type.)
124-
125-
The container and its serialization interface has goals and uses
126-
approaches similar to JAX's
127-
`PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
128-
however its implementation differs a bit.
129-
130-
.. note::
131-
132-
This class is used in type annotation and as a marker of array container
133-
attributes for :func:`~arraycontext.dataclass_array_container`.
134-
As a protocol, it is not intended as a superclass.
135-
"""
136-
137-
# Array containers do not need to have any particular features, so this
138-
# protocol is deliberately empty.
139-
140-
# This *is* used as a type annotation in dataclasses that are processed
145+
class _UserDefinedArrayContainer(Protocol):
146+
# This is used as a type annotation in dataclasses that are processed
141147
# by dataclass_array_container, where it's used to recognize attributes
142148
# that are container-typed.
143149

150+
# This method prevents ArrayContainer from matching any object, while
151+
# matching numpy object arrays and many array containers.
152+
__array_ufunc__: ClassVar[None]
144153

145-
class ArithArrayContainer(ArrayContainer, Protocol):
146-
"""
147-
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
148-
"""
149154

155+
ArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArrayContainer
156+
157+
158+
class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol):
150159
# This is loose and permissive, assuming that any array can be added
151160
# to any container. The alternative would be to plaster type-ignores
152161
# on all those uses. Achieving typing precision on what broadcasting is
@@ -167,6 +176,9 @@ def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
167176
def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
168177

169178

179+
ArithArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArithArrayContainer
180+
181+
170182
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
171183

172184

@@ -175,7 +187,8 @@ class NotAnArrayContainerError(TypeError):
175187

176188

177189
SerializationKey: TypeAlias = Hashable
178-
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
190+
SerializedContainer: TypeAlias = Sequence[
191+
tuple[SerializationKey, ArrayOrContainerOrScalar]]
179192

180193

181194
@singledispatch
@@ -221,7 +234,7 @@ def deserialize_container(
221234
f"'{type(template).__name__}' cannot be deserialized as a container")
222235

223236

224-
def is_array_container_type(cls: type) -> bool:
237+
def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
225238
"""
226239
:returns: *True* if the type *cls* has a registered implementation of
227240
:func:`serialize_container`, or if it is an :class:`ArrayContainer`.
@@ -233,15 +246,22 @@ def is_array_container_type(cls: type) -> bool:
233246
function will say that :class:`numpy.ndarray` is an array container
234247
type, only object arrays *actually are* array containers.
235248
"""
236-
assert isinstance(cls, type), f"must pass a {type!r}, not a '{cls!r}'"
249+
if cls is ArrayContainer:
250+
return True
251+
252+
while isinstance(cls, GenericAlias):
253+
cls = get_origin(cls)
254+
255+
assert isinstance(cls, type), (
256+
f"must pass a {type!r}, not a '{cls!r}'")
237257

238258
return (
239-
cls is ArrayContainer
259+
cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
240260
or (serialize_container.dispatch(cls)
241261
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
242262

243263

244-
def is_array_container(ary: object) -> bool:
264+
def is_array_container(ary: object) -> TypeIs[ArrayContainer]:
245265
"""
246266
:returns: *True* if the instance *ary* has a registered implementation of
247267
:func:`serialize_container`.
@@ -317,7 +337,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]
317337
# {{{ get_container_context_recursively
318338

319339
def get_container_context_recursively_opt(
320-
ary: ArrayContainer) -> ArrayContext | None:
340+
ary: ArrayOrContainerOrScalar) -> ArrayContext | None:
321341
"""Walks the :class:`ArrayContainer` hierarchy to find an
322342
:class:`ArrayContext` associated with it.
323343
@@ -351,7 +371,7 @@ def get_container_context_recursively_opt(
351371
return actx
352372

353373

354-
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
374+
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext:
355375
"""Walks the :class:`ArrayContainer` hierarchy to find an
356376
:class:`ArrayContext` associated with it.
357377
@@ -362,13 +382,7 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
362382
"""
363383
actx = get_container_context_recursively_opt(ary)
364384
if actx is None:
365-
# raise ValueError("no array context was found")
366-
from warnings import warn
367-
warn("No array context was found. This will be an error starting in "
368-
"July of 2022. If you would like the function to return "
369-
"None if no array context was found, use "
370-
"get_container_context_recursively_opt.",
371-
DeprecationWarning, stacklevel=2)
385+
raise ValueError("no array context was found")
372386

373387
return actx
374388

@@ -380,19 +394,20 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
380394
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
381395
# (Though clearly there exists a dependency via loopy.)
382396

383-
def _serialize_multivec_as_container(mv: MultiVector) -> SerializedContainer:
397+
def _serialize_multivec_as_container(
398+
mv: MultiVector[ArrayOrArithContainer]
399+
) -> SerializedContainer:
384400
return list(mv.data.items())
385401

386402

387-
# FIXME: Ignored due to https://github.com/python/mypy/issues/13040
388-
def _deserialize_multivec_as_container( # type: ignore[misc]
389-
template: MultiVector,
390-
serialized: SerializedContainer) -> MultiVector:
403+
def _deserialize_multivec_as_container(
404+
template: MultiVector[CoeffT],
405+
serialized: SerializedContainer) -> MultiVector[CoeffT]:
391406
from pymbolic.geometric_algebra import MultiVector
392407
return MultiVector(dict(serialized), space=template.space)
393408

394409

395-
def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:
410+
def _get_container_context_opt_from_multivec(mv: MultiVector[CoeffT]) -> None:
396411
return None
397412

398413

arraycontext/container/arithmetic.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@
6262
if TYPE_CHECKING:
6363
from collections.abc import Callable
6464

65-
from arraycontext.context import ArrayContext, ArrayOrContainer
65+
from arraycontext.context import (
66+
ArrayContext,
67+
ArrayOrContainer,
68+
ArrayOrContainerOrScalar,
69+
)
6670

6771

6872
# {{{ with_container_arithmetic
@@ -772,11 +776,11 @@ def __post_init__(self) -> None:
772776

773777
def _binary_op(self,
774778
op: Callable[
775-
[ArrayOrContainer, ArrayOrContainer],
776-
ArrayOrContainer
779+
[ArrayOrContainerOrScalar, ArrayOrContainerOrScalar],
780+
ArrayOrContainerOrScalar
777781
],
778-
right: ArrayOrContainer
779-
) -> ArrayOrContainer:
782+
right: ArrayOrContainerOrScalar
783+
) -> ArrayOrContainerOrScalar:
780784
try:
781785
serialized = serialize_container(right)
782786
except NotAnArrayContainerError:
@@ -791,11 +795,11 @@ def _binary_op(self,
791795

792796
def _rev_binary_op(self,
793797
op: Callable[
794-
[ArrayOrContainer, ArrayOrContainer],
795-
ArrayOrContainer
798+
[ArrayOrContainerOrScalar, ArrayOrContainerOrScalar],
799+
ArrayOrContainerOrScalar
796800
],
797-
left: ArrayOrContainer
798-
) -> ArrayOrContainer:
801+
left: ArrayOrContainerOrScalar
802+
) -> ArrayOrContainerOrScalar:
799803
try:
800804
serialized = serialize_container(left)
801805
except NotAnArrayContainerError:

arraycontext/container/dataclass.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# mypy: disallow-untyped-defs
2-
31
"""
42
.. currentmodule:: arraycontext
53
.. autofunction:: dataclass_array_container
@@ -34,7 +32,7 @@
3432
from dataclasses import fields, is_dataclass
3533
from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin
3634

37-
from arraycontext.container import is_array_container_type
35+
from arraycontext.container import ArrayContainer, is_array_container_type
3836

3937

4038
if TYPE_CHECKING:
@@ -99,7 +97,12 @@ def is_array_field(f: _Field) -> bool:
9997
#
10098
# This is not set in stone, but mostly driven by current usage!
10199

100+
# pyright has no idea what we're up to. :)
101+
if field_type is ArrayContainer: # pyright: ignore[reportUnnecessaryComparison]
102+
return True
103+
102104
origin = get_origin(field_type)
105+
103106
# NOTE: `UnionType` is returned when using `Type1 | Type2`
104107
if origin in (Union, UnionType):
105108
if all(is_array_type(arg) for arg in get_args(field_type)):

0 commit comments

Comments
 (0)