Skip to content

Commit 0a02da1

Browse files
committed
add PytatoParallelPyOpenCLArrayContext
1 parent 8762568 commit 0a02da1

7 files changed

Lines changed: 1122 additions & 5 deletions

File tree

arraycontext/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@
8080
from .impl.jax import EagerJAXArrayContext
8181
from .impl.numpy import NumpyArrayContext
8282
from .impl.pyopencl import PyOpenCLArrayContext
83-
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
83+
from .impl.pytato import (
84+
PytatoJAXArrayContext,
85+
PytatoParallelPyOpenCLArrayContext,
86+
PytatoPyOpenCLArrayContext,
87+
)
8488
from .loopy import make_loopy_program
8589
from .pytest import (
8690
PytestArrayContextFactory,
@@ -140,6 +144,7 @@
140144
"NumpyArrayContext",
141145
"PyOpenCLArrayContext",
142146
"PytatoJAXArrayContext",
147+
"PytatoParallelPyOpenCLArrayContext",
143148
"PytatoPyOpenCLArrayContext",
144149
"PytestArrayContextFactory",
145150
"PytestPyOpenCLArrayContextFactory",

arraycontext/impl/pytato/__init__.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
The following :mod:`pytato`-based array contexts are provided:
1414
1515
.. autoclass:: PytatoPyOpenCLArrayContext
16+
.. autoclass:: PytatoParallelPyOpenCLArrayContext
1617
.. autoclass:: PytatoJAXArrayContext
1718
1819
@@ -28,7 +29,8 @@
2829
.. automodule:: arraycontext.impl.pytato.utils
2930
"""
3031
__copyright__ = """
31-
Copyright (C) 2020-1 University of Illinois Board of Trustees
32+
Copyright (C) 2020-6 University of Illinois Board of Trustees
33+
Copyright (C) 2022-3 Kaushik Kulkarni
3234
"""
3335

3436
__license__ = """
@@ -827,9 +829,15 @@ def compile(self,
827829
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
828830
) -> pytato.AbstractResultWithNamedArrays:
829831
import pytato as pt
832+
833+
dag = pt.transform.deduplicate_data_wrappers(dag)
834+
830835
dag = pt.tag_all_calls_to_be_inlined(dag)
831836
dag = pt.inline_calls(dag)
832-
return pt.transform.materialize_with_mpms(dag)
837+
838+
dag = pt.transform.materialize_with_mpms(dag)
839+
840+
return dag
833841

834842
@override
835843
def einsum(self, spec, *args, arg_names=None, tagged=()):
@@ -909,6 +917,113 @@ def clone(self):
909917
# }}}
910918

911919

920+
# {{{ PytatoParallelPyOpenCLArrayContext
921+
922+
class PytatoParallelPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
923+
"""
924+
Same as :class:`PytatoPyOpenCLArrayContext`, but parallelizes across the device.
925+
926+
.. note::
927+
928+
Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for
929+
details on the transformation algorithm provided by this array context.
930+
931+
.. automethod:: transform_dag
932+
.. automethod:: transform_loopy_program
933+
"""
934+
# FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext
935+
# should be calling, or should it be left for more-concrete derived array
936+
# contexts? If the latter, where should it live?
937+
def _materialize_einsum_inputs_and_outputs(
938+
self, dag: pytato.AbstractResultWithNamedArrays
939+
) -> pytato.AbstractResultWithNamedArrays:
940+
import pytato as pt
941+
942+
from .utils import (
943+
get_inputs_and_outputs_of_einsum,
944+
get_inputs_and_outputs_of_reduction_nodes,
945+
)
946+
947+
einsum_inputs, einsum_outputs = get_inputs_and_outputs_of_einsum(dag)
948+
redn_inputs, redn_outputs = get_inputs_and_outputs_of_reduction_nodes(dag)
949+
reduction_inputs_outputs = (
950+
einsum_inputs | einsum_outputs | redn_inputs | redn_outputs)
951+
952+
def materialize(
953+
expr: pt.transform.ArrayOrNames) -> pt.transform.ArrayOrNames:
954+
if expr in reduction_inputs_outputs:
955+
if isinstance(expr, pt.InputArgumentBase):
956+
return expr
957+
else:
958+
return expr.tagged(pt.tags.ImplStored())
959+
else:
960+
return expr
961+
962+
return pt.transform.map_and_copy(dag, materialize)
963+
964+
@override
965+
def transform_dag(
966+
self, dag: pytato.AbstractResultWithNamedArrays
967+
) -> pytato.AbstractResultWithNamedArrays:
968+
r"""
969+
Returns a transformed version of *dag*, where the applied transform is:
970+
971+
#. Materialize as per MPMS materialization heuristic.
972+
#. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs.
973+
"""
974+
import pytato as pt
975+
976+
dag = pt.transform.deduplicate_data_wrappers(dag)
977+
978+
dag = pt.tag_all_calls_to_be_inlined(dag)
979+
dag = pt.inline_calls(dag)
980+
981+
dag = pt.transform.materialize_with_mpms(dag)
982+
dag = self._materialize_einsum_inputs_and_outputs(dag)
983+
984+
return dag
985+
986+
def _parallelize_across_device(
987+
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
988+
from .parallelize import (
989+
add_gbarrier_between_disjoint_loop_sets,
990+
alias_global_temporaries,
991+
split_iteration_domain_across_work_items,
992+
)
993+
994+
# Must add barriers before parallelizing, because some parallelization
995+
# transformations create new loop sets (for example, scalar reductions) and
996+
# create their own barriers as part of that process
997+
t_unit = add_gbarrier_between_disjoint_loop_sets(t_unit)
998+
999+
t_unit = split_iteration_domain_across_work_items(
1000+
t_unit, self.queue.device.max_compute_units)
1001+
1002+
# FIXME: Is this something that this abstract-ish
1003+
# PytatoParallelPyOpenCLArrayContext class should be calling, or should it
1004+
# be left for more-concrete derived array contexts? If the latter, where
1005+
# should it live?
1006+
t_unit = alias_global_temporaries(t_unit)
1007+
1008+
return t_unit
1009+
1010+
def transform_loopy_program(
1011+
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
1012+
r"""
1013+
Returns a transformed version of *t_unit*, where the applied transform is:
1014+
1015+
#. An execution grid size :math:`G` is selected based on *self*'s
1016+
OpenCL-device.
1017+
#. The iteration domain for each statement in the *t_unit* is divided to
1018+
equally among the work-items in :math:`G`.
1019+
#. Kernel boundaries are drawn between every set of disjoint loops.
1020+
#. Once the kernel boundaries are inferred, :func:`alias_global_temporaries`
1021+
is invoked to reduce the memory peak memory used by the transformed
1022+
program.
1023+
"""
1024+
return self._parallelize_across_device(t_unit)
1025+
1026+
9121027
# {{{ PytatoJAXArrayContext
9131028

9141029
class PytatoJAXArrayContext(_BasePytatoArrayContext):

0 commit comments

Comments
 (0)