|
13 | 13 | The following :mod:`pytato`-based array contexts are provided: |
14 | 14 |
|
15 | 15 | .. autoclass:: PytatoPyOpenCLArrayContext |
| 16 | +.. autoclass:: PytatoParallelPyOpenCLArrayContext |
16 | 17 | .. autoclass:: PytatoJAXArrayContext |
17 | 18 |
|
18 | 19 |
|
|
28 | 29 | .. automodule:: arraycontext.impl.pytato.utils |
29 | 30 | """ |
30 | 31 | __copyright__ = """ |
31 | | -Copyright (C) 2020-1 University of Illinois Board of Trustees |
| 32 | +Copyright (C) 2020-6 University of Illinois Board of Trustees |
| 33 | +Copyright (C) 2022-3 Kaushik Kulkarni |
32 | 34 | """ |
33 | 35 |
|
34 | 36 | __license__ = """ |
@@ -827,9 +829,15 @@ def compile(self, |
827 | 829 | def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays |
828 | 830 | ) -> pytato.AbstractResultWithNamedArrays: |
829 | 831 | import pytato as pt |
| 832 | + |
| 833 | + dag = pt.transform.deduplicate_data_wrappers(dag) |
| 834 | + |
830 | 835 | dag = pt.tag_all_calls_to_be_inlined(dag) |
831 | 836 | dag = pt.inline_calls(dag) |
832 | | - return pt.transform.materialize_with_mpms(dag) |
| 837 | + |
| 838 | + dag = pt.transform.materialize_with_mpms(dag) |
| 839 | + |
| 840 | + return dag |
833 | 841 |
|
834 | 842 | @override |
835 | 843 | def einsum(self, spec, *args, arg_names=None, tagged=()): |
@@ -909,6 +917,113 @@ def clone(self): |
909 | 917 | # }}} |
910 | 918 |
|
911 | 919 |
|
| 920 | +# {{{ PytatoParallelPyOpenCLArrayContext |
| 921 | + |
| 922 | +class PytatoParallelPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): |
| 923 | + """ |
| 924 | + Same as :class:`PytatoPyOpenCLArrayContext`, but parallelizes across the device. |
| 925 | +
|
| 926 | + .. note:: |
| 927 | +
|
| 928 | + Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for |
| 929 | + details on the transformation algorithm provided by this array context. |
| 930 | +
|
| 931 | + .. automethod:: transform_dag |
| 932 | + .. automethod:: transform_loopy_program |
| 933 | + """ |
| 934 | + # FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext |
| 935 | + # should be calling, or should it be left for more-concrete derived array |
| 936 | + # contexts? If the latter, where should it live? |
| 937 | + def _materialize_einsum_inputs_and_outputs( |
| 938 | + self, dag: pytato.AbstractResultWithNamedArrays |
| 939 | + ) -> pytato.AbstractResultWithNamedArrays: |
| 940 | + import pytato as pt |
| 941 | + |
| 942 | + from .utils import ( |
| 943 | + get_inputs_and_outputs_of_einsum, |
| 944 | + get_inputs_and_outputs_of_reduction_nodes, |
| 945 | + ) |
| 946 | + |
| 947 | + einsum_inputs, einsum_outputs = get_inputs_and_outputs_of_einsum(dag) |
| 948 | + redn_inputs, redn_outputs = get_inputs_and_outputs_of_reduction_nodes(dag) |
| 949 | + reduction_inputs_outputs = ( |
| 950 | + einsum_inputs | einsum_outputs | redn_inputs | redn_outputs) |
| 951 | + |
| 952 | + def materialize( |
| 953 | + expr: pt.transform.ArrayOrNames) -> pt.transform.ArrayOrNames: |
| 954 | + if expr in reduction_inputs_outputs: |
| 955 | + if isinstance(expr, pt.InputArgumentBase): |
| 956 | + return expr |
| 957 | + else: |
| 958 | + return expr.tagged(pt.tags.ImplStored()) |
| 959 | + else: |
| 960 | + return expr |
| 961 | + |
| 962 | + return pt.transform.map_and_copy(dag, materialize) |
| 963 | + |
| 964 | + @override |
| 965 | + def transform_dag( |
| 966 | + self, dag: pytato.AbstractResultWithNamedArrays |
| 967 | + ) -> pytato.AbstractResultWithNamedArrays: |
| 968 | + r""" |
| 969 | + Returns a transformed version of *dag*, where the applied transform is: |
| 970 | +
|
| 971 | + #. Materialize as per MPMS materialization heuristic. |
| 972 | + #. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs. |
| 973 | + """ |
| 974 | + import pytato as pt |
| 975 | + |
| 976 | + dag = pt.transform.deduplicate_data_wrappers(dag) |
| 977 | + |
| 978 | + dag = pt.tag_all_calls_to_be_inlined(dag) |
| 979 | + dag = pt.inline_calls(dag) |
| 980 | + |
| 981 | + dag = pt.transform.materialize_with_mpms(dag) |
| 982 | + dag = self._materialize_einsum_inputs_and_outputs(dag) |
| 983 | + |
| 984 | + return dag |
| 985 | + |
| 986 | + def _parallelize_across_device( |
| 987 | + self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: |
| 988 | + from .parallelize import ( |
| 989 | + add_gbarrier_between_disjoint_loop_sets, |
| 990 | + alias_global_temporaries, |
| 991 | + split_iteration_domain_across_work_items, |
| 992 | + ) |
| 993 | + |
| 994 | + # Must add barriers before parallelizing, because some parallelization |
| 995 | + # transformations create new loop sets (for example, scalar reductions) and |
| 996 | + # create their own barriers as part of that process |
| 997 | + t_unit = add_gbarrier_between_disjoint_loop_sets(t_unit) |
| 998 | + |
| 999 | + t_unit = split_iteration_domain_across_work_items( |
| 1000 | + t_unit, self.queue.device.max_compute_units) |
| 1001 | + |
| 1002 | + # FIXME: Is this something that this abstract-ish |
| 1003 | + # PytatoParallelPyOpenCLArrayContext class should be calling, or should it |
| 1004 | + # be left for more-concrete derived array contexts? If the latter, where |
| 1005 | + # should it live? |
| 1006 | + t_unit = alias_global_temporaries(t_unit) |
| 1007 | + |
| 1008 | + return t_unit |
| 1009 | + |
| 1010 | + def transform_loopy_program( |
| 1011 | + self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: |
| 1012 | + r""" |
| 1013 | + Returns a transformed version of *t_unit*, where the applied transform is: |
| 1014 | +
|
| 1015 | + #. An execution grid size :math:`G` is selected based on *self*'s |
| 1016 | + OpenCL-device. |
| 1017 | + #. The iteration domain for each statement in the *t_unit* is divided to |
| 1018 | + equally among the work-items in :math:`G`. |
| 1019 | + #. Kernel boundaries are drawn between every set of disjoint loops. |
| 1020 | + #. Once the kernel boundaries are inferred, :func:`alias_global_temporaries` |
| 1021 | + is invoked to reduce the memory peak memory used by the transformed |
| 1022 | + program. |
| 1023 | + """ |
| 1024 | + return self._parallelize_across_device(t_unit) |
| 1025 | + |
| 1026 | + |
912 | 1027 | # {{{ PytatoJAXArrayContext |
913 | 1028 |
|
914 | 1029 | class PytatoJAXArrayContext(_BasePytatoArrayContext): |
|
0 commit comments