feat[next-dace]: Added gt_replace_concat_where_node()#2482
feat[next-dace]: Added gt_replace_concat_where_node()#2482philip-paul-mueller wants to merge 84 commits intoGridTools:mainfrom
gt_replace_concat_where_node()#2482Conversation
…d multiple contexts.
…t is not implemented, but now I can begin.
Now full steam ahead to nested SDFGs or meetings?
edopao
left a comment
There was a problem hiding this comment.
Very good! I only have some minor comments.
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
| is accessed inside nested SDFGs. | ||
|
|
||
| The function has currently the following limitations: | ||
| - `concat_node` must be single use data (might be lifted). |
There was a problem hiding this comment.
| - `concat_node` must be single use data (might be lifted). | |
| - `concat_node` must be single use data. |
Since we remove the concat_where node, my understanding is that it must be single use data.
There was a problem hiding this comment.
I agree that we should remove that part.
However, it can lifted, for that you just have to create the AccessNodes for the producer in the other state and and could then rewire them.
You just have to make sure that the producer data is not modified in the mean time.
But since in most cases we have one state, it is not particular important.
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Show resolved
Hide resolved
|
|
||
| # Check if there is already a connection between `parent_scope` and | ||
| # `scope_to_handle` and if so use it. Otherwise create a new connection. | ||
| parent_to_current_scope_edge = next( |
There was a problem hiding this comment.
In this case, should we also check that the existing connection copies the full shape into the nested scope?
There was a problem hiding this comment.
No you are right.
I refactored it it is now much cleaner.
It is actually done, by induction.
To see that, let's assume that the source listed in producer_specs[i].data_source[parent_scope] already does this, thus every outgoing edge from connection parent_source.conn of parent_source.node provides the full data.
Now we just have to look at the base base, which connects the AccessNode (at the top) with a MapEntry (also at the top), this is done in the later parts of the function.
As you can see there the test is done.
However, your questions made me realize that I should probably refactor this function a bit.~
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
iomaganaris
left a comment
There was a problem hiding this comment.
Some initial comments. I'm still going through this
...sts/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py
Outdated
Show resolved
Hide resolved
| - `concat_node` must be single use data (might be lifted). | ||
| - All producers of `concat_node` must be AccessNodes (might be lifted). |
There was a problem hiding this comment.
It would be interesting to see in the whole dycore if there are cases where this transformation can be applied and would be beneficial with these restrictions lifted as well. Could you try running either the blueline or the model/atmosphere/dycore/tests/dycore/integration_tests/test_benchmark_solve_nonhydro.py::test_benchmark_solve_nonhydro[False-False] test in icon4py and print if there are such cases?
Actually in the case that the producers where not access nodes it would be very difficult to handle (and probably not beneficial) because one branch of the if statement would be a write and the other would be the internal computations of a whole map
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
| # Replace `concat_where` nodes | ||
| # TODO(phimuell): Are there better locations for this transformation? An | ||
| # alternative would be to call after the loop. | ||
| gtx_transformations.gt_apply_concat_where_replacement_on_sdfg( |
There was a problem hiding this comment.
I would expect this transformation to apply only once to the SDFG. However I believe it should probably apply after the map splitting/fusion to avoid adding edges to the maps that might interfere with this pass
There was a problem hiding this comment.
So you would place it after the loop?
However, I would keep the todo.
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Show resolved
Hide resolved
| genuine_consumer_specs, descending_points = _find_consumer_specs_single_source_single_level( | ||
| state, concat_node, for_check=False | ||
| ) | ||
| assert all(_check_descending_point(descending_point) for descending_point in descending_points) |
There was a problem hiding this comment.
Isn't this checked already in gt_check_if_concat_where_node_is_replaceable? It looks like an expensive operation so I'm wondering if it's okay to remove the assert here
There was a problem hiding this comment.
You are right, it is extensive.
| producer_specs: Sequence[_ProducerSpec], | ||
| ) -> None: | ||
| """Helper function of `_map_data_into_nested_scopes()`.""" | ||
| if scope_to_handle in producer_specs[0].data_source: |
There was a problem hiding this comment.
Not sure I understood this part. We exit here if scope_to_handle is in producer_specs[0].data_source because this means that the scope_to_handle is the same scope as the first producer so the concat_node as well?
Could you please add a comment with an explanation here?
There was a problem hiding this comment.
I added a comment.
The check is to see if we have handled the scope already.
The mean reason for this is because we process them in kind of an arbitrary order (you could compute a proper order).
We use [0] because if a scope is handled form one producer is handled for every consumer.
| host_buffer = host_rand(dace_sym.evaluate(desc.total_size, used_symbols), desc.dtype) | ||
| shape = tuple(dace_sym.evaluate(s, used_symbols) for s in desc.shape) | ||
| dtype = desc.dtype.as_numpy_dtype() | ||
| strides = tuple( | ||
| dace_sym.evaluate(s, used_symbols) * desc.dtype.bytes for s in desc.strides | ||
| ) |
There was a problem hiding this comment.
I'm wondering what triggered these changes. Why do we have to copy data to the GPU now and not before?
There was a problem hiding this comment.
Because (as of now) there is no test that needs GPU memory.
I had to rework this function, to handle cases where data descriptors contained symbols, so I copied it from DaCe and for some reason I also copied the GPU part.
| assert consumer_spec.edge.data.wcr is None | ||
| assert consumer_spec.edge.data.subset.num_elements() == 1 | ||
|
|
||
| tlet_inputs: list[str] = [f"__inp{i}" for i in range(len(producer_specs))] |
There was a problem hiding this comment.
I would like to have a look at whether there's any difference between having different pointers for each branch versus one for every outer access node we access
There was a problem hiding this comment.
Do you mean that if we have two times the same producer then we use two different pointers variables (each containing the same value) for the access?
I think we should discuss this further.
|
The comparison was done by running graupel 10 times (each time a new compilation), each time on I would say we see an increase in performance , by a bit less than 2% which is in the expected range. |
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
edopao
left a comment
There was a problem hiding this comment.
I still think that the check for the scan case could be simplified to just check if the nested SDFG label starts with "scan_", because we know by construction that this is the only case where a field operator lambda expression is lowered to a nested SDFG inside a map scope. We know, by construction, that the scan lambda only operates on local points. However, I am also fine with keeping the current check. I just have some small suggestions.
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/transformations/concat_where_mapper.py
Show resolved
Hide resolved
| if type(parent_desc) is not type(nested_desc): | ||
| return False | ||
|
|
||
| if not all((start == 0) == True for start in consumed_subset.min_element()): # noqa: E712 [true-false-comparison] # SymPy comparison |
There was a problem hiding this comment.
I am not sure about the check on start == 0. Could it be that we see 0 only because we are using static domains in muphys? If we were using symbolic domain, we would probably see the _range_0 symbol.
The check should be (to be moved a bit down inside the while):
if not all((start == map_range[0]) == True for start, map_range in zip(consumed_subset.min_element(), edge.src.map.ndrange)):
It could be fused into the existing for loop:
for i, (consumed_range, map_range) in enumerate(consumed_subset, edge.src.map.ndrange):
if str(consumed_range.max_element()).isdigit():
if (consumed_range.max_element() + 1) != parent_desc.shape[i]: # `+1` because of storage format.
return False
else:
if (consumed_range.min_element() == map_range[0]) != True:
return False
if str(consumed_range.max_element()) not in map_params:
return False
There was a problem hiding this comment.
Honestly, I am not sure where the 0 is coming from exactly, my bet is that it is already there.
I am also thinking that if the lower bound is not 0 at lowering/optimization time then the array is probably not mapped in entirely and if so it is more of a coincident and not general.
There are several points related to the proposed change of the lower bound test:
- It should not be in the
elsebranch but directly beneath thefor. - You should not compare against
map_range[0]. Consider a pathological Memlet with0:ithat is inside a Map with the rangei=3:10(think of 3 asstrange_icon_constant), you "clearly" map in the whole array, but you will reject it.
But I like your suggestion regarding the map parameter test.
I have extended it such that a parameter can only be used once.
There was a problem hiding this comment.
Looks good! I also tried to run graupel with symbolic domain but I had forgotten that gt4py takes forever to lower it to GTIR, so I could not run it.

The transformation allows to replace some nodes that are generated by
concat_whereexpression. The transformation is able to replace a Memlet access to the data with a conditional access to the data that generates it.TODO: