Skip to content

Commit 245f8ed

Browse files
committed
Add outlining pass to array expression
1 parent 7f2077b commit 245f8ed

3 files changed

Lines changed: 206 additions & 0 deletions

File tree

arraycontext/context.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ class ArrayContext(ABC):
273273
.. automethod:: tag
274274
.. automethod:: tag_axis
275275
.. automethod:: compile
276+
.. automethod:: outline
276277
"""
277278

278279
array_types: Tuple[type, ...] = ()
@@ -524,6 +525,25 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
524525
"""
525526
return f
526527

528+
def outline(self,
529+
f: Callable[..., Any],
530+
name: Optional[str] = None) -> Callable[..., Any]:
531+
"""
532+
Returns a drop-in-replacement for *f*. The behavior of the returned
533+
callable is specific to the derived class.
534+
535+
The reason for the existence of such a routine is mainly for
536+
arraycontexts that allow a lazy mode of execution. In such
537+
arraycontexts, the computations within *f* maybe staged to potentially
538+
enable additional compiler transformations. See
539+
:func:`pytato.functions.trace_call` or :func:`jax.named_call` for
540+
examples.
541+
542+
:arg f: the function executing the computation to be staged.
543+
:return: a function with the same signature as *f*.
544+
"""
545+
return f
546+
527547
# undocumented for now
528548
@property
529549
@abstractmethod

arraycontext/impl/pytato/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,20 @@ def get_target(self):
216216

217217
# }}}
218218

219+
def outline(self,
220+
f: Callable[..., Any],
221+
name: Optional[str] = None,
222+
tags: FrozenSet[Tag] = frozenset()
223+
) -> Callable[..., Any]:
224+
from pytato.tags import FunctionIdentifier
225+
226+
from .outline import OutlinedCall
227+
name = name or getattr(f, "__name__", None)
228+
if name is not None:
229+
tags = tags | {FunctionIdentifier(name)}
230+
231+
return OutlinedCall(self, f, tags)
232+
219233
# }}}
220234

221235

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
.. autoclass:: OutlinedCall
3+
"""
4+
__copyright__ = """
5+
Copyright (C) 2023 Kaushik Kulkarni
6+
"""
7+
8+
__license__ = """
9+
Permission is hereby granted, free of charge, to any person obtaining a copy
10+
of this software and associated documentation files (the "Software"), to deal
11+
in the Software without restriction, including without limitation the rights
12+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
copies of the Software, and to permit persons to whom the Software is
14+
furnished to do so, subject to the following conditions:
15+
16+
The above copyright notice and this permission notice shall be included in
17+
all copies or substantial portions of the Software.
18+
19+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25+
THE SOFTWARE.
26+
"""
27+
28+
import itertools
29+
from dataclasses import dataclass
30+
from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple
31+
32+
import numpy as np
33+
from immutables import Map
34+
35+
import pytato as pt
36+
from pytools.tag import Tag
37+
38+
from arraycontext.container import is_array_container_type
39+
from arraycontext.container.traversal import rec_keyed_map_array_container
40+
from arraycontext.context import ArrayOrContainer
41+
from arraycontext.impl.pytato import _BasePytatoArrayContext
42+
43+
44+
def _get_arg_id_to_arg(args: Tuple[Any, ...],
45+
kwargs: Mapping[str, Any]
46+
) -> Map[Tuple[Any, ...], Any]:
47+
"""
48+
Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id
49+
to argument values. See
50+
:attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's
51+
representation.
52+
"""
53+
arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {}
54+
55+
for kw, arg in itertools.chain(enumerate(args),
56+
kwargs.items()):
57+
if np.isscalar(arg):
58+
# do not make scalars as placeholders since we inline them.
59+
pass
60+
elif is_array_container_type(arg.__class__):
61+
def id_collector(keys, ary):
62+
arg_id = (kw,) + keys # noqa: B023
63+
arg_id_to_arg[arg_id] = ary # noqa: B023
64+
return ary
65+
66+
rec_keyed_map_array_container(id_collector, arg)
67+
elif isinstance(arg, pt.Array):
68+
arg_id = (kw,)
69+
arg_id_to_arg[arg_id] = arg
70+
else:
71+
raise ValueError("Argument to a compiled operator should be"
72+
" either a scalar, pt.Array or an array container. Got"
73+
f" '{arg}'.")
74+
75+
return Map(arg_id_to_arg)
76+
77+
78+
def _get_placeholder_replacement(arg, kw, arg_id_to_name):
79+
"""
80+
Helper for :class:`OutlinedCall.__call__`. Returns the placeholder version
81+
of an argument to :attr:`OutlinedCall.f`.
82+
"""
83+
if np.isscalar(arg):
84+
return arg
85+
elif isinstance(arg, pt.Array):
86+
name = arg_id_to_name[(kw,)]
87+
return pt.make_placeholder(name, arg.shape, arg.dtype)
88+
elif is_array_container_type(arg.__class__):
89+
def _rec_to_placeholder(keys, ary):
90+
name = arg_id_to_name[(kw,) + keys]
91+
return pt.make_placeholder(name,
92+
ary.shape,
93+
ary.dtype)
94+
95+
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
96+
else:
97+
raise NotImplementedError(type(arg))
98+
99+
100+
def _get_input_arg_id_str(arg_id: Tuple[Any, ...]) -> str:
101+
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier
102+
return f"_actx_in_{_ary_container_key_stringifier(arg_id)}"
103+
104+
105+
def _get_output_arg_id_str(arg_id: Tuple[Any, ...]) -> str:
106+
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier
107+
return f"_actx_out_{_ary_container_key_stringifier(arg_id)}"
108+
109+
110+
@dataclass(frozen=True)
111+
class OutlinedCall:
112+
actx: _BasePytatoArrayContext
113+
f: Callable[..., Any]
114+
tags: FrozenSet[Tag]
115+
116+
def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
117+
arg_id_to_arg = _get_arg_id_to_arg(args, kwargs)
118+
input_id_to_name_in_function = {arg_id: _get_input_arg_id_str(arg_id)
119+
for arg_id in arg_id_to_arg}
120+
121+
pl_args = [_get_placeholder_replacement(arg, iarg,
122+
input_id_to_name_in_function)
123+
for iarg, arg in enumerate(args)]
124+
pl_kwargs = {kw: _get_placeholder_replacement(arg, kw,
125+
input_id_to_name_in_function)
126+
for kw, arg in kwargs.items()}
127+
128+
output = self.f(*pl_args, **pl_kwargs)
129+
130+
if isinstance(output, pt.Array):
131+
returns = {"_": output}
132+
ret_type = pt.function.ReturnType.ARRAY
133+
elif is_array_container_type(output.__class__):
134+
returns = {}
135+
136+
def _unpack_container(key, ary):
137+
key = _get_output_arg_id_str(key)
138+
returns[key] = ary
139+
return ary
140+
141+
rec_keyed_map_array_container(_unpack_container, output)
142+
ret_type = pt.function.ReturnType.DICT_OF_ARRAYS
143+
else:
144+
raise NotImplementedError(type(output))
145+
146+
func_def = pt.function.FunctionDefinition(
147+
frozenset(input_id_to_name_in_function.values()),
148+
ret_type,
149+
Map(returns),
150+
tags=self.tags,
151+
)
152+
153+
call_parameters = {input_id_to_name_in_function[arg_id]: arg
154+
for arg_id, arg in arg_id_to_arg.items()}
155+
call_site_output = func_def(**call_parameters)
156+
157+
if isinstance(output, pt.Array):
158+
return call_site_output
159+
elif is_array_container_type(output.__class__):
160+
def _pack_into_container(key, ary):
161+
key = _get_output_arg_id_str(key)
162+
return call_site_output[key]
163+
164+
call_site_output_as_container = rec_keyed_map_array_container(
165+
_pack_into_container,
166+
output)
167+
return call_site_output_as_container
168+
else:
169+
raise NotImplementedError(type(output))
170+
171+
172+
# vim: foldmethod=marker

0 commit comments

Comments
 (0)