Skip to content

Commit dc62cdc

Browse files
Merge branch 'main' into cupyactx
2 parents ebd4828 + 19bb4fe commit dc62cdc

14 files changed

Lines changed: 610 additions & 3941 deletions

File tree

.basedpyright/baseline.json

Lines changed: 398 additions & 3806 deletions
Large diffs are not rendered by default.

.github/workflows/ci.yml

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,6 @@ jobs:
3131
pip install ruff
3232
ruff check
3333
34-
mypy:
35-
name: Mypy
36-
runs-on: ubuntu-latest
37-
steps:
38-
- uses: actions/checkout@v4
39-
- name: "Main Script"
40-
run: |
41-
set -x
42-
USE_CONDA_BUILD=1
43-
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
44-
echo "- cupy" >> "$CONDA_ENVIRONMENT"
45-
46-
curl -L -O https://tiker.net/ci-support-v0
47-
. ./ci-support-v0
48-
49-
build_py_project_in_conda_env
50-
python -m pip install mypy pytest
51-
./run-mypy.sh
52-
5334
basedpyright:
5435
runs-on: ubuntu-latest
5536
steps:

.gitlab-ci.yml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,6 @@ Ruff:
115115
except:
116116
- tags
117117

118-
Mypy:
119-
script: |
120-
EXTRA_INSTALL="mypy pytest"
121-
122-
curl -L -O https://tiker.net/ci-support-v0
123-
. ./ci-support-v0
124-
125-
build_py_project_in_venv
126-
./run-mypy.sh
127-
tags:
128-
- python3
129-
except:
130-
- tags
131-
132118
Downstream:
133119
parallel:
134120
matrix:

arraycontext/container/traversal.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777

7878
import numpy as np
7979

80+
from pymbolic.typing import Integer
81+
8082
from arraycontext.container import (
8183
ArrayContainer,
8284
NotAnArrayContainerError,
@@ -91,7 +93,6 @@
9193
ArrayOrContainer,
9294
ArrayOrContainerOrScalar,
9395
ArrayOrContainerT,
94-
ArrayT,
9596
ScalarLike,
9697
)
9798

@@ -400,21 +401,20 @@ def keyed_map_array_container(
400401

401402

402403
def rec_keyed_map_array_container(
403-
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
404+
f: Callable[[tuple[SerializationKey, ...], Array], Array],
404405
ary: ArrayOrContainer) -> ArrayOrContainer:
405406
"""
406407
Works similarly to :func:`rec_map_array_container`, except that *f* also
407408
takes in a traversal path to the leaf array. The traversal path argument is
408409
passed in as a tuple of identifiers of the arrays traversed before reaching
409410
the current array.
410411
"""
411-
412412
def rec(keys: tuple[SerializationKey, ...],
413-
ary_: ArrayOrContainerT) -> ArrayOrContainerT:
413+
ary_: ArrayOrContainer) -> ArrayOrContainer:
414414
try:
415415
iterable = serialize_container(ary_)
416416
except NotAnArrayContainerError:
417-
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
417+
return cast(ArrayOrContainer, f(keys, cast(Array, ary_)))
418418
else:
419419
return deserialize_container(ary_, [
420420
(key, rec((*keys, key), subary)) for key, subary in iterable
@@ -777,7 +777,7 @@ def unflatten(
777777
checks are skipped.
778778
"""
779779
# NOTE: https://github.com/python/mypy/issues/7057
780-
offset = 0
780+
offset: int = 0
781781
common_dtype = None
782782

783783
def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
@@ -790,7 +790,11 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
790790

791791
# {{{ validate subary
792792

793-
if (offset + template_subary_c.size) > ary.size:
793+
if (
794+
isinstance(offset, Integer)
795+
and isinstance(template_subary_c.size, Integer)
796+
and isinstance(ary.size, Integer)
797+
and (offset + template_subary_c.size) > ary.size):
794798
raise ValueError("'template' and 'ary' sizes do not match: "
795799
"'template' is too large") from None
796800

@@ -813,6 +817,12 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
813817

814818
# {{{ reshape
815819

820+
if not isinstance(template_subary_c.size, Integer):
821+
raise NotImplementedError(
822+
"unflatten is not implemented for arrays with array-valued "
823+
"size.") from None
824+
825+
# FIXME: Not sure how to make the slicing part work for Array-valued sizes
816826
flat_subary = ary[offset:offset + template_subary_c.size]
817827
try:
818828
subary = actx.np.reshape(flat_subary,
@@ -871,15 +881,15 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
871881

872882

873883
def flat_size_and_dtype(
874-
ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
884+
ary: ArrayOrContainer) -> tuple[Array | Integer, np.dtype[Any] | None]:
875885
"""
876886
:returns: a tuple ``(size, dtype)`` that would be the length and
877887
:class:`numpy.dtype` of the one-dimensional array returned by
878888
:func:`flatten`.
879889
"""
880890
common_dtype = None
881891

882-
def _flat_size(subary: ArrayOrContainer) -> int:
892+
def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
883893
nonlocal common_dtype
884894

885895
try:

arraycontext/context.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
import numpy as np
175175
from typing_extensions import Self
176176

177+
from pymbolic.typing import Integer, Scalar as _Scalar
177178
from pytools import memoize_method
178179
from pytools.tag import ToTagSetConvertible
179180

@@ -202,11 +203,11 @@ class Array(Protocol):
202203
"""
203204

204205
@property
205-
def shape(self) -> tuple[int, ...]:
206+
def shape(self) -> tuple[Array | Integer, ...]:
206207
...
207208

208209
@property
209-
def size(self) -> int:
210+
def size(self) -> Array | Integer:
210211
...
211212

212213
@property
@@ -220,22 +221,27 @@ def dtype(self) -> np.dtype[Any]:
220221
def __getitem__(self, index: Any) -> Array:
221222
...
222223

223-
# some basic arithmetic that's supposed to work
224-
def __neg__(self) -> Self: ...
225-
def __abs__(self) -> Self: ...
226-
def __add__(self, other: Self | ScalarLike) -> Self: ...
227-
def __radd__(self, other: Self | ScalarLike) -> Self: ...
228-
def __sub__(self, other: Self | ScalarLike) -> Self: ...
229-
def __rsub__(self, other: Self | ScalarLike) -> Self: ...
230-
def __mul__(self, other: Self | ScalarLike) -> Self: ...
231-
def __rmul__(self, other: Self | ScalarLike) -> Self: ...
232-
def __truediv__(self, other: Self | ScalarLike) -> Self: ...
233-
def __rtruediv__(self, other: Self | ScalarLike) -> Self: ...
224+
# Some basic arithmetic that's supposed to work
225+
# Need to return Array instead of Self because for some array types, arithmetic
226+
# operations on one subtype may result in a different subtype.
227+
# For example, pytato arrays: <Placeholder> + 1 -> <IndexLambda>
228+
def __neg__(self) -> Array: ...
229+
def __abs__(self) -> Array: ...
230+
def __add__(self, other: Self | ScalarLike) -> Array: ...
231+
def __radd__(self, other: Self | ScalarLike) -> Array: ...
232+
def __sub__(self, other: Self | ScalarLike) -> Array: ...
233+
def __rsub__(self, other: Self | ScalarLike) -> Array: ...
234+
def __mul__(self, other: Self | ScalarLike) -> Array: ...
235+
def __rmul__(self, other: Self | ScalarLike) -> Array: ...
236+
def __pow__(self, other: Self | ScalarLike) -> Array: ...
237+
def __rpow__(self, other: Self | ScalarLike) -> Array: ...
238+
def __truediv__(self, other: Self | ScalarLike) -> Array: ...
239+
def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...
234240

235241

236242
# deprecated, use ScalarLike instead
237-
ScalarLike: TypeAlias = int | float | complex | np.generic
238-
Scalar = ScalarLike
243+
Scalar = _Scalar
244+
ScalarLike = Scalar
239245
ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)
240246

241247
# NOTE: I'm kind of not sure about the *Tc versions of these type variables.

arraycontext/impl/pyopencl/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"""
3333

3434
from collections.abc import Callable
35-
from typing import TYPE_CHECKING
35+
from typing import TYPE_CHECKING, Literal
3636
from warnings import warn
3737

3838
import numpy as np
@@ -51,7 +51,8 @@
5151

5252
if TYPE_CHECKING:
5353
import loopy as lp
54-
import pyopencl
54+
import pyopencl as cl
55+
import pyopencl.array as cl_array
5556

5657

5758
# {{{ PyOpenCLArrayContext
@@ -81,9 +82,17 @@ class PyOpenCLArrayContext(ArrayContext):
8182
.. automethod:: transform_loopy_program
8283
"""
8384

85+
context: cl.Context
86+
queue: cl.CommandQueue
87+
allocator: cl_array.Allocator | None
88+
89+
_force_device_scalars: Literal[True]
90+
_passed_force_device_scalars: bool
91+
_wait_event_queue_length: int
92+
8493
def __init__(self,
85-
queue: pyopencl.CommandQueue,
86-
allocator: pyopencl.tools.AllocatorBase | None = None,
94+
queue: cl.CommandQueue,
95+
allocator: cl_array.Allocator | None = None,
8796
wait_event_queue_length: int | None = None,
8897
force_device_scalars: bool | None = None) -> None:
8998
r"""

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
rec_multimap_array_container,
4242
rec_multimap_reduce_array_container,
4343
)
44-
from arraycontext.context import Array, ArrayOrContainer
44+
from arraycontext.context import Array as actx_Array, ArrayOrContainer
4545
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
4646
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
4747
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
@@ -206,7 +206,7 @@ def _any(ary):
206206
_any,
207207
a)
208208

209-
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
209+
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> actx_Array:
210210
actx = self._array_context
211211
queue = actx.queue
212212

arraycontext/impl/pyopencl/taggable_cl_array.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
from __future__ import annotations
88

99
from dataclasses import dataclass
10-
from typing import Any
10+
from typing import Any, Literal
1111

1212
import numpy as np
13+
from numpy.typing import DTypeLike
1314

15+
import pyopencl as cl
1416
import pyopencl.array as cla
1517
from pytools import memoize
1618
from pytools.tag import Tag, Taggable, ToTagSetConvertible
@@ -74,6 +76,9 @@ class TaggableCLArray(cla.Array, Taggable):
7476
record application-specific metadata to drive the optimizations in
7577
:meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`.
7678
"""
79+
tags: frozenset[Tag]
80+
axes: tuple[Axis, ...]
81+
7782
def __init__(self, cq, shape, dtype, order="C", allocator=None,
7883
data=None, offset=0, strides=None, events=None, _flags=None,
7984
_fast=False, _size=None, _context=None, _queue=None,
@@ -165,13 +170,20 @@ def to_tagged_cl_array(ary: cla.Array,
165170
# }}}
166171

167172

173+
_EMPTY_TAG_SET: frozenset[Tag] = frozenset()
174+
175+
168176
# {{{ creation
169177

170-
def empty(queue, shape, dtype=float, *,
171-
axes: tuple[Axis, ...] | None = None,
172-
tags: frozenset[Tag] = frozenset(),
173-
order: str = "C",
174-
allocator=None) -> TaggableCLArray:
178+
def empty(
179+
queue: cl.CommandQueue,
180+
shape: tuple[int, ...] | int,
181+
dtype: DTypeLike = float,
182+
*, axes: tuple[Axis, ...] | None = None,
183+
tags: frozenset[Tag] = _EMPTY_TAG_SET,
184+
order: Literal["C"] | Literal["F"] = "C",
185+
allocator: cla.Allocator | None = None,
186+
) -> TaggableCLArray:
175187
if dtype is not None:
176188
dtype = np.dtype(dtype)
177189

@@ -181,11 +193,15 @@ def empty(queue, shape, dtype=float, *,
181193
order=order, allocator=allocator)
182194

183195

184-
def zeros(queue, shape, dtype=float, *,
185-
axes: tuple[Axis, ...] | None = None,
186-
tags: frozenset[Tag] = frozenset(),
187-
order: str = "C",
188-
allocator=None) -> TaggableCLArray:
196+
def zeros(
197+
queue: cl.CommandQueue,
198+
shape: tuple[int, ...] | int,
199+
dtype: DTypeLike = float,
200+
*, axes: tuple[Axis, ...] | None = None,
201+
tags: frozenset[Tag] = _EMPTY_TAG_SET,
202+
order: Literal["C"] | Literal["F"] = "C",
203+
allocator: cla.Allocator | None = None,
204+
) -> TaggableCLArray:
189205
result = empty(
190206
queue, shape, dtype=dtype, axes=axes, tags=tags,
191207
order=order, allocator=allocator)
@@ -194,10 +210,13 @@ def zeros(queue, shape, dtype=float, *,
194210
return result
195211

196212

197-
def to_device(queue, ary, *,
198-
axes: tuple[Axis, ...] | None = None,
199-
tags: frozenset[Tag] = frozenset(),
200-
allocator=None):
213+
def to_device(
214+
queue: cl.CommandQueue,
215+
ary: np.ndarray[Any],
216+
*, axes: tuple[Axis, ...] | None = None,
217+
tags: frozenset[Tag] = _EMPTY_TAG_SET,
218+
allocator: cla.Allocator | None = None,
219+
) -> TaggableCLArray:
201220
return to_tagged_cl_array(
202221
cla.to_device(queue, ary, allocator=allocator),
203222
axes=axes, tags=tags)

0 commit comments

Comments
 (0)