Skip to content

Commit 6de41e4

Browse files
committed
more typing work
1 parent 62e3586 commit 6de41e4

3 files changed

Lines changed: 10 additions & 77 deletions

File tree

.basedpyright/baseline.json

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14607,22 +14607,6 @@
1460714607
"lineCount": 1
1460814608
}
1460914609
},
14610-
{
14611-
"code": "reportPrivateUsage",
14612-
"range": {
14613-
"startColumn": 53,
14614-
"endColumn": 58,
14615-
"lineCount": 1
14616-
}
14617-
},
14618-
{
14619-
"code": "reportPrivateUsage",
14620-
"range": {
14621-
"startColumn": 53,
14622-
"endColumn": 58,
14623-
"lineCount": 1
14624-
}
14625-
},
1462614610
{
1462714611
"code": "reportUnknownMemberType",
1462814612
"range": {
@@ -14903,22 +14887,6 @@
1490314887
"lineCount": 1
1490414888
}
1490514889
},
14906-
{
14907-
"code": "reportPrivateUsage",
14908-
"range": {
14909-
"startColumn": 53,
14910-
"endColumn": 58,
14911-
"lineCount": 1
14912-
}
14913-
},
14914-
{
14915-
"code": "reportPrivateUsage",
14916-
"range": {
14917-
"startColumn": 53,
14918-
"endColumn": 58,
14919-
"lineCount": 1
14920-
}
14921-
},
1492214890
{
1492314891
"code": "reportUnknownMemberType",
1492414892
"range": {
@@ -20891,38 +20859,6 @@
2089120859
"lineCount": 1
2089220860
}
2089320861
},
20894-
{
20895-
"code": "reportAttributeAccessIssue",
20896-
"range": {
20897-
"startColumn": 37,
20898-
"endColumn": 54,
20899-
"lineCount": 1
20900-
}
20901-
},
20902-
{
20903-
"code": "reportAttributeAccessIssue",
20904-
"range": {
20905-
"startColumn": 25,
20906-
"endColumn": 29,
20907-
"lineCount": 1
20908-
}
20909-
},
20910-
{
20911-
"code": "reportAttributeAccessIssue",
20912-
"range": {
20913-
"startColumn": 25,
20914-
"endColumn": 29,
20915-
"lineCount": 1
20916-
}
20917-
},
20918-
{
20919-
"code": "reportAttributeAccessIssue",
20920-
"range": {
20921-
"startColumn": 38,
20922-
"endColumn": 55,
20923-
"lineCount": 1
20924-
}
20925-
},
2092620862
{
2092720863
"code": "reportMissingParameterType",
2092820864
"range": {

arraycontext/impl/pytato/__init__.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,10 @@ def __init__(
152152
super().__init__()
153153

154154
import pytato as pt
155-
self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
155+
self._freeze_prg_cache: dict[
156+
pt.AbstractResultWithNamedArrays, lp.TranslationUnit] = {}
156157
self._dag_transform_cache: dict[
157-
pt.DictOfNamedArrays,
158+
pt.AbstractResultWithNamedArrays,
158159
tuple[pt.AbstractResultWithNamedArrays, str]] = {}
159160

160161
if compile_trace_callback is None:
@@ -600,12 +601,10 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
600601
dag = Deduplicator()(dag)
601602

602603
# FIXME: Remove this if/when _normalize_pt_expr gets support for functions
603-
dag = pt.tag_all_calls_to_be_inlined(
604-
dag)
604+
dag = pt.tag_all_calls_to_be_inlined(dag)
605605
dag = pt.inline_calls(dag)
606606

607-
normalized_expr, bound_arguments = _normalize_pt_expr(
608-
dag)
607+
normalized_expr, bound_arguments = _normalize_pt_expr(dag)
609608

610609
try:
611610
pt_prg = self._freeze_prg_cache[normalized_expr]
@@ -760,10 +759,10 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
760759
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
761760
) -> pytato.AbstractResultWithNamedArrays:
762761
import pytato as pt
763-
tdag = pt.tag_all_calls_to_be_inlined(dag)
764-
tdag = pt.inline_calls(tdag)
765-
tdag = pt.transform.materialize_with_mpms(tdag)
766-
return tdag
762+
dag = pt.tag_all_calls_to_be_inlined(dag)
763+
dag = pt.inline_calls(dag)
764+
dag = pt.transform.materialize_with_mpms(dag)
765+
return dag
767766

768767
def einsum(self, spec, *args, arg_names=None, tagged=()):
769768
import pytato as pt

arraycontext/impl/pytato/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
Array,
5151
Axis as PtAxis,
5252
DataWrapper,
53-
DictOfNamedArrays,
5453
Placeholder,
5554
SizeParam,
5655
make_placeholder,
@@ -138,7 +137,7 @@ def map_function_definition(
138137
# definitions can't contain non-argument placeholders
139138
def _normalize_pt_expr(
140139
expr: AbstractResultWithNamedArrays
141-
) -> tuple[DictOfNamedArrays, Mapping[str, Any]]:
140+
) -> tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]:
142141
"""
143142
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
144143
normalized form of *expr*, with all instances of
@@ -157,7 +156,6 @@ def _normalize_pt_expr(
157156

158157
normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
159158
normalized_expr = normalize_mapper(expr)
160-
assert isinstance(normalized_expr, DictOfNamedArrays)
161159
return normalized_expr, normalize_mapper.bound_arguments
162160

163161

0 commit comments

Comments
 (0)