Skip to content

Commit acd9eef

Browse files
committed
Towards typing of outlining
1 parent bef8ad2 commit acd9eef

3 files changed

Lines changed: 66 additions & 41 deletions

File tree

arraycontext/container/traversal.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ArrayOrContainer,
9292
ArrayOrContainerOrScalar,
9393
ArrayOrContainerT,
94+
ArrayOrContainerTc,
9495
ArrayT,
9596
ScalarLike,
9697
)
@@ -399,28 +400,33 @@ def keyed_map_array_container(
399400
])
400401

401402

403+
def _rec_keyed_map_array_container_rec(
404+
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
405+
keys: tuple[SerializationKey, ...],
406+
ary_: ArrayOrContainerT
407+
) -> ArrayOrContainerT:
408+
try:
409+
iterable = serialize_container(ary_)
410+
except NotAnArrayContainerError:
411+
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
412+
else:
413+
return deserialize_container(ary_, [
414+
(key, _rec_keyed_map_array_container_rec(
415+
f, (*keys, key), subary)) for key, subary in iterable
416+
])
417+
418+
402419
def rec_keyed_map_array_container(
403420
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
404-
ary: ArrayOrContainer) -> ArrayOrContainer:
421+
ary: ArrayOrContainerT) -> ArrayOrContainerT:
405422
"""
406423
Works similarly to :func:`rec_map_array_container`, except that *f* also
407424
takes in a traversal path to the leaf array. The traversal path argument is
408425
passed in as a tuple of identifiers of the arrays traversed before reaching
409426
the current array.
410427
"""
411428

412-
def rec(keys: tuple[SerializationKey, ...],
413-
ary_: ArrayOrContainerT) -> ArrayOrContainerT:
414-
try:
415-
iterable = serialize_container(ary_)
416-
except NotAnArrayContainerError:
417-
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
418-
else:
419-
return deserialize_container(ary_, [
420-
(key, rec((*keys, key), subary)) for key, subary in iterable
421-
])
422-
423-
return rec((), ary)
429+
return _rec_keyed_map_array_container_rec(f, (), ary)
424430

425431
# }}}
426432

arraycontext/impl/pytato/outline.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,36 @@
3131
import itertools
3232
from collections.abc import Callable, Mapping
3333
from dataclasses import dataclass
34-
from typing import Any
34+
from typing import cast
3535

3636
import numpy as np
3737
from immutabledict import immutabledict
3838

3939
import pytato as pt
40+
from pymbolic import Scalar
4041
from pytools.tag import Tag
4142

4243
from arraycontext.container import is_array_container_type
4344
from arraycontext.container.traversal import rec_keyed_map_array_container
44-
from arraycontext.context import ArrayOrContainer, ArrayT
45+
from arraycontext.context import (
46+
Array,
47+
ArrayOrContainer,
48+
ArrayOrContainerTc,
49+
ArrayT,
50+
)
4551
from arraycontext.impl.pytato import _BasePytatoArrayContext
4652

4753

48-
def _get_arg_id_to_arg(args: tuple[Any, ...],
49-
kwargs: Mapping[str, Any]
50-
) -> immutabledict[tuple[Any, ...], Any]:
54+
def _get_arg_id_to_arg(args: tuple[object, ...],
55+
kwargs: Mapping[str, object]
56+
) -> immutabledict[tuple[object, ...], object]:
5157
"""
5258
Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
5359
to argument values. See
5460
:attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
5561
representation.
5662
"""
57-
arg_id_to_arg: dict[tuple[Any, ...], Any] = {}
63+
arg_id_to_arg: dict[tuple[object, ...], object] = {}
5864

5965
for kw, arg in itertools.chain(enumerate(args),
6066
kwargs.items()):
@@ -64,7 +70,7 @@ def _get_arg_id_to_arg(args: tuple[Any, ...],
6470
# do not make scalars as placeholders since we inline them.
6571
pass
6672
elif is_array_container_type(arg.__class__):
67-
def id_collector(keys: tuple[Any, ...], ary: ArrayT) -> ArrayT:
73+
def id_collector(keys: tuple[object, ...], ary: ArrayT) -> ArrayT:
6874
if np.isscalar(ary):
6975
pass
7076
else:
@@ -85,21 +91,21 @@ def id_collector(keys: tuple[Any, ...], ary: ArrayT) -> ArrayT:
8591

8692

8793
def _get_input_arg_id_str(
88-
arg_id: tuple[Any, ...], prefix: str | None = None) -> str:
94+
arg_id: tuple[object, ...], prefix: str | None = None) -> str:
8995
if prefix is None:
9096
prefix = ""
9197
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier
9298
return f"_actx_{prefix}_in_{_ary_container_key_stringifier(arg_id)}"
9399

94100

95-
def _get_output_arg_id_str(arg_id: tuple[Any, ...]) -> str:
101+
def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str:
96102
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier
97103
return f"_actx_out_{_ary_container_key_stringifier(arg_id)}"
98104

99105

100106
def _get_arg_id_to_placeholder(
101-
arg_id_to_arg: Mapping[tuple[Any, ...], Any],
102-
prefix: str | None = None) -> immutabledict[tuple[Any, ...], pt.Placeholder]:
107+
arg_id_to_arg: Mapping[tuple[object, ...], object],
108+
prefix: str | None = None) -> immutabledict[tuple[object, ...], pt.Placeholder]:
103109
"""
104110
Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder`
105111
for each argument in *arg_id_to_arg*. See
@@ -115,27 +121,26 @@ def _get_arg_id_to_placeholder(
115121

116122

117123
def _call_with_placeholders(
118-
f: Callable[..., Any],
119-
args: tuple[Any],
120-
kwargs: Mapping[str, Any],
121-
arg_id_to_placeholder: Mapping[tuple[Any, ...], pt.Placeholder]) -> Any:
124+
f: Callable[..., object],
125+
args: tuple[object],
126+
kwargs: Mapping[str, object],
127+
arg_id_to_placeholder: Mapping[tuple[object, ...], pt.Placeholder]) -> object:
122128
"""
123129
Construct placeholders analogous to *args* and *kwargs* and call *f*.
124130
"""
125131
def get_placeholder_replacement(
126-
arg: ArrayOrContainer | None, key: tuple[Any, ...]
127-
) -> ArrayOrContainer | None:
132+
arg: ArrayOrContainerTc | Scalar | None, key: tuple[object, ...]
133+
) -> ArrayOrContainerTc | Scalar | None:
128134
if arg is None:
129135
return None
130136
elif np.isscalar(arg):
131-
return arg
137+
return cast(Scalar, arg)
132138
elif isinstance(arg, pt.Array):
133-
return arg_id_to_placeholder[key]
139+
return cast(ArrayOrContainerTc, arg_id_to_placeholder[key])
134140
elif is_array_container_type(arg.__class__):
135-
def _rec_to_placeholder(keys: tuple[Any, ...], ary: pt.Array) -> pt.Array:
141+
def _rec_to_placeholder(keys: tuple[object, ...], ary: ArrayT) -> ArrayT:
136142
result = get_placeholder_replacement(ary, key + keys)
137-
assert isinstance(result, pt.Array)
138-
return result
143+
return cast(ArrayT, result)
139144

140145
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
141146
else:
@@ -157,7 +162,7 @@ def _unpack_output(
157162
elif is_array_container_type(output.__class__):
158163
unpacked_output = {}
159164

160-
def _unpack_container(key: tuple[Any, ...], ary: ArrayT) -> ArrayT:
165+
def _unpack_container(key: tuple[object, ...], ary: ArrayT) -> ArrayT:
161166
key_str = _get_output_arg_id_str(key)
162167
unpacked_output[key_str] = ary
163168
return ary
@@ -171,7 +176,7 @@ def _unpack_container(key: tuple[Any, ...], ary: ArrayT) -> ArrayT:
171176

172177
def _pack_output(
173178
output_template: ArrayOrContainer,
174-
unpacked_output: pt.Array | immutabledict[str, pt.Array]
179+
unpacked_output: Array | immutabledict[str, Array]
175180
) -> ArrayOrContainer:
176181
"""
177182
Pack *unpacked_output* into array containers according to *output_template*.
@@ -182,7 +187,7 @@ def _pack_output(
182187
elif is_array_container_type(output_template.__class__):
183188
assert isinstance(unpacked_output, immutabledict)
184189

185-
def _pack_into_container(key: tuple[Any, ...], ary: pt.Array) -> pt.Array:
190+
def _pack_into_container(key: tuple[object, ...], ary: Array) -> Array:
186191
key_str = _get_output_arg_id_str(key)
187192
return unpacked_output[key_str]
188193

@@ -194,10 +199,10 @@ def _pack_into_container(key: tuple[Any, ...], ary: pt.Array) -> pt.Array:
194199
@dataclass(frozen=True)
195200
class OutlinedCall:
196201
actx: _BasePytatoArrayContext
197-
f: Callable[..., Any]
202+
f: Callable[..., object]
198203
tags: frozenset[Tag]
199204

200-
def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
205+
def __call__(self, *args: object, **kwargs: object) -> ArrayOrContainer:
201206
arg_id_to_arg = _get_arg_id_to_arg(args, kwargs)
202207

203208
from .utils import _verify_is_dag
@@ -263,7 +268,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
263268
call_site_output = func_def(**call_bindings)
264269

265270
assert isinstance(call_site_output, pt.Array | immutabledict)
266-
return _pack_output(output, call_site_output)
271+
# FIXME: pt.Array is not an actx Array
272+
return _pack_output(cast("Array | immutabledict[str, Array]", output),
273+
cast("Array | immutabledict[str, Array]", call_site_output))
267274

268275

269276
# vim: foldmethod=marker

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ reportImplicitStringConcatenation = "none"
138138
reportUnnecessaryIsInstance = "none"
139139
reportUnusedCallResult = "none"
140140
reportExplicitAny = "none"
141+
reportPrivateUsage = "hint"
141142

142143
# This reports even cycles that are qualified by 'if TYPE_CHECKING'. Not what
143144
# we care about at this moment.
@@ -157,3 +158,14 @@ reportPrivateUsage = "none"
157158
reportMissingTypeStubs = "hint"
158159
reportAny = "hint"
159160

161+
[[tool.basedpyright.executionEnvironments]]
162+
root = "examples"
163+
reportUnknownArgumentType = "hint"
164+
reportUnknownMemberType = "hint"
165+
reportUnknownVariableType = "hint"
166+
reportUnknownParameterType = "hint"
167+
reportMissingTypeArgument = "hint"
168+
reportPrivateUsage = "none"
169+
reportMissingTypeStubs = "hint"
170+
reportAny = "hint"
171+

0 commit comments

Comments
 (0)