diff --git a/.gitignore b/.gitignore index 52faecbe..3b726fb8 100644 --- a/.gitignore +++ b/.gitignore @@ -131,7 +131,7 @@ venv.bak/ .ropeproject # mkdocs documentation -/site +site/ # mypy .mypy_cache/ @@ -155,3 +155,7 @@ cython_debug/ #.idea/ /.vscode/ *.png + + +# custom run scripts +*.sh \ No newline at end of file diff --git a/irspec/docs/collective/.$base.drawio.bkp b/irspec/docs/collective/.$base.drawio.bkp new file mode 100644 index 00000000..99ca9f27 --- /dev/null +++ b/irspec/docs/collective/.$base.drawio.bkp @@ -0,0 +1,683 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/irspec/docs/collective/base.drawio b/irspec/docs/collective/base.drawio new file mode 100644 index 00000000..83fe4757 --- /dev/null +++ b/irspec/docs/collective/base.drawio @@ -0,0 +1,683 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/irspec/docs/collective/collective.md b/irspec/docs/collective/collective.md new file mode 100644 index 00000000..41d3ce5d --- /dev/null +++ b/irspec/docs/collective/collective.md @@ -0,0 +1,244 @@ +# Collective IR + +The goal of this document is to give an overview of the key concepts present in the IR. It does not (yet) fully describe the semantics of the computation. + + +## Syntax Fundamentals + +### Streams +The stream class of the Spatial IR is extended with `multistream`, for a scalar type ``. + +MultiStreams take a name and a root in (x,y) coordinates as arguments. Additionally a Broadcast or Reduce can be defined. + +### Collective Functions +Collective Communication functions can be called inside the compute block. For further implementation details see the specific collective definition. + +## Broadcast +A broadcast is defined with the standard send and receive framework provided by the Spatial IR. It is differentiated from the single point to point communication by using a `multistream` instead of a standard stream. This mimics the support for broadcast communication found in many spatial architectures. + +In the dataflow block a broadcast is defined in the following way: +``` +multistream name = broadcast_stream(root_x, root_y) { + channels = auto + } +``` +where (root_x, root_y) defines the sender. The name is important to give as an argument in the compute blocks. With channels a specific channel can be targeted for the communication in architectures that support it. In almost all situations auto should lead to optimal results. + +Sending data in a broadcast that is defined via the multistream `bcast` can therefore be defined as: +```rust +compute i16 variable, i16 variable in subgrid_expression { + send(data, bcast) +} +``` +where data is the data being sent. + +???+ example "Example: Simple Broadcast" + ```rust + compute i16 i, i16 j in [1:N, 0] { + await receive(a, bcast) + } + + compute i16 i, i16 j in [0, 0] { + await send(a, bcast) + } + ``` + where `i`, `j` are `i16` variables that are bound to the coordinates of the PEs in the subgrid and `bcast` is a multistream. + + This can be generated with the following code: + ```rust + dataflow i16 i, i16 j in [0:N, 0] { + multistream bcast = broadcast_stream(0, 0) { + channels = auto + } + } + + compute i16 i, i16 j in [0:N, 0] { + await broadcast(a, bcast); + } + ``` + +In the future the functionality could be extended with an optional send-receive routing (like in the reduce case) for devices that do not support broadcast communication. + +## Reduce + +Most architectures do not support Reduce operations. Therefore we translate reduces to simple send-receive communication. + +``` +NOTE: We currently only support reduce in a N-by-N grid +that can not be defined partially or in multiple rounds. +``` + +In the dataflowblock a reduce is defined the following way: +``` +multistream name = reduce_stream(root_x, root_y) { + graph = auto, + op = S_SUM, + pipelined = true + } +``` +where (root_x, root_y) defines the receiver. The name is important to give as an argument in the compute blocks. graph chooses the layout the communication follows. Further details on the different layouts available can be found in the [Layouts section](layouts.md). op defines which operation to use for the reduce. The currently supported list can be found below. pipelined can either be 'true' or 'false' and defines whether when sending arrays the whole array gets received by the next processing element (pipelined = false) or if each element of the array gets send on before receiving the next element. + +The options for the operation `op` are: + +- CL_SUM (returns the sum of all elements) +- CL_PRODUCT (returns the product of all elements) + +At the moment parameters are not allowed in the range of the coordinate grid, i.e. +```rust +dataflow i16 i, i16 j in [0:N , 0:N] {...} +``` +is not allowed. + +In the computeblocks a reduce can then be used with the following line: +```rust +await reduce(data, name) +``` +where data is the element/array to reduce on. + +???+ example "Example: Simple Broadcast with snake communication" + ```rust + kernel @add() { + place i16 i, i16 j in [0:5 , 0:5] { + i16[100] a + } + dataflow i16 i, i16 j in [0:5:1 , 0:5:1] { + stream reduce = relative_stream(-1, 0) + } + dataflow i16 i, i16 j in [0:5:1 , 1:4:1] { + stream reduce#1 = relative_stream(1, 0) + } + dataflow i16 i, i16 j in [0:1:1 , 1:5:1] { + stream reduce#2 = relative_stream(0, -1) + } + dataflow i16 i, i16 j in [4:5:1 , 0:4:1] { + stream reduce#3 = relative_stream(0, -1) + } + compute i16 i, i16 j in [0:1 , 0:1] { + a[0] = 1 + await foreach i32 reduce_runner, i16 reduce_receive in [0:100], receive(reduce#1) { + a[reduce_runner] = (a[reduce_runner] + reduce_receive) + } + } + compute i16 i, i16 j in [0:1 , 1:2] { + a[0] = 1 + await foreach i32 reduce_runner#1, i16 reduce_receive#1 in [0:100], receive(reduce#1) { + a[reduce_runner#1] = (a[reduce_runner#1] + reduce_receive#1) + } + await send(a, reduce#2) + } + compute i16 i, i16 j in [0:1 , 2:3] { + a[0] = 1 + await foreach i32 reduce_runner#2, i16 reduce_receive#2 in [0:100], receive(reduce#1) { + a[reduce_runner#2] = (a[reduce_runner#2] + reduce_receive#2) + } + await send(a, reduce#2) + } + compute i16 i, i16 j in [0:1 , 3:4] { + a[0] = 1 + await foreach i32 reduce_runner#3, i16 reduce_receive#3 in [0:100], receive(reduce#1) { + a[reduce_runner#3] = (a[reduce_runner#3] + reduce_receive#3) + } + await send(a, reduce#2) + } + compute i16 i, i16 j in [0:1 , 4:5] { + a[0] = 1 + await foreach i32 reduce_runner#4, i16 reduce_receive#4 in [0:100], receive(reduce#1) { + a[reduce_runner#4] = (a[reduce_runner#4] + reduce_receive#4) + } + await send(a, reduce#3) + } + compute i16 i, i16 j in [4:5 , 0:1] { + a[0] = 1 + await foreach i32 reduce_runner#5, i16 reduce_receive#5 in [0:100], receive(reduce#4) { + a[reduce_runner#5] = (a[reduce_runner#5] + reduce_receive#5) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [4:5 , 1:2] { + a[0] = 1 + await foreach i32 reduce_runner#6, i16 reduce_receive#6 in [0:100], receive(reduce#2) { + a[reduce_runner#6] = (a[reduce_runner#6] + reduce_receive#6) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [4:5 , 2:3] { + a[0] = 1 + await foreach i32 reduce_runner#7, i16 reduce_receive#7 in [0:100], receive(reduce#2) { + a[reduce_runner#7] = (a[reduce_runner#7] + reduce_receive#7) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [4:5 , 3:4] { + a[0] = 1 + await foreach i32 reduce_runner#8, i16 reduce_receive#8 in [0:100], receive(reduce#2) { + a[reduce_runner#8] = (a[reduce_runner#8] + reduce_receive#8) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [4:5 , 4:5] { + a[0] = 1 + await send(a, reduce#1) + } + compute i16 i, i16 j in [1:4 , 0:1] { + a[0] = 1 + await foreach i32 reduce_runner#9, i16 reduce_receive#9 in [0:100], receive(reduce#1) { + a[reduce_runner#9] = (a[reduce_runner#9] + reduce_receive#9) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [1:4 , 1:2] { + a[0] = 1 + await foreach i32 reduce_runner#10, i16 reduce_receive#10 in [0:100], receive(reduce#1) { + a[reduce_runner#10] = (a[reduce_runner#10] + reduce_receive#10) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [1:4 , 2:3] { + a[0] = 1 + await foreach i32 reduce_runner#11, i16 reduce_receive#11 in [0:100], receive(reduce#1) { + a[reduce_runner#11] = (a[reduce_runner#11] + reduce_receive#11) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [1:4 , 3:4] { + a[0] = 1 + await foreach i32 reduce_runner#12, i16 reduce_receive#12 in [0:100], receive(reduce#1) { + a[reduce_runner#12] = (a[reduce_runner#12] + reduce_receive#12) + } + await send(a, reduce#1) + } + compute i16 i, i16 j in [1:4 , 4:5] { + a[0] = 1 + await foreach i32 reduce_runner#13, i16 reduce_receive#13 in [0:100], receive(reduce#1) { + a[reduce_runner#13] = (a[reduce_runner#13] + reduce_receive#13) + } + await send(a, reduce#1) + } + } + ``` + + can be generated from: + + ```rust + kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = snake, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + } + } + ``` + + The layout for this example can be found as the snake example in the [Layouts section](layouts.md). \ No newline at end of file diff --git a/irspec/docs/collective/design_goals.md b/irspec/docs/collective/design_goals.md new file mode 100644 index 00000000..d1a8b3d3 --- /dev/null +++ b/irspec/docs/collective/design_goals.md @@ -0,0 +1,6 @@ +# Design Goals + +- Models collective communication schemas + - Broadcast + - Reduce +- Integrates into Spatial IR for device agnostic communication abstractions \ No newline at end of file diff --git a/irspec/docs/collective/grid.drawio.svg b/irspec/docs/collective/grid.drawio.svg new file mode 100644 index 00000000..eba57855 --- /dev/null +++ b/irspec/docs/collective/grid.drawio.svg @@ -0,0 +1,4 @@ + + + +
*
\ No newline at end of file diff --git a/irspec/docs/collective/layouts.md b/irspec/docs/collective/layouts.md new file mode 100644 index 00000000..88e6468a --- /dev/null +++ b/irspec/docs/collective/layouts.md @@ -0,0 +1,31 @@ +# Layouts + +When using Collective Reduce functions two different communication schemas / layouts can be used. + +## Usage + +To choose the layout the `algorithm` flag can be set to +``` +algorithm = grid +or +algorithm = snake +or +algorithm = auto +``` +A schematic example for both the snake and grid layout can be found below. Currently `auto` chooses the grid algorithm. The snake algorithm can only be choosen if the root of the reduce is in one of the 4 corners of the communication grid the reduce is defined on. For large arrays snake will maximize throughput while for short arrays grid will minimize latency. + +## Definitions + +Below are schematics to understand the logic of the snake and grid pattern. The root of the reduce is marked with a star `*`. + +### Snake + +The snake pattern currently only works with the root in one of the four corners. It then puts all the PEs on a string favoring horizontal communication. + +![Alternative Text](snake.drawio.svg) + +### Grid + +The grid pattern works with the root in every PE. The first reduction is horizontally and the second one is vertically. + +![Alternative Text](grid.drawio.svg) \ No newline at end of file diff --git a/irspec/docs/collective/snake.drawio.svg b/irspec/docs/collective/snake.drawio.svg new file mode 100644 index 00000000..08473e95 --- /dev/null +++ b/irspec/docs/collective/snake.drawio.svg @@ -0,0 +1,4 @@ + + + +
*
\ No newline at end of file diff --git a/irspec/mkdocs.yml b/irspec/mkdocs.yml index 76ffdc4c..eb82843b 100644 --- a/irspec/mkdocs.yml +++ b/irspec/mkdocs.yml @@ -13,6 +13,10 @@ nav: - Routing Semantics: spatial/routing.md - Parameterized Semantics: spatial/parametric.md - Examples: spatial/examples.md + - Collective IR: + - Design Goals: collective/design_goals.md + - Specification: collective/collective.md + - Layouts: collective/layouts.md - Dataflow Task IR: dataflowtask/dataflowtask.md markdown_extensions: diff --git a/samples/collective/hard_reduce_1.ref_tile b/samples/collective/hard_reduce_1.ref_tile new file mode 100644 index 00000000..ef96d70d --- /dev/null +++ b/samples/collective/hard_reduce_1.ref_tile @@ -0,0 +1,26 @@ +[0:1 , 0:1] +[0:1 , 1:2] +[0:1 , 2:3] +[0:1 , 3:4] +[0:1 , 4:5] +[1:2 , 0:1] +[1:2 , 1:2] +[1:2 , 2:3] +[1:2 , 3:4] +[1:2 , 4:5] +[2:3 , 0:1] +[2:3 , 1:2] +[2:3 , 2:3] +[2:3 , 3:4] +[2:3 , 4:5] +[3:4 , 0:1] +[3:4 , 1:2] +[3:4 , 2:3] +[3:4 , 3:4] +[3:4 , 4:5] +[4:5 , 0:1] +[4:5 , 1:2] +[4:5 , 2:3] +[4:5 , 3:4] +[4:5 , 4:5] +14 \ No newline at end of file diff --git a/samples/collective/hard_reduce_1.sptl b/samples/collective/hard_reduce_1.sptl new file mode 100644 index 00000000..da2f608f --- /dev/null +++ b/samples/collective/hard_reduce_1.sptl @@ -0,0 +1,32 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + multistream red1 = reduce_stream(2, 2) { + algorithm = grid, + op = CL_SUM, + pipelined = true + } + multistream red2 = reduce_stream(4, 4) { + algorithm = snake, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + await reduce(a, red1) + await reduce(a, red2) + } +} \ No newline at end of file diff --git a/samples/collective/hard_reduce_2.ref_tile b/samples/collective/hard_reduce_2.ref_tile new file mode 100644 index 00000000..1a7319c3 --- /dev/null +++ b/samples/collective/hard_reduce_2.ref_tile @@ -0,0 +1,16 @@ +[0:1 , 0:1] +[0:1 , 4:5] +[4:5 , 0:1] +[4:5 , 4:5] +[0:1 , 1:2] +[0:1 , 2:3] +[0:1 , 3:4] +[4:5 , 1:2] +[4:5 , 2:3] +[4:5 , 3:4] +[1:4 , 0:1] +[1:4 , 1:2] +[1:4 , 2:3] +[1:4 , 3:4] +[1:4 , 4:5] +8 \ No newline at end of file diff --git a/samples/collective/hard_reduce_2.sptl b/samples/collective/hard_reduce_2.sptl new file mode 100644 index 00000000..d165aa51 --- /dev/null +++ b/samples/collective/hard_reduce_2.sptl @@ -0,0 +1,44 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + multistream red2 = reduce_stream(0, 4) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + multistream red3 = reduce_stream(4, 0) { + algorithm = snake, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + for i32 test1 in [0:999] { + a[0] = 100 + await reduce(a, red) + for i32 test2 in [0:25] { + await reduce(a, red2) + await reduce(a, red) + for i32 test3 in [0:30] { + for i32 test4 in [0:35] { + for i32 test5 in [0:40] { + await reduce(a, red3) + } + } + } + } + } + } +} \ No newline at end of file diff --git a/samples/collective/hard_reduce_3.ref_tile b/samples/collective/hard_reduce_3.ref_tile new file mode 100644 index 00000000..ba06f1a4 --- /dev/null +++ b/samples/collective/hard_reduce_3.ref_tile @@ -0,0 +1,8 @@ +[0:1 , 0:1] +[0:1 , 4:5] +[4:5 , 0:1] +[4:5 , 4:5] +[0:1 , 1:4] +[4:5 , 1:4] +[1:4 , 0:5] +6 \ No newline at end of file diff --git a/samples/collective/hard_reduce_3.sptl b/samples/collective/hard_reduce_3.sptl new file mode 100644 index 00000000..8a57ea51 --- /dev/null +++ b/samples/collective/hard_reduce_3.sptl @@ -0,0 +1,44 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + multistream red2 = reduce_stream(0, 4) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + multistream red3 = reduce_stream(4, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + for i32 test1 in [0:999] { + a[0] = 100 + await reduce(a, red) + for i32 test2 in [0:25] { + await reduce(a, red2) + await reduce(a, red) + for i32 test3 in [0:30] { + for i32 test4 in [0:35] { + for i32 test5 in [0:40] { + await reduce(a, red3) + } + } + } + } + } + } +} \ No newline at end of file diff --git a/samples/collective/medium_reduce_grid_1.ref_tile b/samples/collective/medium_reduce_grid_1.ref_tile new file mode 100644 index 00000000..66d2a222 --- /dev/null +++ b/samples/collective/medium_reduce_grid_1.ref_tile @@ -0,0 +1,14 @@ +[0:1 , 0:1] +[2:3 , 2:3] +[4:5 , 4:5] +[0:1 , 4:5] +[0:1 , 1:4] +[1:2 , 0:5] +[3:4 , 0:5] +[2:3 , 0:1] +[2:3 , 1:2] +[2:3 , 3:4] +[2:3 , 4:5] +[4:5 , 0:1] +[4:5 , 1:4] +8 \ No newline at end of file diff --git a/samples/collective/medium_reduce_grid_1.sptl b/samples/collective/medium_reduce_grid_1.sptl new file mode 100644 index 00000000..9abf7f34 --- /dev/null +++ b/samples/collective/medium_reduce_grid_1.sptl @@ -0,0 +1,32 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + multistream red1 = reduce_stream(2, 2) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + multistream red2 = reduce_stream(4, 4) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + await reduce(a, red1) + await reduce(a, red2) + } +} \ No newline at end of file diff --git a/samples/collective/simple_bcast.ref_tile b/samples/collective/simple_bcast.ref_tile new file mode 100644 index 00000000..8d20c461 --- /dev/null +++ b/samples/collective/simple_bcast.ref_tile @@ -0,0 +1,4 @@ +[0:1 , 0:1] +[0:1 , 1:N] +[1:N , 0:N] +1 \ No newline at end of file diff --git a/samples/collective/simple_bcast.sptl b/samples/collective/simple_bcast.sptl new file mode 100644 index 00000000..dcd4faa6 --- /dev/null +++ b/samples/collective/simple_bcast.sptl @@ -0,0 +1,18 @@ + +kernel @add(stream readonly a_in, stream[N, N] writeonly out) { + + place u16 i, u16 j in [0:N, 0:N] { + f32[K] a; + } + + dataflow i16 i, i16 j in [0:N, 0:N] { + multistream bcast = broadcast_stream(0, 0) { + channels = auto + } + } + + compute i16 i, i16 j in [0:N, 0:N] { + await broadcast(a, bcast); + } + +} diff --git a/samples/collective/simple_reduce_grid_1.ref_tile b/samples/collective/simple_reduce_grid_1.ref_tile new file mode 100644 index 00000000..ff6a7b52 --- /dev/null +++ b/samples/collective/simple_reduce_grid_1.ref_tile @@ -0,0 +1,6 @@ +[0:1 , 0:1] +[0:1 , 1:4] +[0:1 , 4:5] +[1:4 , 0:5] +[4:5 , 0:5] +2 \ No newline at end of file diff --git a/samples/collective/simple_reduce_grid_1.sptl b/samples/collective/simple_reduce_grid_1.sptl new file mode 100644 index 00000000..15d4ca62 --- /dev/null +++ b/samples/collective/simple_reduce_grid_1.sptl @@ -0,0 +1,20 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + } +} \ No newline at end of file diff --git a/samples/collective/simple_reduce_grid_2.ref_tile b/samples/collective/simple_reduce_grid_2.ref_tile new file mode 100644 index 00000000..548a6420 --- /dev/null +++ b/samples/collective/simple_reduce_grid_2.ref_tile @@ -0,0 +1,8 @@ +[0:1 , 0:5] +[1:2 , 0:1] +[1:2 , 1:3] +[1:2 , 3:4] +[1:2 , 4:5] +[2:4 , 0:5] +[4:5 , 0:5] +4 \ No newline at end of file diff --git a/samples/collective/simple_reduce_grid_2.sptl b/samples/collective/simple_reduce_grid_2.sptl new file mode 100644 index 00000000..e8486149 --- /dev/null +++ b/samples/collective/simple_reduce_grid_2.sptl @@ -0,0 +1,20 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(1, 3) { + algorithm = grid, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + } +} \ No newline at end of file diff --git a/samples/collective/simple_reduce_grid_pipelined_1.ref_tile b/samples/collective/simple_reduce_grid_pipelined_1.ref_tile new file mode 100644 index 00000000..f1ce4840 --- /dev/null +++ b/samples/collective/simple_reduce_grid_pipelined_1.ref_tile @@ -0,0 +1,26 @@ +[0:1 , 0:1] +[0:1 , 1:2] +[0:1 , 2:3] +[0:1 , 3:4] +[0:1 , 4:5] +[1:2 , 0:1] +[1:2 , 1:2] +[1:2 , 2:3] +[1:2 , 3:4] +[1:2 , 4:5] +[2:3 , 0:1] +[2:3 , 1:2] +[2:3 , 2:3] +[2:3 , 3:4] +[2:3 , 4:5] +[3:4 , 0:1] +[3:4 , 1:2] +[3:4 , 2:3] +[3:4 , 3:4] +[3:4 , 4:5] +[4:5 , 0:1] +[4:5 , 1:2] +[4:5 , 2:3] +[4:5 , 3:4] +[4:5 , 4:5] +4 \ No newline at end of file diff --git a/samples/collective/simple_reduce_grid_pipelined_1.sptl b/samples/collective/simple_reduce_grid_pipelined_1.sptl new file mode 100644 index 00000000..a64eb84b --- /dev/null +++ b/samples/collective/simple_reduce_grid_pipelined_1.sptl @@ -0,0 +1,20 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = true + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + } +} \ No newline at end of file diff --git a/samples/collective/simple_reduce_grid_pipelined_2.ref_tile b/samples/collective/simple_reduce_grid_pipelined_2.ref_tile new file mode 100644 index 00000000..4bf29eac --- /dev/null +++ b/samples/collective/simple_reduce_grid_pipelined_2.ref_tile @@ -0,0 +1,26 @@ +[0:1 , 0:1] +[0:1 , 1:2] +[0:1 , 2:3] +[0:1 , 3:4] +[0:1 , 4:5] +[1:2 , 0:1] +[1:2 , 1:2] +[1:2 , 2:3] +[1:2 , 3:4] +[1:2 , 4:5] +[2:3 , 0:1] +[2:3 , 1:2] +[2:3 , 2:3] +[2:3 , 3:4] +[2:3 , 4:5] +[3:4 , 0:1] +[3:4 , 1:2] +[3:4 , 2:3] +[3:4 , 3:4] +[3:4 , 4:5] +[4:5 , 0:1] +[4:5 , 1:2] +[4:5 , 2:3] +[4:5 , 3:4] +[4:5 , 4:5] +6 \ No newline at end of file diff --git a/samples/collective/simple_reduce_grid_pipelined_2.sptl b/samples/collective/simple_reduce_grid_pipelined_2.sptl new file mode 100644 index 00000000..85eab140 --- /dev/null +++ b/samples/collective/simple_reduce_grid_pipelined_2.sptl @@ -0,0 +1,20 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(1, 3) { + algorithm = grid, + op = CL_SUM, + pipelined = true + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + } +} \ No newline at end of file diff --git a/samples/collective/simple_reduce_looped.ref_tile b/samples/collective/simple_reduce_looped.ref_tile new file mode 100644 index 00000000..f1ce4840 --- /dev/null +++ b/samples/collective/simple_reduce_looped.ref_tile @@ -0,0 +1,26 @@ +[0:1 , 0:1] +[0:1 , 1:2] +[0:1 , 2:3] +[0:1 , 3:4] +[0:1 , 4:5] +[1:2 , 0:1] +[1:2 , 1:2] +[1:2 , 2:3] +[1:2 , 3:4] +[1:2 , 4:5] +[2:3 , 0:1] +[2:3 , 1:2] +[2:3 , 2:3] +[2:3 , 3:4] +[2:3 , 4:5] +[3:4 , 0:1] +[3:4 , 1:2] +[3:4 , 2:3] +[3:4 , 3:4] +[3:4 , 4:5] +[4:5 , 0:1] +[4:5 , 1:2] +[4:5 , 2:3] +[4:5 , 3:4] +[4:5 , 4:5] +4 \ No newline at end of file diff --git a/samples/collective/simple_reduce_looped.sptl b/samples/collective/simple_reduce_looped.sptl new file mode 100644 index 00000000..151cf0a6 --- /dev/null +++ b/samples/collective/simple_reduce_looped.sptl @@ -0,0 +1,23 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = grid, + op = CL_SUM, + pipelined = true + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + for i32 test1 in [0:999] { + a[0] = 100 + await reduce(a, red) + } + } +} \ No newline at end of file diff --git a/samples/collective/simple_reduce_snake_1.ref_tile b/samples/collective/simple_reduce_snake_1.ref_tile new file mode 100644 index 00000000..99a429c3 --- /dev/null +++ b/samples/collective/simple_reduce_snake_1.ref_tile @@ -0,0 +1,16 @@ +[0:1 , 0:1] +[0:1 , 1:2] +[0:1 , 2:3] +[0:1 , 3:4] +[0:1 , 4:5] +[4:5 , 0:1] +[4:5 , 1:2] +[4:5 , 2:3] +[4:5 , 3:4] +[4:5 , 4:5] +[1:4 , 0:1] +[1:4 , 1:2] +[1:4 , 2:3] +[1:4 , 3:4] +[1:4 , 4:5] +4 \ No newline at end of file diff --git a/samples/collective/simple_reduce_snake_1.sptl b/samples/collective/simple_reduce_snake_1.sptl new file mode 100644 index 00000000..b81c445d --- /dev/null +++ b/samples/collective/simple_reduce_snake_1.sptl @@ -0,0 +1,20 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = snake, + op = CL_SUM, + pipelined = false + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + } +} \ No newline at end of file diff --git a/samples/collective/simple_reduce_snake_pipelined_1.ref_tile b/samples/collective/simple_reduce_snake_pipelined_1.ref_tile new file mode 100644 index 00000000..4bf29eac --- /dev/null +++ b/samples/collective/simple_reduce_snake_pipelined_1.ref_tile @@ -0,0 +1,26 @@ +[0:1 , 0:1] +[0:1 , 1:2] +[0:1 , 2:3] +[0:1 , 3:4] +[0:1 , 4:5] +[1:2 , 0:1] +[1:2 , 1:2] +[1:2 , 2:3] +[1:2 , 3:4] +[1:2 , 4:5] +[2:3 , 0:1] +[2:3 , 1:2] +[2:3 , 2:3] +[2:3 , 3:4] +[2:3 , 4:5] +[3:4 , 0:1] +[3:4 , 1:2] +[3:4 , 2:3] +[3:4 , 3:4] +[3:4 , 4:5] +[4:5 , 0:1] +[4:5 , 1:2] +[4:5 , 2:3] +[4:5 , 3:4] +[4:5 , 4:5] +6 \ No newline at end of file diff --git a/samples/collective/simple_reduce_snake_pipelined_1.sptl b/samples/collective/simple_reduce_snake_pipelined_1.sptl new file mode 100644 index 00000000..7006db6d --- /dev/null +++ b/samples/collective/simple_reduce_snake_pipelined_1.sptl @@ -0,0 +1,20 @@ + +kernel @add() { + + place i16 i, i16 j in [0:5, 0:5] { + i16[100] a + } + + dataflow i16 i, i16 j in [0:5, 0:5] { + multistream red = reduce_stream(0, 0) { + algorithm = snake, + op = CL_SUM, + pipelined = true + } + } + + compute i16 i, i16 j in [0:5, 0:5] { + a[0] = 1 + await reduce(a, red) + } +} \ No newline at end of file diff --git a/spatialstencil/lowering/versioning.py b/spatialstencil/lowering/versioning.py index 431e124b..8c93ac5e 100644 --- a/spatialstencil/lowering/versioning.py +++ b/spatialstencil/lowering/versioning.py @@ -21,3 +21,10 @@ def next_version(self, name: str) -> T: version = self._var_counter[name] self._var_counter[name] += 1 return self.cls(name, version) + + + def current_version(self, name: str) -> T: + """ + Gets the current version of a variable name. + """ + return self.cls(name, self._var_counter[name]) diff --git a/spatialstencil/optimizations/__init__.py b/spatialstencil/optimizations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spatialstencil/optimizations/optimization_pass.py b/spatialstencil/optimizations/optimization_pass.py new file mode 100644 index 00000000..1cda1e3a --- /dev/null +++ b/spatialstencil/optimizations/optimization_pass.py @@ -0,0 +1,14 @@ +from spatialstencil.optimizations.spatial_reduce import ReduceOptimizer +from spatialstencil.optimizations.spatial_broadcast import BroadcastOptimizer + + + +def optimization_pass(program): + """ + Runs the spatial optimizations on the program. + """ + broadcast_optimizer = BroadcastOptimizer(program) + pass_1 = broadcast_optimizer.broadcast_subroutine() + reduce_optimizer = ReduceOptimizer(pass_1) + pass_2 = reduce_optimizer.reduce_subroutine() + return pass_2 \ No newline at end of file diff --git a/spatialstencil/optimizations/spatial_broadcast.py b/spatialstencil/optimizations/spatial_broadcast.py new file mode 100644 index 00000000..d28eae02 --- /dev/null +++ b/spatialstencil/optimizations/spatial_broadcast.py @@ -0,0 +1,238 @@ +from spatialstencil.syntax.spatial_ir.irnodes import (Kernel, DataflowBlock, MulStreamDeclaration, BroadcastRoutingDeclaration, ComputeBlock, ConstantLiteral, + SubgridExpression, RangeExpression, Expression, ScalarType, ForeachStatement, ForStatement, MapStatement, + AsyncBlock, BroadcastStatement, SendStatement, ReceiveStatement, Identifier) +from spatialstencil.lowering.versioning import Versioning + +class BroadcastOptimizer(): + roots: list[list[int]] = [] + broadcast_operations: dict[str, list[int]] = {} + + def __init__(self, kernel: Kernel) -> None: + self.name = kernel.name + self.parameters = kernel.parameters + self.arguments = kernel.arguments + self.body = kernel.body + self.versioning = Versioning[Identifier](Identifier) + self.roots = [] + self.broadcast_operations = {} + return None + + ## + # Replace the broadcast statements in the compute blocks and change the tiling of the compute statemtents accordingly + # Entry Function + ## + def broadcast_subroutine(self) -> Kernel: + self.find_roots() + if self.roots != []: + self.fix_subgrid() + self.replace_broadcast() + return Kernel(name=self.name, parameters=self.parameters, arguments=self.arguments, body=self.body) + + + ## + # Function to aggregate the roots of all broadcast operations + ## + def find_roots(self) -> None: + for elem in self.body: + if isinstance(elem, DataflowBlock): + for stmt in elem.statements: + if isinstance(stmt, MulStreamDeclaration) and isinstance(stmt.routing, BroadcastRoutingDeclaration): + self.roots.append([stmt.x.value.value, stmt.x.value.value + 1, stmt.y.value.value, stmt.y.value.value + 1]) + self.broadcast_operations[stmt.stream_name.name] = [stmt.x.value.value, stmt.y.value.value] + return None + + + ## + # Updates the compute block tiling for the new communication patterns + ## + def fix_subgrid(self) -> None: + newbody = [] + for elem in self.body: + if isinstance(elem, ComputeBlock): + x_start = elem.subgrid.x_range.start.value.value + x_stop = elem.subgrid.x_range.stop.value.value if isinstance(elem.subgrid.x_range.stop.value, ConstantLiteral) else None + y_start = elem.subgrid.y_range.start.value.value + y_stop = elem.subgrid.y_range.stop.value.value if isinstance(elem.subgrid.y_range.stop.value, ConstantLiteral) else None + + ## fix to deal with parameters + x_literal = True if x_stop is None else False + y_literal = True if y_stop is None else False + if x_literal: + x_stop = 9999999999999 + if y_literal: + y_stop = 9999999999999 + grid = [[[x_start, x_stop], [y_start, y_stop]]] + + for com_grid in self.roots: + to_remove = [] + for sub_grid in grid: + if com_grid[0] > sub_grid[0][0] and com_grid[0] < sub_grid[0][1]: + # print("left") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, com_grid[0]], [sub_y_start, sub_y_stop]]) + grid.append([[com_grid[0], sub_x_stop], [sub_y_start, sub_y_stop]]) + to_remove.append(sub_grid) + elif com_grid[1] > sub_grid[0][0] and com_grid[1] < sub_grid[0][1]: + # print("right") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, com_grid[1]], [sub_y_start, sub_y_stop]]) + grid.append([[com_grid[1], sub_x_stop], [sub_y_start, sub_y_stop]]) + to_remove.append(sub_grid) + elif com_grid[2] > sub_grid[1][0] and com_grid[2] < sub_grid[1][1] and com_grid[0] <= sub_grid[0][0] and com_grid[1] >= sub_grid[0][1]: + # print("top") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, sub_x_stop], [sub_y_start, com_grid[2]]]) + grid.append([[sub_x_start, sub_x_stop], [com_grid[2], sub_y_stop]]) + to_remove.append(sub_grid) + elif com_grid[3] > sub_grid[1][0] and com_grid[3] < sub_grid[1][1] and com_grid[0] <= sub_grid[0][0] and com_grid[1] >= sub_grid[0][1]: + # print("bottom") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, sub_x_stop], [sub_y_start, com_grid[3]]]) + grid.append([[sub_x_start, sub_x_stop], [com_grid[3], sub_y_stop]]) + to_remove.append(sub_grid) + # delete old unused + for rmv in to_remove: + grid.remove(rmv) + + for sub_grid in grid: + if x_literal or y_literal: + if sub_grid[0][1] == 9999999999999: + sub_grid[0][1] = elem.subgrid.x_range.stop.value + else: + sub_grid[0][1] = ConstantLiteral(value=sub_grid[0][1], dtype=elem.subgrid.x_range.start.value.dtype) + if sub_grid[1][1] == 9999999999999: + sub_grid[1][1] = elem.subgrid.y_range.stop.value + else: + sub_grid[1][1] = ConstantLiteral(value=sub_grid[1][1], dtype=elem.subgrid.y_range.start.value.dtype) + + if not x_literal and not y_literal: + for com_grid in grid: + newbody.append( + ComputeBlock( + elem.variables, + SubgridExpression( + RangeExpression( + start=Expression(ConstantLiteral(com_grid[0][0], ScalarType.i32)), + stop=Expression(ConstantLiteral(com_grid[0][1], ScalarType.i32)) + ), + RangeExpression( + start=Expression(ConstantLiteral(com_grid[1][0], ScalarType.i32)), + stop=Expression(ConstantLiteral(com_grid[1][1], ScalarType.i32)) + ) + ), + elem.statements + ) + ) + else: + for com_grid in grid: + newbody.append( + ComputeBlock( + elem.variables, + SubgridExpression( + RangeExpression( + start=Expression(ConstantLiteral(com_grid[0][0], ScalarType.i32)), + stop=Expression(com_grid[0][1]) + ), + RangeExpression( + start=Expression(ConstantLiteral(com_grid[1][0], ScalarType.i32)), + stop=Expression(com_grid[1][1]) + ) + ), + elem.statements + ) + ) + + else: + newbody.append(elem) + + self.body = newbody + return None + + + + ## + # Replace the broadcast statements with send and receive statements + ## + def _replace_broadcast(self, stmt, elem) -> list[Expression]: + x_start = elem.subgrid.x_range.start.value.value + x_stop = elem.subgrid.x_range.stop.value.value if isinstance(elem.subgrid.x_range.stop.value, ConstantLiteral) else None + y_start = elem.subgrid.y_range.start.value.value + y_stop = elem.subgrid.y_range.stop.value.value if isinstance(elem.subgrid.y_range.stop.value, ConstantLiteral) else None + + ## fix to deal with parameters + x_literal = True if x_stop is None else False + y_literal = True if y_stop is None else False + if x_literal: + x_stop = 9999999999999 + if y_literal: + y_stop = 9999999999999 + + name = stmt.stream_name.name + root = self.broadcast_operations[name] + + if x_start == root[0] and y_start == root[1] and x_stop == root[0] + 1 and y_stop == root[1] + 1: + send = SendStatement( + local_array=stmt.local_array, + stream_name=stmt.stream_name, + completion_name=None + ) + return [send] + else: + receive = ReceiveStatement( + local_array=stmt.local_array, + stream_name=stmt.stream_name, + completion_name=None + ) + return [receive] + + + ## + # Function to recursively go through the compute blocks and find all the occurernces of the broadcast statements + ## + def replace_stmt(self, stmt, elem, to_replace) -> list[Expression]: + input_stmt = stmt + if isinstance(stmt, to_replace): + if to_replace == BroadcastStatement: + return self._replace_broadcast(stmt, elem) + + # all of these use body + elif isinstance(stmt, ForeachStatement) or isinstance(stmt, ForStatement) or isinstance(stmt, MapStatement) or isinstance(stmt, AsyncBlock): + new_body = [] + for body_stmt in stmt.body: + replaced_stmts = self.replace_stmt(body_stmt, elem, to_replace) + for replaced_stmt in replaced_stmts: + new_body.append(replaced_stmt) + input_stmt.body = new_body + return [input_stmt] + + + ## + # Changes the occurences of broadcast statements in the compute blocks + ## + def replace_broadcast(self) -> None: + finalbody = [] + for elem in self.body: + if isinstance(elem, ComputeBlock): + statements = [] + for stmt in elem.statements: + new_stmts = self.replace_stmt(stmt, elem, BroadcastStatement) + for nstmt in new_stmts: + statements.append(nstmt) + finalbody.append(ComputeBlock(elem.variables, elem.subgrid, statements)) + else: + finalbody.append(elem) + + self.body = finalbody + return None \ No newline at end of file diff --git a/spatialstencil/optimizations/spatial_reduce.py b/spatialstencil/optimizations/spatial_reduce.py new file mode 100644 index 00000000..4a7fe6ed --- /dev/null +++ b/spatialstencil/optimizations/spatial_reduce.py @@ -0,0 +1,1376 @@ +from spatialstencil.syntax.spatial_ir.irnodes import (Kernel, ComputeBlock, ReduceStatement, Expression, SubgridExpression, RangeExpression, ConstantLiteral, ScalarType, + DataflowBlock, MulStreamDeclaration, ReduceRoutingDeclaration, StreamType, TypedIdentifier, ForeachStatement, ArraySlice, + BinaryOperator, SendStatement, ReceiveGenerator, AssignmentStatement, RelativeStreamDeclaration, PlaceBlock, Phase, + Parameter, KernelArgument, ReceiveStatement, ForStatement, FieldDeclaration, ArrayType, MapStatement, AsyncBlock, Identifier) +from typing import Union, Optional, Literal +from spatialstencil.lowering.versioning import Versioning + + +class ReduceOptimizer(): + name: str | None + parameters: list[Parameter] + arguments: list[KernelArgument] + body: list[PlaceBlock | DataflowBlock | ComputeBlock | Phase] + _communication_patterns: Optional[dict[str, dict[tuple[int, int], list[list[list[int]]]]]] = None + reduce_operations: dict[str, dict[str, Union[int, Literal['OP_SUM'], list[int]]]] = {} + grid_streams: dict[str, list[list]] = {} + snake_streams: dict[str, list[list]] = {} + pipelined: dict[str, bool] = {} + + + def __init__(self, kernel: Kernel) -> None: + self.name = kernel.name + self.parameters = kernel.parameters + self.arguments = kernel.arguments + self.body = kernel.body + self.versioning = Versioning[Identifier](Identifier) + self._communication_patterns = None + self.reduce_operations = {} + self.grid_streams = {} + self.snake_streams = {} + self.pipelined = {} + return None + + + ## + # Replace the reduce statements and change the tiling of the compute statemtents accordingly + # Entry Function + ## + def reduce_subroutine(self) -> Kernel: + self.change_data_blocks() + if self.reduce_operations != {}: + self.fix_subgrid() + self.change_compute_blocks() + return Kernel(name=self.name, parameters=self.parameters, arguments=self.arguments, body=self.body) + + + ## + # Template Send Statement + ## + def create_send_statement(self, stmt, pipelined_send, index) -> SendStatement: + send = SendStatement( + local_array=ArraySlice( + array=stmt.local_array, + indices=[Expression(value=self.versioning.current_version("reduce_runner"))] + ), + stream_name=pipelined_send[index], + completion_name=None + ) + return send + + ## + # Template Receive Statement + ## + def create_receive_statement(self, stmt, pipelined_receive, index) -> ReceiveStatement: + receive = ReceiveStatement( + local_array=self.versioning.current_version("pipeline_helper"), + stream_name=pipelined_receive[index], + completion_name=None + ) + return receive + + + ## + # Template Binary Operation + ## + def create_binary_operation(self, stmt, current_op, rhs) -> AssignmentStatement: + bin_op = AssignmentStatement( + destination=ArraySlice( + array=stmt.local_array, + indices=[Expression(value=self.versioning.current_version("reduce_runner"))] + ), + source=Expression( + BinaryOperator( + left=Expression( + value=ArraySlice( + array=stmt.local_array, + indices=[Expression(value=self.versioning.current_version("reduce_runner"))] + ) + ), + op= current_op, + right=Expression( + value=rhs + ) + ) + ) + ) + return bin_op + + + ## + # Function that replaces the reduce statement in the compute statement + # stmt is the reduce statement that is getting replaced + # elem is the compute statement that is getting changed to receive context information + ## + def replace_reduce(self, stmt, elem) -> list[Expression]: + + current_position = [elem.subgrid.x_range.start.value.value, + elem.subgrid.x_range.stop.value.value, + elem.subgrid.y_range.start.value.value, + elem.subgrid.y_range.stop.value.value] + newstatements = [] + stream_name = stmt.stream_name.name + operation_id = self.reduce_operations[stmt.stream_name.name][0]['op'] + root = self.reduce_operations[stmt.stream_name.name][1] + origin = self.reduce_operations[stmt.stream_name.name][4] + complete_grid = [self.reduce_operations[stmt.stream_name.name][2], self.reduce_operations[stmt.stream_name.name][3]] + send_identifier = self.reduce_operations[stmt.stream_name.name][5] + + # send_amount is the length of the array that is being sent - needed for the custom for loop + send_amount = self.reduce_operations[stmt.stream_name.name][6] + if send_amount == None: + for elem in self.body: + for srch in elem.iter_child_nodes(): + if isinstance(srch, FieldDeclaration): + if srch.field_name == send_identifier: + if isinstance(srch.dtype, ArrayType): + send_amount = srch.dtype.shape[0].value.value + self.reduce_operations[stmt.stream_name.name][6] = send_amount + else: + raise ValueError(f"Field {send_identifier} is not an array. Only arrays are currently supported.") + + if stream_name in self.grid_streams: + connections = self.grid_streams[stream_name] + elif stream_name in self.snake_streams: + connections = self.snake_streams[stream_name] + else: + raise ValueError(f"Stream name {stream_name} not found in grid_streams or snake_streams.") + + if operation_id == "CL_SUM": + current_op = '+' + elif operation_id == "CL_PROD": + current_op = '*' + else: + raise NotImplementedError("Currently only CL_SUM and CL_PROD are supported.") + + if stream_name in self.grid_streams: + pipelined_send = [] + pipelined_receive = [] + if not connections[0][4]: + # not pipelined + for con in connections: + if (current_position[0] >= con[1][0] + and current_position[1] <= con[1][1] + and current_position[2] >= con[1][2] + and current_position[3] <= con[1][3]): + + if (con[3] == 'left' and current_position[1] != con[1][1] + or con[3] == 'right' and current_position[0] != con[1][0] + or con[3] == 'top' and current_position[3] != con[1][3] + or con[3] == 'bottom' and current_position[2] != con[1][2]): + + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("reduce_receive")) + newstatements.append( + ForeachStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + parameter_range=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + stream_variable=TypedIdentifier(dtype=con[2].dtype, + identifier=self.versioning.next_version("reduce_receive")), + receive_stream=ReceiveGenerator(stream_name=con[0]), + body=[ + bin_op + ], + completion_name=None + ) + ) + + if (con[3] == 'left' and current_position[0] != con[1][0] + or con[3] == 'right' and current_position[1] != con[1][1] + or con[3] == 'top' and current_position[2] != con[1][2] + or con[3] == 'bottom' and current_position[3] != con[1][3]): + + newstatements.append( + SendStatement( + local_array=stmt.local_array, + stream_name=con[0], + completion_name=None + ) + ) + + + else: + # pipelined + for con_list in connections: + for con in con_list[1]: + if (current_position[0] >= con[0] and current_position[1] <= con[1] + and current_position[2] >= con[2] and current_position[3] <= con[3]): + if con[8] == 'sender': + pipelined_send.append(con_list[0]) + elif con[8] == 'receiver': + pipelined_receive.append(con_list[0]) + + if pipelined_send != [] and pipelined_receive != []: + # receive first then send + newstatements.append( + AssignmentStatement( + destination=self.versioning.next_version("pipeline_helper"), + source=Expression( + ConstantLiteral(0, ScalarType.i32) + ) + ) + ) + if len(pipelined_receive) == 1: + send = self.create_send_statement(stmt, pipelined_send, 0) + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + send + ], + ) + ) + elif len(pipelined_receive) == 2: + send = self.create_send_statement(stmt, pipelined_send, 0) + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + receive1 = self.create_receive_statement(stmt, pipelined_receive, 1) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + receive1, + bin_op, + send + ], + ) + ) + elif len(pipelined_receive) == 3: + send = self.create_send_statement(stmt, pipelined_send, 0) + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + receive1 = self.create_receive_statement(stmt, pipelined_receive, 1) + receive2 = self.create_receive_statement(stmt, pipelined_receive, 2) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + receive1, + bin_op, + receive2, + bin_op, + send + ], + ) + ) + else: + send = self.create_send_statement(stmt, pipelined_send, 0) + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + receive1 = self.create_receive_statement(stmt, pipelined_receive, 1) + receive2 = self.create_receive_statement(stmt, pipelined_receive, 2) + receive3 = self.create_receive_statement(stmt, pipelined_receive, 3) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + receive1, + bin_op, + receive2, + bin_op, + receive3, + bin_op, + send + ], + ) + ) + elif pipelined_send == [] and pipelined_receive != []: + # receive only + newstatements.append( + AssignmentStatement( + destination=self.versioning.next_version("pipeline_helper"), + source=Expression( + ConstantLiteral(0, ScalarType.i32) + ) + ) + ) + if len(pipelined_receive) == 1: + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op + ], + ) + ) + elif len(pipelined_receive) == 2: + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + receive1 = self.create_receive_statement(stmt, pipelined_receive, 1) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + receive1, + bin_op + ], + ) + ) + elif len(pipelined_receive) == 3: + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + receive1 = self.create_receive_statement(stmt, pipelined_receive, 1) + receive2 = self.create_receive_statement(stmt, pipelined_receive, 2) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + receive1, + bin_op, + receive2, + bin_op + ], + ) + ) + else: + receive0 = self.create_receive_statement(stmt, pipelined_receive, 0) + receive1 = self.create_receive_statement(stmt, pipelined_receive, 1) + receive2 = self.create_receive_statement(stmt, pipelined_receive, 2) + receive3 = self.create_receive_statement(stmt, pipelined_receive, 3) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + receive1, + bin_op, + receive2, + bin_op, + receive3, + bin_op + ], + ) + ) + elif pipelined_send != [] and pipelined_receive == []: + # send only + send = self.create_send_statement(stmt, pipelined_send, 0) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + send + ], + ) + ) + else: + raise ValueError(f"No pipelined send or receive found for position {current_position}.") + + + + elif stream_name in self.snake_streams: + if not (current_position[0] == origin[0] and current_position[2] == origin[1]): + # everything but the starting point receives first + + # get receive stream + receive_stream = None + for con in connections: + for detailed_con in con[5]: + if (current_position[0] >= detailed_con[0] and current_position[1] <= detailed_con[1] + and current_position[2] >= detailed_con[2] and current_position[3] <= detailed_con[3] + and ((detailed_con[4] == -1 and not current_position[1] == detailed_con[1]) + or (detailed_con[4] == 1 and not current_position[0] == detailed_con[0]) + or (detailed_con[4] == 0 and detailed_con[8] == 'receiver') + or (con[6] == True and detailed_con[8] == 'receiver'))): + receive_stream = con + break + + if not receive_stream == None: + break + + if operation_id == "CL_SUM": + current_op = '+' + elif operation_id == "CL_PROD": + current_op = '*' + else: + raise NotImplementedError("Currently only CL_SUM and CL_PROD are supported.") + + # change receive statement + + # not pipelined + if not con[6]: + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("reduce_receive")) + newstatements.append( + ForeachStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + parameter_range=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + stream_variable=TypedIdentifier(dtype=receive_stream[2].dtype, + identifier=self.versioning.next_version("reduce_receive")), + receive_stream=ReceiveGenerator(stream_name=receive_stream[0]), + body=[ + bin_op + ], + completion_name=None + ) + ) + + # pipelined root + elif (current_position[0] == root[0] and current_position[2] == root[1]): + newstatements.append( + AssignmentStatement( + destination=self.versioning.next_version("pipeline_helper"), + source=Expression( + ConstantLiteral(0, ScalarType.i32) + ) + ) + ) + + receive0 = self.create_receive_statement(stmt, receive_stream, 0) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op + ], + ) + ) + + if not (current_position[0] == root[0] and current_position[2] == root[1]): + # only root does not send + + # get send stream + send_stream = None + for con in connections: + for detailed_con in con[5]: + if (current_position[0] >= detailed_con[0] and current_position[1] <= detailed_con[1] + and current_position[2] >= detailed_con[2] and current_position[3] <= detailed_con[3] + and ((detailed_con[4] == 1 and not current_position[1] == detailed_con[1]) + or (detailed_con[4] == -1 and not current_position[0] == detailed_con[0]) + or (detailed_con[4] == 0 and detailed_con[8] == 'sender') + or (con[6] == True and detailed_con[8] == 'sender'))): + send_stream = con + break + + if not send_stream == None: + break + + # not pipelined + if not con[6]: + newstatements.append( + SendStatement( + local_array=stmt.local_array, + stream_name=send_stream[0], + completion_name=None + ) + ) + + # pipelined origin + elif (current_position[0] == origin[0] and current_position[2] == origin[1]): + send0 = self.create_send_statement(stmt, send_stream, 0) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + send0 + ], + ) + ) + + + # pipelined + else: + newstatements.append( + AssignmentStatement( + destination=self.versioning.next_version("pipeline_helper"), + source=Expression( + ConstantLiteral(0, ScalarType.i32) + ) + ) + ) + + send0 = self.create_send_statement(stmt, send_stream, 0) + receive0 = self.create_receive_statement(stmt, receive_stream, 0) + bin_op = self.create_binary_operation(stmt, current_op, self.versioning.current_version("pipeline_helper")) + newstatements.append( + ForStatement( + variables=[TypedIdentifier(dtype=ScalarType.i32, identifier=self.versioning.next_version("reduce_runner"))], + range_expression=[RangeExpression(start=Expression(ConstantLiteral(0, ScalarType.i32)), + stop=Expression(ConstantLiteral(send_amount, ScalarType.i32)), + step=None)], + body=[ + receive0, + bin_op, + send0 + ], + ) + ) + + return newstatements + + + + ## + # Function to recursively go through the compute blocks and find all the occurernces of the reduce statements + ## + def replace_stmt(self, stmt, elem, to_replace) -> list[Expression]: + input_stmt = stmt + if isinstance(stmt, to_replace): + if to_replace == ReduceStatement: + return self.replace_reduce(stmt, elem) + + # all of these use body + elif isinstance(stmt, ForeachStatement) or isinstance(stmt, ForStatement) or isinstance(stmt, MapStatement) or isinstance(stmt, AsyncBlock): + new_body = [] + for body_stmt in stmt.body: + replaced_stmts = self.replace_stmt(body_stmt, elem, to_replace) + for replaced_stmt in replaced_stmts: + new_body.append(replaced_stmt) + input_stmt.body = new_body + return [input_stmt] + + + + ## + # Defines the snake communication pattern + ## + def snake_communication_pattern(self, x_start, x_stop, y_start, y_stop, x, y, name, pipelined) -> None: + communication = [] + if y == y_start: + if (y_stop - y_start) % 2 != 0: + + # horizontal movement + if pipelined and x_stop - x_start > 2: + if (x_stop - x_start) % 2 != 0: + communication.append([x_start, x_stop - 1, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop - 1, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + else: + communication.append([x_start, x_stop, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + else: + if x_stop - x_start > 1: + communication.append([x_start, x_stop, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 1]) + communication.append([x_start, x_stop, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 1]) + + # vertical movement + # not dependent on pipelined as if we have a column it's already pipelined through the left and right edge being the same edge + if x == x_start: + # print('upper left corner odd') + if y_stop - y_start > 2: + communication.append([x_start, x_start + 1, y_start + 1, y_stop , 0, -1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_stop - 1, x_stop, y_start, y_stop - 1, 0, -1, 1, 1]) + if x == x_stop - 1: + # print('upper right corner odd') + if y_stop - y_start > 2: + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop, 0, -1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_start, x_start + 1, y_start, y_stop - 1, 0, -1, 1, 1]) + else: + + # horizontal movement + if pipelined and x_stop - x_start > 2: + if (x_stop - x_start) % 2 != 0: + communication.append([x_start, x_stop - 1, y_start, y_stop - 1, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start, y_stop - 1, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop - 1, y_start + 1, y_stop, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start + 1, y_stop, 1 if x == x_start else -1, 0, 1, 2]) + else: + communication.append([x_start, x_stop, y_start, y_stop - 1, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start, y_stop - 1, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop, y_start + 1, y_stop, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start + 1, y_stop, 1 if x == x_start else -1, 0, 1, 2]) + else: + if x_stop - x_start > 1: + communication.append([x_start, x_stop, y_start, y_stop - 1, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop, y_start + 1, y_stop, 1 if x == x_start else -1, 0, 1, 2]) + + # vertical movement + # not dependent on pipelined as if we have a column it's already pipelined + if x == x_start: + # print('upper left corner even') + if y_stop - y_start > 2: + communication.append([x_start, x_start + 1, y_start + 1, y_stop - 1, 0, -1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_stop - 1, x_stop, y_start, y_stop, 0, -1, 1, 1]) + if x == x_stop - 1: + # print('upper right corner even') + if y_stop - y_start > 2: + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop - 1, 0, -1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_start, x_start + 1, y_start, y_stop, 0, -1, 1, 1]) + + elif y == y_stop - 1: + if (y_stop - y_start) % 2 != 0: + + # horizontal movement + if pipelined and x_stop - x_start > 2: + if (x_stop - x_start) % 2 != 0: + communication.append([x_start, x_stop - 1, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop - 1, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + else: + communication.append([x_start, x_stop, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + else: + if x_stop - x_start > 1: + communication.append([x_start, x_stop, y_start, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop, y_start + 1, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + + # vertical movement + # not dependent on pipelined as if we have a column it's already pipelined + if x == x_start: + # print('lower left corner odd') + if y_stop - y_start > 2: + communication.append([x_start, x_start + 1, y_start, y_stop - 1, 0, 1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop, 0, 1, 1, 1]) + if x == x_stop - 1: + # print('lower right corner odd') + if y_stop - y_start > 2: + communication.append([x_stop - 1, x_stop, y_start, y_stop - 1, 0, 1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_start, x_start + 1, y_start + 1, y_stop, 0, 1, 1, 1]) + else: + + # horizontal movement + if pipelined and x_stop - x_start > 1: + if (x_stop - x_start) % 2 != 0: + communication.append([x_start, x_stop - 1, y_start + 1, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start + 1, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop - 1, y_start, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop, y_start, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + else: + communication.append([x_start, x_stop, y_start + 1, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start + 1, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop, y_start, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + communication.append([x_start + 1, x_stop - 1, y_start, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + + else: + communication.append([x_start, x_stop, y_start + 1, y_stop, -1 if x == x_start else 1, 0, 1, 2]) + communication.append([x_start, x_stop, y_start, y_stop - 1, 1 if x == x_start else -1, 0, 1, 2]) + + # vertical movement + # not dependent on pipelined as if we have a column it's already pipelined + if x == x_start: + # print('lower left corner even') + if y_stop - y_start > 2: + communication.append([x_start, x_start + 1, y_start + 1, y_stop - 1, 0, 1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_stop - 1, x_stop, y_start, y_stop, 0, 1, 1, 1]) + if x == x_stop - 1: + # print('lower right corner even') + if y_stop - y_start > 2: + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop - 1, 0, 1, 1, 1]) + if y_stop - y_start > 1: + communication.append([x_start, x_start + 1, y_start, y_stop, 0, 1, 1, 1]) + else: + raise NotImplementedError("Only the corners are implemented for 'snake'") + + self.snake_streams.update({name: communication}) + + + ## + # Defines the grid communication pattern + ## + def grid_communication_pattern(self, x_start, x_stop, y_start, y_stop, x, y, name, pipelined) -> None: + communication = [] + if x == x_start: + # horizontal movement + if x_start == x_stop - 1: + # print('no horizontal movement needed') + pass + elif x_stop - x_start == 2: + # print('right to left') + communication.append([x_start, x_stop, y_start, y_stop, -1, 0, 1, 1]) + else: + if pipelined: + if (x_stop - x_start) % 2 == 0: + communication.append([x_start, x_stop, y_start, y_stop, -1, 0, 1, 1]) + communication.append([x_start + 1, x_stop - 1, y_start, y_stop, -1, 0, 1, 1]) + else: + communication.append([x_start, x_stop - 1, y_start, y_stop, -1, 0, 1, 1]) + communication.append([x_start + 1, x_stop, y_start, y_stop, -1, 0, 1, 1]) + else: + communication.append([x_start, x_stop, y_start, y_stop, -1, 0, 1, 1]) + # TODO add steps for pipelined communication from here + + # vertical movement + if y_start == y_stop - 1: + # print('no vertical movement needed') + pass + elif y == y_start: + # print('upper left corner') + if not pipelined or y_stop - y_start <= 2: + communication.append([x_start, x_start + 1, y_start, y_stop, 0, -1, 1, 1]) + else: + if (y_stop - y_start) % 2 == 0: + communication.append([x_start, x_start + 1, y_start, y_stop, 0, -1, 1, 1]) + communication.append([x_start, x_start + 1, y_start + 1, y_stop - 1, 0, -1, 1, 1]) + else: + communication.append([x_start, x_start + 1, y_start, y_stop - 1, 0, -1, 1, 1]) + communication.append([x_start, x_start + 1, y_start + 1, y_stop, 0, -1, 1, 1]) + elif y == y_stop - 1: + # print('lower left corner') + if not pipelined or y_stop - y_start <= 2: + communication.append([x_start, x_start + 1, y_start, y_stop, 0, 1, 1, 1]) + else: + if (y_stop - y_start) % 2 == 0: + communication.append([x_start, x_start + 1, y_start, y_stop, 0, 1, 1, 1]) + communication.append([x_start, x_start + 1, y_start + 1, y_stop - 1, 0, 1, 1, 1]) + else: + communication.append([x_start, x_start + 1, y_start, y_stop - 1, 0, 1, 1, 1]) + communication.append([x_start, x_start + 1, y_start + 1, y_stop, 0, 1, 1, 1]) + else: + # print('left edge') + if not pipelined: + communication.append([x_start, x_start + 1, y_start, y + 1, 0, 1, 1, 1]) + communication.append([x_start, x_start + 1, y, y_stop, 0, -1, 1, 1]) + else: + # upper part + if (y - y_start) >= 2: # y is inclusive while y_stop is exclusive + if (y - y_start) % 2 == 0: + communication.append([x_start, x_start + 1, y_start, y, 0, 1, 1, 1]) + communication.append([x_start, x_start + 1, y_start + 1, y + 1, 0, 1, 1, 1]) + else: + communication.append([x_start, x_start + 1, y_start, y + 1, 0, 1, 1, 1]) + communication.append([x_start, x_start + 1, y_start + 1, y, 0, 1, 1, 1]) + else: + communication.append([x_start, x_start + 1, y_start, y + 1, 0, 1, 1, 1]) + + # lower part + if (y_stop - y) > 2: + if (y_stop - y) % 2 == 0: + communication.append([x_start, x_start + 1, y, y_stop, 0, -1, 1, 1]) + communication.append([x_start, x_start + 1, y + 1, y_stop - 1, 0, -1, 1, 1]) + else: + communication.append([x_start, x_start + 1, y, y_stop - 1, 0, -1, 1, 1]) + communication.append([x_start, x_start + 1, y + 1, y_stop, 0, -1, 1, 1]) + else: + communication.append([x_start, x_start + 1, y, y_stop, 0, -1, 1, 1]) + + + + elif x == x_stop - 1: + # horizontal movement + if x_start == x_stop - 1: + # print('no horizontal movement needed') + pass + elif x_stop - x_start == 2: + # print('left to right') + communication.append([x_start, x_stop, y_start, y_stop, 1, 0, 1, 1]) + else: + # print('left to right') + if pipelined: + if (x_stop - x_start) % 2 == 0: + communication.append([x_start, x_stop, y_start, y_stop, 1, 0, 1, 1]) + communication.append([x_start + 1, x_stop - 1, y_start, y_stop, 1, 0, 1, 1]) + else: + communication.append([x_start, x_stop - 1, y_start, y_stop, 1, 0, 1, 1]) + communication.append([x_start + 1, x_stop, y_start, y_stop, 1, 0, 1, 1]) + else: + communication.append([x_start, x_stop, y_start, y_stop, 1, 0, 1, 1]) + + # vertical movement + if y_start == y_stop - 1: + # print('no vertical movement needed') + pass + elif y == y_start: + # print('upper right corner') + if not pipelined or y_stop - y_start <= 2: + communication.append([x_stop - 1, x_stop, y_start, y_stop, 0, -1, 1, 1]) + else: + if (y_stop - y_start) % 2 == 0: + communication.append([x_stop - 1, x_stop, y_start, y_stop, 0, -1, 1, 1]) + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop - 1, 0, -1, 1, 1]) + else: + communication.append([x_stop - 1, x_stop, y_start, y_stop - 1, 0, -1, 1, 1]) + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop, 0, -1, 1, 1]) + elif y == y_stop - 1: + # print('lower right corner') + if not pipelined or y_stop - y_start <= 2: + communication.append([x_stop - 1, x_stop, y_start, y_stop, 0, 1, 1, 1]) + else: + if (y_stop - y_start) % 2 == 0: + communication.append([x_stop - 1, x_stop, y_start, y_stop, 0, 1, 1, 1]) + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop - 1, 0, 1, 1, 1]) + else: + communication.append([x_stop - 1, x_stop, y_start, y_stop - 1, 0, 1, 1, 1]) + communication.append([x_stop - 1, x_stop, y_start + 1, y_stop, 0, 1, 1, 1]) + else: + # print('right edge') + if not pipelined: + communication.append([x_stop - 1, x_stop, y_start, y + 1, 0, 1, 1, 1]) + communication.append([x_stop - 1, x_stop, y, y_stop, 0, -1, 1, 1]) + else: + # upper part + if (y - y_start) >= 2: # y is inclusive while y_stop is exclusive + if (y - y_start) % 2 == 0: + communication.append([x_stop - 1, x_stop, y_start, y, 0, 1, 1, 1]) + communication.append([x_stop - 1, x_stop, y_start + 1, y + 1, 0, 1, 1, 1]) + else: + communication.append([x_stop - 1, x_stop, y_start, y + 1, 0, 1, 1, 1]) + communication.append([x_stop - 1, x_stop, y_start + 1, y, 0, 1, 1, 1]) + else: + communication.append([x_stop - 1, x_stop, y_start, y + 1, 0, 1, 1, 1]) + + # lower part + if (y_stop - y) > 2: + if (y_stop - y) % 2 == 0: + communication.append([x_stop - 1, x_stop, y, y_stop, 0, -1, 1, 1]) + communication.append([x_stop - 1, x_stop, y + 1, y_stop - 1, 0, -1, 1, 1]) + else: + communication.append([x_stop - 1, x_stop, y, y_stop - 1, 0, -1, 1, 1]) + communication.append([x_stop - 1, x_stop, y + 1, y_stop, 0, -1, 1, 1]) + else: + communication.append([x_stop - 1, x_stop, y, y_stop, 0, -1, 1, 1]) + + else: + # horizontal movement + # print('middle') + if not pipelined: + communication.append([x_start, x + 1, y_start, y_stop, 1, 0, 1, 1]) # left to middle + communication.append([x, x_stop, y_start, y_stop, -1, 0, 1, 1]) # right to middle + else: + # left + if (x - x_start) >= 2: # x is inclusive while x_stop is exclusive + if (x - x_start) % 2 == 0: + communication.append([x_start, x, y_start, y_stop, 1, 0, 1, 1]) + communication.append([x_start + 1, x + 1, y_start, y_stop, 1, 0, 1, 1]) + else: + communication.append([x_start, x + 1, y_start, y_stop, 1, 0, 1, 1]) + communication.append([x_start + 1, x, y_start, y_stop, 1, 0, 1, 1]) + else: + communication.append([x_start, x + 1, y_start, y_stop, 1, 0, 1, 1]) + + # right + if (x_stop - x) > 2: + if (x_stop - x) % 2 == 0: + communication.append([x, x_stop, y_start, y_stop, -1, 0, 1, 1]) + communication.append([x + 1, x_stop - 1, y_start, y_stop, -1, 0, 1, 1]) + else: + communication.append([x, x_stop - 1, y_start, y_stop, -1, 0, 1, 1]) + communication.append([x + 1, x_stop, y_start, y_stop, -1, 0, 1, 1]) + else: + communication.append([x, x_stop, y_start, y_stop, -1, 0, 1, 1]) + + # vertical movement + if y_start == y_stop - 1: + # print('no vertical movement needed') + pass + elif y == y_start: + # print('upper edge') + if not pipelined or y_stop - y_start <= 2: + communication.append([x, x + 1, y_start, y_stop, 0, -1, 1, 1]) + else: + if (y_stop - y_start) % 2 == 0: + communication.append([x, x + 1, y_start, y_stop, 0, -1, 1, 1]) + communication.append([x, x + 1, y_start + 1, y_stop - 1, 0, -1, 1, 1]) + else: + communication.append([x, x + 1, y_start, y_stop - 1, 0, -1, 1, 1]) + communication.append([x, x + 1, y_start + 1, y_stop, 0, -1, 1, 1]) + elif y == y_stop - 1: + # print('lower edge') + if not pipelined or y_stop - y_start <= 2: + communication.append([x, x + 1, y_start, y_stop, 0, 1, 1, 1]) + else: + if (y_stop - y_start) % 2 == 0: + communication.append([x, x + 1, y_start, y_stop, 0, 1, 1, 1]) + communication.append([x, x + 1, y_start + 1, y_stop - 1, 0, 1, 1, 1]) + else: + communication.append([x, x + 1, y_start, y_stop - 1, 0, 1, 1, 1]) + communication.append([x, x + 1, y_start + 1, y_stop, 0, 1, 1, 1]) + else: + # print('center') + if not pipelined: + communication.append([x, x + 1, y_start, y + 1, 0, 1, 1, 1]) + communication.append([x, x + 1, y, y_stop, 0, -1, 1, 1]) + else: + # upper part + if (y - y_start) >= 2: # y is inclusive while y_stop is exclusive + if (y - y_start) % 2 == 0: + communication.append([x, x + 1, y_start, y, 0, 1, 1, 1]) + communication.append([x, x + 1, y_start + 1, y + 1, 0, 1, 1, 1]) + else: + communication.append([x, x + 1, y_start, y + 1, 0, 1, 1, 1]) + communication.append([x, x + 1, y_start + 1, y, 0, 1, 1, 1]) + else: + communication.append([x, x + 1, y_start, y + 1, 0, 1, 1, 1]) + + # lower part + if (y_stop - y) > 2: + if (y_stop - y) % 2 == 0: + communication.append([x, x + 1, y, y_stop, 0, -1, 1, 1]) + communication.append([x, x + 1, y + 1, y_stop - 1, 0, -1, 1, 1]) + else: + communication.append([x, x + 1, y, y_stop - 1, 0, -1, 1, 1]) + communication.append([x, x + 1, y + 1, y_stop, 0, -1, 1, 1]) + else: + communication.append([x, x + 1, y, y_stop, 0, -1, 1, 1]) + + self.grid_streams.update({name : communication}) + + + ## + # Creates the communication patterns for the reduce operation (snake or grid) + ## + def create_communication_patterns(self, x_start, x_stop, y_start, y_stop, x, y, name, algorithm, pipelined) -> None: + if x < x_start or x >= x_stop or y < y_start or y >= y_stop: + if x == x_stop or y == y_stop: + raise ValueError(f"The communication point (x, y) = ({x}, {y}) is not within the subgrid" + + f"[x_start, x_stop, y_start, y_stop] = [{x_start}, {x_stop}, {y_start}, {y_stop}] for the operation {name}." + + f" Remember that the stop value is exclusive.") + raise ValueError(f"The communication point (x, y) = ({x}, {y}) is not within the subgrid" + + f"[x_start, x_stop, y_start, y_stop] = [{x_start}, {x_stop}, {y_start}, {y_stop}] for the operation {name}.") + communication = [] + self.pipelined.update({name : False}) + mode = algorithm if algorithm != 'auto' else 'grid' + if mode == 'snake': + self.snake_communication_pattern(x_start, x_stop, y_start, y_stop, x, y, name, pipelined) + + elif mode == 'grid': + self.grid_communication_pattern(x_start, x_stop, y_start, y_stop, x, y, name, pipelined) + + else: + raise NotImplementedError(f"Communication mode '{mode}' is not implemented.") + + return None + + + ## + # Updates the datablocks with the new communication patterns + ## + def change_data_blocks(self) -> None: + newbody = [] + self.reduce_operations = {} + for elem in self.body: + if isinstance(elem, DataflowBlock): + olddataflobblock = [] + newdataflobblocks = [] + for stmt in elem.statements: + if isinstance(stmt, MulStreamDeclaration) and isinstance(stmt.routing, ReduceRoutingDeclaration): + self.create_communication_patterns(elem.subgrid.x_range.start.value.value, + elem.subgrid.x_range.stop.value.value, + elem.subgrid.y_range.start.value.value, + elem.subgrid.y_range.stop.value.value, + stmt.x.value.value, + stmt.y.value.value, + stmt.stream_name.name, + stmt.routing.algorithm, + stmt.routing.pipelined) + + self.reduce_operations.update({stmt.stream_name.name: [{'op': stmt.routing.op}, [stmt.x.value.value, stmt.y.value.value], + [elem.subgrid.x_range.start.value.value, elem.subgrid.x_range.stop.value.value], + [elem.subgrid.y_range.start.value.value, elem.subgrid.y_range.stop.value.value], + [stmt.x.value.value if (elem.subgrid.y_range.stop.value.value - elem.subgrid.y_range.start.value.value) % 2 == 0 else (elem.subgrid.x_range.stop.value.value - stmt.x.value.value - 1), + elem.subgrid.y_range.stop.value.value - 1 if elem.subgrid.y_range.start.value.value == stmt.y.value.value else elem.subgrid.y_range.start.value.value], + None, + None]}) + new_grid_streams = [] + new_snake_streams = [] + + if stmt.stream_name.name in self.grid_streams: + current_grid_streams = self.grid_streams[stmt.stream_name.name] + elif stmt.stream_name.name in self.snake_streams: + current_grid_streams = self.snake_streams[stmt.stream_name.name] + + # create intermediate datastructure to express all communication + for com in current_grid_streams: + newdataflobblocks.append([[com[0], com[1]], [com[2], com[3]], + RelativeStreamDeclaration( + dtype=StreamType(stmt.dtype.dtype), + stream_name=self.versioning.next_version("reduce"), + dx=Expression(ConstantLiteral(com[4], ScalarType.i32)), + dy=Expression(ConstantLiteral(com[5], ScalarType.i32)) + ), + [com[6], com[7]]], + ) + if stmt.stream_name.name in self.grid_streams: + if not stmt.routing.pipelined: + if com[4] == -1: + new_grid_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'left', stmt.routing.pipelined]) + elif com[4] == 1: + new_grid_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'right', stmt.routing.pipelined]) + elif com[5] == -1: + new_grid_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'top', stmt.routing.pipelined]) + elif com[5] == 1: + new_grid_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'bottom', stmt.routing.pipelined]) + else: + if com[4] == -1: + unrolled_com = [] + for i in range(com[1], com[0], -1): + if i % 2 == com[1] % 2: + for j in range(com[2], com[3]): + unrolled_com.append([i-1, i, j, j + 1, com[4], com[5], com[6], com[7], 'sender']) + else: + for j in range(com[2], com[3]): + unrolled_com.append([i-1, i, j, j + 1, com[4], com[5], com[6], com[7], 'receiver']) + new_grid_streams.append([self.versioning.current_version("reduce"), unrolled_com, StreamType(stmt.dtype.dtype), 'left', stmt.routing.pipelined]) + elif com[4] == 1: + unrolled_com = [] + for i in range(com[0], com[1]): + if i % 2 == com[0] % 2: + for j in range(com[2], com[3]): + unrolled_com.append([i, i+1, j, j + 1, com[4], com[5], com[6], com[7], 'sender']) + else: + for j in range(com[2], com[3]): + unrolled_com.append([i, i+1, j, j + 1, com[4], com[5], com[6], com[7], 'receiver']) + new_grid_streams.append([self.versioning.current_version("reduce"), unrolled_com, StreamType(stmt.dtype.dtype), 'right', stmt.routing.pipelined]) + elif com[5] == -1: + unrolled_com = [] + for i in range(com[3], com[2], -1): + if i % 2 == com[3] % 2: + unrolled_com.append([com[0], com[1], i-1, i, com[4], com[5], com[6], com[7], 'sender']) + else: + unrolled_com.append([com[0], com[1], i-1, i, com[4], com[5], com[6], com[7], 'receiver']) + new_grid_streams.append([self.versioning.current_version("reduce"), unrolled_com, StreamType(stmt.dtype.dtype), 'top', stmt.routing.pipelined]) + elif com[5] == 1: + unrolled_com = [] + for i in range(com[2], com[3]): + if i % 2 == com[2] % 2: + unrolled_com.append([com[0], com[1], i, i+1, com[4], com[5], com[6], com[7], 'sender']) + else: + unrolled_com.append([com[0], com[1], i, i+1, com[4], com[5], com[6], com[7], 'receiver']) + new_grid_streams.append([self.versioning.current_version("reduce"), unrolled_com, StreamType(stmt.dtype.dtype), 'bottom', stmt.routing.pipelined]) + elif stmt.stream_name.name in self.snake_streams: + if com[4] == -1: + unrolled_com = [] + for i in range(com[2], com[3]): + if (i - com[2]) % com[7] == 0: + if not stmt.routing.pipelined: + unrolled_com.append([com[0], com[1], i, i+1, com[4], com[5], com[6], com[7]]) + else: + for j in range(com[1], com[0], -1): + if j % 2 == com[1] % 2: + unrolled_com.append([j-1, j, i, i+1, com[4], com[5], com[6], com[7], 'sender']) + else: + unrolled_com.append([j-1, j, i, i+1, com[4], com[5], com[6], com[7], 'receiver']) + new_snake_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'left', 'horizontal', unrolled_com, stmt.routing.pipelined]) + elif com[4] == 1: + unrolled_com = [] + for i in range(com[2], com[3]): + if (i - com[2]) % com[7] == 0: + if not stmt.routing.pipelined: + unrolled_com.append([com[0], com[1], i, i+1, com[4], com[5], com[6], com[7]]) + else: + for j in range(com[0], com[1]): + if j % 2 == com[0] % 2: + unrolled_com.append([j, j+1, i, i+1, com[4], com[5], com[6], com[7], 'sender']) + else: + unrolled_com.append([j, j+1, i, i+1, com[4], com[5], com[6], com[7], 'receiver']) + new_snake_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'right', 'horizontal', unrolled_com, stmt.routing.pipelined]) + elif com[5] == -1: + unrolled_com = [] + for i in range(com[3], com[2], -1): + if i % 2 == com[3] % 2: + unrolled_com.append([com[0], com[1], i-1, i, com[4], com[5], com[6], com[7], 'sender']) + else: + unrolled_com.append([com[0], com[1], i-1, i, com[4], com[5], com[6], com[7], 'receiver']) + new_snake_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'top', 'vertical', unrolled_com, stmt.routing.pipelined]) + elif com[5] == 1: + unrolled_com = [] + for i in range(com[2], com[3]): + if i % 2 == com[2] % 2: + unrolled_com.append([com[0], com[1], i, i+1, com[4], com[5], com[6], com[7], 'sender']) + else: + unrolled_com.append([com[0], com[1], i, i+1, com[4], com[5], com[6], com[7], 'receiver']) + new_snake_streams.append([self.versioning.current_version("reduce"), com, StreamType(stmt.dtype.dtype), 'bottom', 'vertical', unrolled_com, stmt.routing.pipelined]) + + if stmt.stream_name.name in self.grid_streams: + self.grid_streams.update({stmt.stream_name.name: new_grid_streams}) + elif stmt.stream_name.name in self.snake_streams: + self.snake_streams.update({stmt.stream_name.name: new_snake_streams}) + + else: + olddataflobblock.append(stmt) + + if olddataflobblock != []: + newbody.append(DataflowBlock(variables=elem.variables, subgrid=elem.subgrid, statements=olddataflobblock)) + for newdataflobblock in newdataflobblocks: + newbody.append( + DataflowBlock( + variables=elem.variables, + subgrid=SubgridExpression( + x_range=RangeExpression( + start=Expression( + ConstantLiteral(newdataflobblock[0][0], ScalarType.i32) + ), + stop=Expression( + ConstantLiteral(newdataflobblock[0][1], ScalarType.i32) + ), + step=Expression( + ConstantLiteral(newdataflobblock[3][0], ScalarType.i32) + ) + ), + y_range=RangeExpression( + start=Expression( + ConstantLiteral(newdataflobblock[1][0], ScalarType.i32) + ), + stop=Expression( + ConstantLiteral(newdataflobblock[1][1], ScalarType.i32) + ), + step=Expression( + ConstantLiteral(newdataflobblock[3][1], ScalarType.i32) + ) + ), + ), + statements=[newdataflobblock[2]])) + else: + newbody.append(elem) + self.body = newbody + + + ## + # Updates the compute block tiling for the new communication patterns + ## + def fix_subgrid(self) -> None: + newbody = [] + + for elem in self.body: + if isinstance(elem, ComputeBlock): + x_start = elem.subgrid.x_range.start.value.value + x_stop = elem.subgrid.x_range.stop.value.value + y_start = elem.subgrid.y_range.start.value.value + y_stop = elem.subgrid.y_range.stop.value.value + grid = [[[x_start, x_stop], [y_start, y_stop]]] + + for stmt in elem.statements: + red_stmt = [] + nodes = [stmt] + + # get all the reduce statements + while len(nodes) > 0: + for intermediate_stmt in nodes: + if isinstance(intermediate_stmt, ForeachStatement) or isinstance(intermediate_stmt, ForStatement) or isinstance(intermediate_stmt, MapStatement) or isinstance(intermediate_stmt, AsyncBlock): + for element in intermediate_stmt.body: + nodes.append(element) + elif isinstance(intermediate_stmt, ReduceStatement): + red_stmt.append(intermediate_stmt) + nodes.remove(intermediate_stmt) + + if isinstance(stmt, ReduceStatement): + red_stmt.append(stmt) + + for stmt in red_stmt: + stream_name = stmt.stream_name.name + + if self.reduce_operations[stream_name][5] == None: + for tst in stmt.iter_child_nodes(): + self.reduce_operations[stream_name][5] = tst + break + + if stmt.stream_name.name in self.grid_streams: + connections = self.grid_streams[stream_name] + + if not connections[0][4]: + #not pipelined + reduce_connections = [] + send_connections = [] + for con in connections: + if con[3] == 'left': + send_connections.append([con[1][1] - 1, con[1][1], con[1][2], con[1][3]]) + elif con[3] == 'right': + send_connections.append([con[1][0], con[1][0] + 1, con[1][2], con[1][3]]) + elif con[3] == 'top': + send_connections.append([con[1][0], con[1][1], con[1][3] - 1, con[1][3]]) + elif con[3] == 'bottom': + send_connections.append([con[1][0], con[1][1], con[1][2], con[1][2] + 1]) + reduce_connections.append(con[1]) + root = self.reduce_operations[stmt.stream_name.name][1] + for send in send_connections: + reduce_connections.append(send) + + reduce_connections.append([root[0], root[0] + 1, root[1], root[1] + 1]) + + for com_grid in reduce_connections: + to_remove = [] + for sub_grid in grid: + if com_grid[0] > sub_grid[0][0] and com_grid[0] < sub_grid[0][1]: + # print("left") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, com_grid[0]], [sub_y_start, sub_y_stop]]) + grid.append([[com_grid[0], sub_x_stop], [sub_y_start, sub_y_stop]]) + to_remove.append(sub_grid) + elif com_grid[1] > sub_grid[0][0] and com_grid[1] < sub_grid[0][1]: + # print("right") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, com_grid[1]], [sub_y_start, sub_y_stop]]) + grid.append([[com_grid[1], sub_x_stop], [sub_y_start, sub_y_stop]]) + to_remove.append(sub_grid) + elif com_grid[2] > sub_grid[1][0] and com_grid[2] < sub_grid[1][1] and com_grid[0] <= sub_grid[0][0] and com_grid[1] >= sub_grid[0][1]: + # print("top") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, sub_x_stop], [sub_y_start, com_grid[2]]]) + grid.append([[sub_x_start, sub_x_stop], [com_grid[2], sub_y_stop]]) + to_remove.append(sub_grid) + elif com_grid[3] > sub_grid[1][0] and com_grid[3] < sub_grid[1][1] and com_grid[0] <= sub_grid[0][0] and com_grid[1] >= sub_grid[0][1]: + # print("bottom") + sub_x_start = sub_grid[0][0] + sub_x_stop = sub_grid[0][1] + sub_y_start = sub_grid[1][0] + sub_y_stop = sub_grid[1][1] + grid.append([[sub_x_start, sub_x_stop], [sub_y_start, com_grid[3]]]) + grid.append([[sub_x_start, sub_x_stop], [com_grid[3], sub_y_stop]]) + to_remove.append(sub_grid) + # delete old unused + for rmv in to_remove: + grid.remove(rmv) + else: + #pipelined + new_grid = [] + for i in range(x_start, x_stop): + for j in range(y_start, y_stop): + new_grid.append([[i, i + 1], [j, j + 1]]) + grid = new_grid + + if self.snake_streams != {}: + new_grid = [] + complete_grid = [] + pipelined = False + + for name in self.snake_streams: + complete_grid = [self.reduce_operations[name][2], self.reduce_operations[name][3]] + pipelined = self.snake_streams[name][0][6] + + + if not pipelined: + + list_grid = [[x] for x in grid] + + for com_grid in list_grid: + to_remove = [] + for com in com_grid: + if com[0][0] == complete_grid[0][0] and com[0][1] != complete_grid[0][0] + 1: + # print("left") + com_grid.append([[complete_grid[0][0], complete_grid[0][0] + 1], [com[1][0], com[1][1]]]) + com_grid.append([[complete_grid[0][0] + 1, com[0][1]], [com[1][0], com[1][1]]]) + to_remove.append(com) + elif com[0][1] == complete_grid[0][1] and com[0][0] != complete_grid[0][1] - 1: + # print("right") + com_grid.append([[complete_grid[0][1] - 1, complete_grid[0][1]], [com[1][0], com[1][1]]]) + com_grid.append([[com[0][0], complete_grid[0][1] - 1], [com[1][0], com[1][1]]]) + to_remove.append(com) + elif com[1][1] - com[1][0] != 1: + # print('multiple rows') + for i in range(com[1][0], com[1][1]): + com_grid.append([[com[0][0], com[0][1]], [i, i + 1]]) + to_remove.append(com) + + for rmv in to_remove: + com_grid.remove(rmv) + + for com in com_grid: + new_grid.append(com) + + else: + for i in range(x_start, x_stop): + for j in range(y_start, y_stop): + new_grid.append([[i, i + 1], [j, j + 1]]) + + grid = new_grid + + + for com_grid in grid: + newbody.append( + ComputeBlock( + elem.variables, + SubgridExpression( + RangeExpression( + start=Expression(ConstantLiteral(com_grid[0][0], ScalarType.i32)), + stop=Expression(ConstantLiteral(com_grid[0][1], ScalarType.i32)) + ), + RangeExpression( + start=Expression(ConstantLiteral(com_grid[1][0], ScalarType.i32)), + stop=Expression(ConstantLiteral(com_grid[1][1], ScalarType.i32)) + ) + ), + elem.statements + ) + ) + else: + newbody.append(elem) + + self.body = newbody + return None + + + ## + # Changes the occurences of reduce statements in the compute blocks + ## + def change_compute_blocks(self) -> None: + finalbody = [] + for elem in self.body: + if isinstance(elem, ComputeBlock): + statements = [] + for stmt in elem.statements: + new_stmts = self.replace_stmt(stmt, elem, ReduceStatement) + for nstmt in new_stmts: + statements.append(nstmt) + finalbody.append(ComputeBlock(elem.variables, elem.subgrid, statements)) + else: + finalbody.append(elem) + + self.body = finalbody + + return None \ No newline at end of file diff --git a/spatialstencil/syntax/spatial_ir/irnodes.py b/spatialstencil/syntax/spatial_ir/irnodes.py index de21e956..0637c638 100644 --- a/spatialstencil/syntax/spatial_ir/irnodes.py +++ b/spatialstencil/syntax/spatial_ir/irnodes.py @@ -5,6 +5,8 @@ from spatialstencil.syntax.common.types import ScalarType, IRType from spatialstencil.syntax.spatial_ir.grid_geometry import Rectangle +from lark import Tree + @dataclass class SpatialNode(BaseNode): @@ -39,7 +41,7 @@ def validate(self) -> None: def as_ir(self, indent: int = 0) -> str: return str(self.value) - + # Parameters @dataclass @@ -91,6 +93,21 @@ def validate(self) -> None: def as_ir(self, indent: int = 0) -> str: return f'stream<{self.dtype.as_ir()}>' + + + +@dataclass +class MultiStreamType(SpatialNode, IRType): + """ + A multistream type that handles collective communication patterns. + """ + dtype: ScalarType + + def validate(self) -> None: + assert isinstance(self.dtype, ScalarType) + + def as_ir(self, indent: int = 0) -> str: + return f'multistream<{self.dtype.as_ir()}>' @@ -100,7 +117,7 @@ class ArrayType(SpatialNode, IRType): """ An array type of a scalar or stream, with one or more dimensions. """ - base_type: Union[ScalarType, StreamType] + base_type: Union[ScalarType, StreamType, MultiStreamType] shape: list[Union[int, 'Expression']] def validate(self) -> None: @@ -118,7 +135,7 @@ class TypedIdentifier(SpatialNode): """ A variable identifier (e.g., x, y, my_variable) with a type. """ - dtype: Union[ScalarType, StreamType, ArrayType] + dtype: Union[ScalarType, StreamType, MultiStreamType, ArrayType] identifier: Identifier def validate(self) -> None: @@ -383,12 +400,54 @@ def validate(self) -> None: for r in self.hops: dx, dy = r.offset assert abs(dx) + abs(dy) == 1, "Each hop must have an absolute sum of 1." + if isinstance(self.hops, Tree): + self.hops = self.hops.data + + # this doesn't work for self.channel != "auto" def as_ir(self, indent: int = 0) -> str: indent_str = ' ' * indent hops_str = "auto" if self.hops == "auto" else f"[{', '.join(hop.as_ir() for hop in self.hops)}]" channel_str = "auto" if self.channel == "auto" else str(self.channel) return f"{indent_str}hops = {hops_str},\n{indent_str}channel = {channel_str}" + + +@dataclass +class BroadcastRoutingDeclaration(SpatialNode): + """ + A routing declaration for a stream, optionally specifying hops and channel. + """ + channels: Union[int, Literal["auto"]] = "auto" # Channel ID or 'auto' + + def validate(self) -> None: + # this doesn't work for self.channel != "auto" - check this + if isinstance(self.channels, Tree): + self.channels = self.channels.data + + def as_ir(self, indent: int = 0) -> str: + indent_str = ' ' * indent + channels_str = "auto" if self.channels == "auto" else str(self.channels) + return f"{indent_str}channels = {channels_str}" + + +@dataclass +class ReduceRoutingDeclaration(SpatialNode): + """ + A routing declaration for a reduce, optionally specifying hops and channel. + """ + algorithm: str = '' + op: str = '' + pipelined: bool = False + + def validate(self) -> None: + assert isinstance(self.algorithm, str) + assert isinstance(self.op, str) + assert isinstance(self.pipelined, bool) + + + def as_ir(self, indent: int = 0) -> str: + indent_str = ' ' * indent + return f"{indent_str}algorithm = {self.algorithm},\n{indent_str}op = {self.op},\n{indent_str}pipelined = {self.pipelined}" @dataclass @@ -419,6 +478,38 @@ def as_ir(self, indent: int = 0) -> str: return f'{indent_str}stream<{self.dtype.dtype.as_ir()}> {self.stream_name.as_ir()} = relative_stream({self.dx.as_ir()}, {self.dy.as_ir()}){routing_str}' +@dataclass +class MulStreamDeclaration(SpatialNode): + """ + A stream declaration inside a dataflow block that declares a communication stream + to and from PEs at relative positions, with an optional routing declaration. + """ + dtype: MultiStreamType + stream_name: Identifier + x: Expression + y: Expression + routing: Optional[Union[BroadcastRoutingDeclaration, ReduceRoutingDeclaration]] = None + + def validate(self) -> None: + assert isinstance(self.dtype, MultiStreamType) + assert isinstance(self.stream_name, Identifier) + assert isinstance(self.x, Expression) + assert isinstance(self.y, Expression) + if self.routing: + assert isinstance(self.routing, Union[BroadcastRoutingDeclaration, ReduceRoutingDeclaration]) + + def as_ir(self, indent: int = 0) -> str: + indent_str = ' ' * indent + routing_str = "" + if self.routing: + routing_str = f" {{\n{self.routing.as_ir(indent + 1)}\n{' ' * indent}}}" + if isinstance(self.routing, ReduceRoutingDeclaration): + return f'{indent_str}multistream<{self.dtype.dtype.as_ir()}> {self.stream_name.as_ir()} = reduce_stream({self.x.as_ir()}, {self.y.as_ir()}){routing_str}' + elif isinstance(self.routing, BroadcastRoutingDeclaration): + return f'{indent_str}multistream<{self.dtype.dtype.as_ir()}> {self.stream_name.as_ir()} = broadcast_stream({self.x.as_ir()}, {self.y.as_ir()}){routing_str}' + else: + raise ValueError("Invalid routing declaration") + ### # Dataflow Block ### @@ -431,11 +522,11 @@ class DataflowBlock(SpatialNode): """ variables: list[TypedIdentifier] subgrid: SubgridExpression - statements: list[RelativeStreamDeclaration] + statements: list[Union[list[RelativeStreamDeclaration], list[MulStreamDeclaration]]] def validate(self) -> None: assert all(isinstance(var, TypedIdentifier) for var in self.variables) - assert all(isinstance(stmt, RelativeStreamDeclaration) for stmt in self.statements) + assert all(isinstance(stmt, RelativeStreamDeclaration) or isinstance(stmt, MulStreamDeclaration) for stmt in self.statements) assert len(self.variables) == 2 def as_ir(self, indent: int = 0) -> str: @@ -519,6 +610,50 @@ def as_ir(self, indent: int = 0) -> str: if self.completion_name: return f'{indent_str}{self.completion_name.as_ir()} = receive({self.local_array.as_ir()}, {self.stream_name.as_ir()})' return f'{indent_str}await receive({self.local_array.as_ir()}, {self.stream_name.as_ir()})' + + +@dataclass +class BroadcastStatement(Statement): + """ + Branch statement for sending data asynchronously through a stream. + """ + local_array: Union[Identifier, ArraySlice] + stream_name: Union[Identifier, ArraySlice] + completion_name: Optional[Completion] = None + + def validate(self) -> None: + assert isinstance(self.local_array, (Identifier, ArraySlice)) + assert isinstance(self.stream_name, (Identifier, ArraySlice)) + if self.completion_name: + assert isinstance(self.completion_name, Completion) + + def as_ir(self, indent: int = 0) -> str: + indent_str = ' ' * indent + if self.completion_name: + return f'{indent_str}{self.completion_name.as_ir()} = broadcast({self.local_array.as_ir()}, {self.stream_name.as_ir()})' + return f'{indent_str}await broadcast({self.local_array.as_ir()}, {self.stream_name.as_ir()})' + +@dataclass +class ReduceStatement(Statement): + """ + Receive statement for receiving data asynchronously through a stream. + """ + local_array: Union[Identifier, ArraySlice] + stream_name: Union[Identifier, ArraySlice] + completion_name: Optional[Completion] = None + + def validate(self) -> None: + assert isinstance(self.local_array, (Identifier, ArraySlice)) + assert isinstance(self.stream_name, (Identifier, ArraySlice)) + if self.completion_name: + assert isinstance(self.completion_name, Completion) + + + def as_ir(self, indent: int = 0) -> str: + indent_str = ' ' * indent + if self.completion_name: + return f'{indent_str}{self.completion_name.as_ir()} = receive({self.local_array.as_ir()}, {self.stream_name.as_ir()})' + return f'{indent_str}await receive({self.local_array.as_ir()}, {self.stream_name.as_ir()})' # Receive generator @@ -758,14 +893,14 @@ class KernelArgument(SpatialNode): """ A kernel argument of a given type. """ - dtype: Union[ScalarType, ArrayType, StreamType] + dtype: Union[ScalarType, ArrayType, StreamType, MultiStreamType] identifier: Identifier readonly: bool = False writeonly: bool = False compiletime: bool = False def validate(self) -> None: - assert isinstance(self.dtype, (ScalarType, ArrayType, StreamType)) + assert isinstance(self.dtype, (ScalarType, ArrayType, StreamType, MultiStreamType)) assert isinstance(self.identifier, Identifier) assert not self.readonly or not self.writeonly assert not self.compiletime or not self.writeonly @@ -844,6 +979,8 @@ def subgrids(self) -> list[Subgrid]: (0, elem))) return rectangles + + # Specialized visitors diff --git a/spatialstencil/syntax/spatial_ir/language.lark b/spatialstencil/syntax/spatial_ir/language.lark index 684cd7e7..8673c1b8 100644 --- a/spatialstencil/syntax/spatial_ir/language.lark +++ b/spatialstencil/syntax/spatial_ir/language.lark @@ -18,7 +18,8 @@ hexadecimal_literal : "0x" hex_digits negated_integer_literal : "-" integer_literal ?posneg_integer_literal : integer_literal | negated_integer_literal float_literal : /[-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?/ -string_literal : ESCAPED_STRING +non_escaped_string : letters (underscore letters)* +string_literal : ESCAPED_STRING | non_escaped_string ?constant_literal : bool_literal | posneg_integer_literal | float_literal | string_literal // Identifier syntax (loosely following MLIR conventions) @@ -32,7 +33,7 @@ identifier : suffix_id ("#" digits)? !uint_type : "u8" | "u16" | "u32" !bool_type : "bool" ?scalar_type : float_type | int_type | uint_type | bool_type -stream_type : "stream" "<" scalar_type ">" +stream_type : "stream" "<" scalar_type ">" | "multistream" "<" scalar_type ">" ?standard_type : scalar_type | stream_type // Array types @@ -115,9 +116,15 @@ subgrid_expression_2d : "[" range_expression "," range_expression "]" hop : "(" posneg_integer_literal "," posneg_integer_literal ")" // 2D at the moment, might expand hops : "[" hop ("," hop)* "]" routing : "hops" "=" (auto | hops) "," "channel" "=" (auto | integer_literal) +broadcast_routing : "channels" "=" (auto | integer_literal) +reduce_routing : "algorithm" "=" (auto | constant_literal) "," "op" "=" constant_literal "," "pipelined" "=" bool_literal field_declaration : builtin_type identifier (";")? //(";" | NEWLINE) -stream_declaration : "stream" "<" scalar_type ">" identifier "=" "relative_stream" "(" value_expr "," value_expr ")" ("{" routing "}")? (";")? +stream_declaration : classic_stream | mul_stream +classic_stream : "stream" "<" scalar_type ">" identifier "=" "relative_stream" "(" value_expr "," value_expr ")" ("{" routing "}")? (";")? +mul_stream : bcast | red +bcast : "multistream" "<" scalar_type ">" identifier "=" "broadcast_stream" "(" value_expr "," value_expr ")" ("{" broadcast_routing "}")? (";")? +red : "multistream" "<" scalar_type ">" identifier "=" "reduce_stream" "(" value_expr "," value_expr ")" ("{" reduce_routing "}")? (";")? vars : identifier ("," identifier)* typed_var : scalar_type identifier typed_vars : typed_var ("," typed_var)* diff --git a/spatialstencil/syntax/spatial_ir/lark_to_ir.py b/spatialstencil/syntax/spatial_ir/lark_to_ir.py index 192274a5..bf8aa010 100644 --- a/spatialstencil/syntax/spatial_ir/lark_to_ir.py +++ b/spatialstencil/syntax/spatial_ir/lark_to_ir.py @@ -2,7 +2,7 @@ from spatialstencil.syntax.common.types import ScalarType from spatialstencil.syntax.spatial_ir import irnodes -from spatialstencil.syntax.spatial_ir.irnodes import StreamType, Identifier +from spatialstencil.syntax.spatial_ir.irnodes import StreamType, MultiStreamType, Identifier class TreeToSpatialIR(lark.Transformer): @@ -36,7 +36,13 @@ def hexadecimal_literal(self, *digits): @lark.v_args(inline=True) def string_literal(self, s): - return irnodes.StringLiteral(s[1:-1].replace('\\"', '"')) + if type(s).__name__ == 'Tree': + combined_string = '' + for i in range(len(s.children)): + combined_string += s.children[i] + return combined_string + else: + return irnodes.StringLiteral(s[1:-1].replace('\\"', '"')) @lark.v_args(inline=True) def bare_id(self, *elements): @@ -122,6 +128,10 @@ def function_call(self, args, meta=None): return irnodes.SendStatement(*arguments, completion_name=completion) elif func == 'receive': return irnodes.ReceiveStatement(*arguments, completion_name=completion) + elif func == 'broadcast': + return irnodes.BroadcastStatement(*arguments, completion_name=completion) + elif func == 'reduce': + return irnodes.ReduceStatement(*arguments, completion_name=completion) raise SyntaxError(f'Unrecognized free function call to "{func}"') subscript = irnodes.ArraySlice.from_lark @@ -133,6 +143,8 @@ def function_call(self, args, meta=None): # Declarations and routing hop = irnodes.RoutingHop.from_lark routing = irnodes.RoutingDeclaration.from_lark + broadcast_routing = irnodes.BroadcastRoutingDeclaration.from_lark + reduce_routing = irnodes.ReduceRoutingDeclaration.from_lark field_declaration = irnodes.FieldDeclaration.from_lark subgrid_expression_2d = irnodes.SubgridExpression.from_lark @@ -142,8 +154,18 @@ def hop(self, args): return irnodes.RoutingHop(o) def stream_declaration(self, args): - args[0] = StreamType(args[0]) - return irnodes.RelativeStreamDeclaration(*args) + if args[0].data == 'classic_stream': + args[0].children[0] = StreamType(args[0].children[0]) + return irnodes.RelativeStreamDeclaration(*args[0].children) + elif args[0].data == 'mul_stream': + args[0].children[0].children[0] = MultiStreamType(args[0].children[0].children[0]) + return irnodes.MulStreamDeclaration(*args[0].children[0].children) + else: + raise NotImplementedError('Only classic and mul stream declarations are supported at the moment') + + # original code + # args[0] = StreamType(args[0]) + # return irnodes.RelativeStreamDeclaration(*args) # Scopes def _scope_wrapper(self, cls, args): diff --git a/tests/test_collective_ir_parser.py b/tests/test_collective_ir_parser.py new file mode 100644 index 00000000..88b0b79c --- /dev/null +++ b/tests/test_collective_ir_parser.py @@ -0,0 +1,108 @@ +from spatialstencil.syntax.spatial_ir import parser +from spatialstencil.optimizations.optimization_pass import optimization_pass +import os + + + +def _load_ref_file(file) -> list[str]: + # change the file extension to .ref + file = file[:-5] + '.ref_tile' + # read the file and return the lines as a list without the newline character + with open(file, 'r') as f: + return [line.strip() for line in f.readlines()] + +def _tiling_test(file): + """ + Tests a roundtrip IR->parse->IR->parse->IR for differences. + + :param file: + :return: + """ + program = parser.parse_file(file) + program_optimized = optimization_pass(program) + ir_1 = program_optimized.as_ir() + + ir_ref = _load_ref_file(file) + num_dataflow = ir_ref[-1] + ir_ref = ir_ref[:-1] + + count_ref = 0 + for line in ir_ref: + assert ("compute i16 i, i16 j in " + line) in ir_1 + count_ref += 1 + count = 0 + for line in ir_1.splitlines(): + if "compute i16 i, i16 j in " in line: + count += 1 + assert count == count_ref + count_dataflow = 0 + for line in ir_1.splitlines(): + if "dataflow i16 i, i16 j in" in line: + count_dataflow += 1 + assert count_dataflow == int(num_dataflow) + + + + +def test_simple_bcast(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_bcast.sptl') + _tiling_test(file) + +def test_simple_reduce_grid_pipelined_1(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_reduce_grid_pipelined_1.sptl') + _tiling_test(file) + +def test_simple_reduce_grid_pipelined_2(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_reduce_grid_pipelined_2.sptl') + _tiling_test(file) + +def test_simple_reduce_grid_1(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_reduce_grid_1.sptl') + _tiling_test(file) + +def test_simple_reduce_grid_2(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_reduce_grid_2.sptl') + _tiling_test(file) + +def test_simple_reduce_snake_pipelined_1(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_reduce_snake_pipelined_1.sptl') + _tiling_test(file) + +def test_simple_reduce_snake_1(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_reduce_snake_1.sptl') + _tiling_test(file) + +def test_simple_reduce_looped(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'simple_reduce_looped.sptl') + _tiling_test(file) + +def test_medium_reduce_grid_1(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'medium_reduce_grid_1.sptl') + _tiling_test(file) + +def test_hard_reduce_1(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'hard_reduce_1.sptl') + _tiling_test(file) + +def test_hard_reduce_2(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'hard_reduce_2.sptl') + _tiling_test(file) + +def test_hard_reduce_3(): + file = os.path.join(os.path.dirname(__file__), '..', 'samples', 'collective', 'hard_reduce_3.sptl') + _tiling_test(file) + + +if __name__ == '__main__': + test_simple_bcast() + test_simple_reduce_grid_pipelined_1() + test_simple_reduce_grid_pipelined_2() + test_simple_reduce_grid_1() + test_simple_reduce_grid_2() + test_simple_reduce_snake_pipelined_1() + test_simple_reduce_snake_1() + test_simple_reduce_looped() + test_medium_reduce_grid_1() + test_hard_reduce_1() + test_hard_reduce_2() + test_hard_reduce_3() \ No newline at end of file diff --git a/tests/test_spatial_ir_parser.py b/tests/test_spatial_ir_parser.py index 931c6814..d82b36e9 100644 --- a/tests/test_spatial_ir_parser.py +++ b/tests/test_spatial_ir_parser.py @@ -1,4 +1,5 @@ from spatialstencil.syntax.spatial_ir import irnodes as spast, parser +from spatialstencil.optimizations.optimization_pass import optimization_pass import os @@ -93,7 +94,8 @@ def _rountrip_test(file): :return: """ program = parser.parse_file(file) - ir_1 = program.as_ir() + program_optimized = optimization_pass(program) + ir_1 = program_optimized.as_ir() program2 = parser.parse_string(ir_1) ir_2 = program2.as_ir() assert ir_1 == ir_2 @@ -108,6 +110,7 @@ def test_spatial_roundtrip_two_phase_split(): _rountrip_test(file) + if __name__ == '__main__': test_spatial_roundtrip_laplacian() test_spatial_visitor() @@ -115,4 +118,4 @@ def test_spatial_roundtrip_two_phase_split(): test_spatial_roundtrip_two_phase_unrouted() test_spatial_roundtrip_two_phase_split() test_spatial_roundtrip_forward() - test_spatial_roundtrip_backward() + test_spatial_roundtrip_backward() \ No newline at end of file