Skip to content

Commit 4044d68

Browse files
committed
Centralize typing stuff in arraycontext.typing
1 parent 30d8a4d commit 4044d68

23 files changed

Lines changed: 404 additions & 336 deletions

arraycontext/__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@
3030
"""
3131

3232
from .container import (
33-
ArithArrayContainer,
34-
ArrayContainer,
35-
ArrayContainerT,
3633
NotAnArrayContainerError,
3734
SerializationKey,
3835
SerializedContainer,
@@ -72,9 +69,26 @@
7269
with_array_context,
7370
)
7471
from .context import (
75-
Array,
7672
ArrayContext,
7773
ArrayContextFactory,
74+
tag_axes,
75+
)
76+
from .impl.jax import EagerJAXArrayContext
77+
from .impl.numpy import NumpyArrayContext
78+
from .impl.pyopencl import PyOpenCLArrayContext
79+
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
80+
from .loopy import make_loopy_program
81+
from .pytest import (
82+
PytestArrayContextFactory,
83+
PytestPyOpenCLArrayContextFactory,
84+
pytest_generate_tests_for_array_contexts,
85+
)
86+
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
87+
from .typing import (
88+
ArithArrayContainer,
89+
Array,
90+
ArrayContainer,
91+
ArrayContainerT,
7892
ArrayOrArithContainer,
7993
ArrayOrArithContainerOrScalar,
8094
ArrayOrArithContainerOrScalarT,
@@ -86,21 +100,10 @@
86100
ArrayOrScalar,
87101
ArrayOrScalarT,
88102
ArrayT,
103+
ContainerOrScalarT,
89104
Scalar,
90105
ScalarLike,
91-
tag_axes,
92-
)
93-
from .impl.jax import EagerJAXArrayContext
94-
from .impl.numpy import NumpyArrayContext
95-
from .impl.pyopencl import PyOpenCLArrayContext
96-
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
97-
from .loopy import make_loopy_program
98-
from .pytest import (
99-
PytestArrayContextFactory,
100-
PytestPyOpenCLArrayContextFactory,
101-
pytest_generate_tests_for_array_contexts,
102106
)
103-
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
104107

105108

106109
__all__ = (
@@ -123,6 +126,7 @@
123126
"ArrayT",
124127
"BcastUntilActxArray",
125128
"CommonSubexpressionTag",
129+
"ContainerOrScalarT",
126130
"EagerJAXArrayContext",
127131
"ElementwiseMapKernelTag",
128132
"NotAnArrayContainerError",

arraycontext/container/__init__.py

Lines changed: 10 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@
4949
This should be considered experimental for now, and it may well change.
5050
5151
.. autoclass:: ArithArrayContainer
52-
.. class:: ArrayContainerT
53-
54-
A type variable with a lower bound of :class:`ArrayContainer`.
52+
.. autoclass:: ArrayContainerT
5553
5654
.. autoexception:: NotAnArrayContainerError
5755
@@ -125,80 +123,35 @@
125123
from types import GenericAlias, UnionType
126124
from typing import (
127125
TYPE_CHECKING,
128-
ClassVar,
129-
Protocol,
130126
TypeAlias,
131-
TypeVar,
132127
get_origin,
133128
)
134129

135130
# For use in singledispatch type annotations, because sphinx can't figure out
136131
# what 'np' is.
137132
import numpy
138133
import numpy as np
139-
from typing_extensions import Self, TypeIs
134+
from typing_extensions import TypeIs
140135

141-
from pytools.obj_array import ObjectArrayND
136+
from pytools.obj_array import ObjectArrayND as ObjectArrayND
142137

143-
from arraycontext.context import (
138+
from arraycontext.typing import (
139+
ArrayContainer,
140+
ArrayContainerT,
144141
ArrayOrArithContainer,
145-
ArrayOrArithContainerOrScalar,
142+
ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar,
146143
ArrayOrContainerOrScalar,
147144
)
148145

149146

150147
if TYPE_CHECKING:
151148
from pymbolic.geometric_algebra import CoeffT, MultiVector
152149

153-
from arraycontext.context import ArrayContext, ArrayOrScalar
154-
155-
156-
# {{{ ArrayContainer
157-
158-
class _UserDefinedArrayContainer(Protocol):
159-
# This is used as a type annotation in dataclasses that are processed
160-
# by dataclass_array_container, where it's used to recognize attributes
161-
# that are container-typed.
162-
163-
# This method prevents ArrayContainer from matching any object, while
164-
# matching numpy object arrays and many array containers.
165-
__array_ufunc__: ClassVar[None]
166-
167-
168-
ArrayContainer: TypeAlias = (
169-
ObjectArrayND[ArrayOrContainerOrScalar]
170-
| _UserDefinedArrayContainer
171-
)
172-
173-
174-
class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol):
175-
# This is loose and permissive, assuming that any array can be added
176-
# to any container. The alternative would be to plaster type-ignores
177-
# on all those uses. Achieving typing precision on what broadcasting is
178-
# allowable seems like a huge endeavor and is likely not feasible without
179-
# a mypy plugin. Maybe some day? -AK, November 2024
180-
181-
def __neg__(self) -> Self: ...
182-
def __abs__(self) -> Self: ...
183-
def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
184-
def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
185-
def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
186-
def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
187-
def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
188-
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
189-
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
190-
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...
191-
def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
192-
def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
193-
194-
195-
ArithArrayContainer: TypeAlias = (
196-
ObjectArrayND[ArrayOrArithContainerOrScalar]
197-
| _UserDefinedArithArrayContainer)
198-
150+
from arraycontext.context import ArrayContext
151+
from arraycontext.typing import ArrayOrScalar as ArrayOrScalar
199152

200-
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
201153

154+
# {{{ ArrayContainer traversals
202155

203156
class NotAnArrayContainerError(TypeError):
204157
""":class:`TypeError` subclass raised when an array container is expected."""

arraycontext/container/arithmetic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@
6868
if TYPE_CHECKING:
6969
from collections.abc import Callable
7070

71-
from arraycontext.context import (
72-
ArrayContext,
71+
from arraycontext.context import ArrayContext
72+
from arraycontext.typing import (
7373
ArrayOrContainer,
7474
ArrayOrContainerOrScalar,
7575
)

arraycontext/container/dataclass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,12 @@
5656

5757
from pytools.obj_array import ObjectArray
5858

59-
from arraycontext.container import ArrayContainer, is_array_container_type
60-
from arraycontext.context import ArrayOrContainer, ArrayOrContainerOrScalar
59+
from arraycontext.container import is_array_container_type
60+
from arraycontext.typing import (
61+
ArrayContainer,
62+
ArrayOrContainer,
63+
ArrayOrContainerOrScalar,
64+
)
6165

6266

6367
if TYPE_CHECKING:

arraycontext/container/traversal.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,26 @@
3333
.. autofunction:: unflatten
3434
.. autofunction:: flat_size_and_dtype
3535
36-
Numpy conversion
37-
~~~~~~~~~~~~~~~~
38-
.. autofunction:: from_numpy
39-
.. autofunction:: to_numpy
40-
4136
Algebraic operations
4237
~~~~~~~~~~~~~~~~~~~~
4338
.. autofunction:: outer
39+
40+
.. currentmodule:: arraycontext.traversal
41+
42+
References
43+
----------
44+
45+
.. class:: ArrayOrScalar
46+
47+
See :class:`arraycontext.ArrayOrScalar`.
48+
49+
.. class:: ArrayOrContainer
50+
51+
See :class:`arraycontext.ArrayOrContainer`.
52+
53+
.. class:: ArrayContainerT
54+
55+
See :class:`arraycontext.ArrayContainerT`.
4456
"""
4557

4658
from __future__ import annotations
@@ -84,24 +96,27 @@
8496
)
8597

8698
from arraycontext.container import (
87-
ArrayContainer,
88-
ArrayContainerT,
8999
NotAnArrayContainerError,
90100
SerializationKey,
91101
deserialize_container,
92102
get_container_context_recursively_opt,
93103
is_array_container,
94104
serialize_container,
95105
)
96-
from arraycontext.context import is_scalar_like, shape_is_int_only
106+
from arraycontext.typing import (
107+
ArrayContainer,
108+
ArrayContainerT,
109+
is_scalar_like,
110+
shape_is_int_only,
111+
)
97112

98113

99114
if TYPE_CHECKING:
100115
from collections.abc import Callable, Collection, Iterable
101116

102-
from arraycontext.context import (
117+
from arraycontext.context import ArrayContext
118+
from arraycontext.typing import (
103119
Array,
104-
ArrayContext,
105120
ArrayOrContainer,
106121
ArrayOrContainerOrScalar,
107122
ArrayOrContainerOrScalarT,

0 commit comments

Comments
 (0)