Skip to content

Comments

feat[next-dace]: Added gt_replace_concat_where_node()#2482

Open
philip-paul-mueller wants to merge 84 commits intoGridTools:mainfrom
philip-paul-mueller:concat_where_copy_to_map_tasklet
Open

feat[next-dace]: Added gt_replace_concat_where_node()#2482
philip-paul-mueller wants to merge 84 commits intoGridTools:mainfrom
philip-paul-mueller:concat_where_copy_to_map_tasklet

Conversation

@philip-paul-mueller
Copy link
Contributor

@philip-paul-mueller philip-paul-mueller commented Feb 16, 2026

The transformation allows to replace some nodes that are generated by concat_where expression. The transformation is able to replace a Memlet access to the data with a conditional access to the data that generates it.

TODO:

  • Run Blueline
  • Run Graupel (see here)

Now full steam ahead to nested SDFGs or meetings?
Copy link
Contributor

@edopao edopao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good! I only have some minor comments.

is accessed inside nested SDFGs.

The function has currently the following limitations:
- `concat_node` must be single use data (might be lifted).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `concat_node` must be single use data (might be lifted).
- `concat_node` must be single use data.

Since we remove the concat_where node, my understanding is that it must be single use data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should remove that part.
However, it can lifted, for that you just have to create the AccessNodes for the producer in the other state and and could then rewire them.
You just have to make sure that the producer data is not modified in the mean time.
But since in most cases we have one state, it is not particular important.


# Check if there is already a connection between `parent_scope` and
# `scope_to_handle` and if so use it. Otherwise create a new connection.
parent_to_current_scope_edge = next(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, should we also check that the existing connection copies the full shape into the nested scope?

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you are right.
I refactored it it is now much cleaner.

It is actually done, by induction.
To see that, let's assume that the source listed in producer_specs[i].data_source[parent_scope] already does this, thus every outgoing edge from connection parent_source.conn of parent_source.node provides the full data.
Now we just have to look at the base base, which connects the AccessNode (at the top) with a MapEntry (also at the top), this is done in the later parts of the function.
As you can see there the test is done.

However, your questions made me realize that I should probably refactor this function a bit.~

Copy link
Contributor

@iomaganaris iomaganaris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments. I'm still going through this

Comment on lines 148 to 149
- `concat_node` must be single use data (might be lifted).
- All producers of `concat_node` must be AccessNodes (might be lifted).
Copy link
Contributor

@iomaganaris iomaganaris Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be interesting to see in the whole dycore if there are cases where this transformation can be applied and would be beneficial with these restrictions lifted as well. Could you try running either the blueline or the model/atmosphere/dycore/tests/dycore/integration_tests/test_benchmark_solve_nonhydro.py::test_benchmark_solve_nonhydro[False-False] test in icon4py and print if there are such cases?

Actually in the case that the producers where not access nodes it would be very difficult to handle (and probably not beneficial) because one branch of the if statement would be a write and the other would be the internal computations of a whole map

# Replace `concat_where` nodes
# TODO(phimuell): Are there better locations for this transformation? An
# alternative would be to call after the loop.
gtx_transformations.gt_apply_concat_where_replacement_on_sdfg(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this transformation to apply only once to the SDFG. However I believe it should probably apply after the map splitting/fusion to avoid adding edges to the maps that might interfere with this pass

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you would place it after the loop?
However, I would keep the todo.

genuine_consumer_specs, descending_points = _find_consumer_specs_single_source_single_level(
state, concat_node, for_check=False
)
assert all(_check_descending_point(descending_point) for descending_point in descending_points)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this checked already in gt_check_if_concat_where_node_is_replaceable? It looks like an expensive operation so I'm wondering if it's okay to remove the assert here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, it is extensive.

producer_specs: Sequence[_ProducerSpec],
) -> None:
"""Helper function of `_map_data_into_nested_scopes()`."""
if scope_to_handle in producer_specs[0].data_source:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understood this part. We exit here if scope_to_handle is in producer_specs[0].data_source because this means that the scope_to_handle is the same scope as the first producer so the concat_node as well?
Could you please add a comment with an explanation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment.
The check is to see if we have handled the scope already.
The mean reason for this is because we process them in kind of an arbitrary order (you could compute a proper order).
We use [0] because if a scope is handled form one producer is handled for every consumer.

Comment on lines +159 to +164
host_buffer = host_rand(dace_sym.evaluate(desc.total_size, used_symbols), desc.dtype)
shape = tuple(dace_sym.evaluate(s, used_symbols) for s in desc.shape)
dtype = desc.dtype.as_numpy_dtype()
strides = tuple(
dace_sym.evaluate(s, used_symbols) * desc.dtype.bytes for s in desc.strides
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering what triggered these changes. Why do we have to copy data to the GPU now and not before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because (as of now) there is no test that needs GPU memory.
I had to rework this function, to handle cases where data descriptors contained symbols, so I copied it from DaCe and for some reason I also copied the GPU part.

assert consumer_spec.edge.data.wcr is None
assert consumer_spec.edge.data.subset.num_elements() == 1

tlet_inputs: list[str] = [f"__inp{i}" for i in range(len(producer_specs))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to have a look at whether there's any difference between having different pointers for each branch versus one for every outer access node we access

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that if we have two times the same producer then we use two different pointers variables (each containing the same value) for the access?

I think we should discuss this further.

@philip-paul-mueller
Copy link
Contributor Author

The comparison was done by running graupel 10 times (each time a new compilation), each time on atm_R2B06.nc with 100 repetition.
For MAIN the staging branch 9a25a94699 was used.
For "This Transformation" 5612a48cf3f728bf79cb was used merged with main (a9de666c) and the staging branch (9a25a946).

MAIN                vs   This Transformation
0.7839205265045166  vs.  0.7592110633850098
0.7873868942260742  vs.  0.7683718204498291
0.7889411449432373  vs.  0.7686924934387207
0.7904205322265625  vs.  0.7687270641326904
0.7986536026000977  vs.  0.7690911293029785
0.8011176586151123  vs.  0.7742154598236084
0.8041496276855469  vs.  0.7803781032562256
0.806187629699707   vs.  0.7814192771911621
0.8070793151855469  vs.  0.7899436950683594
0.8198916912078857  vs.  0.8004758358001709

I would say we see an increase in performance , by a bit less than 2% which is in the expected range.

Copy link
Contributor

@edopao edopao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@edopao edopao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think that the check for the scan case could be simplified to just check if the nested SDFG label starts with "scan_", because we know by construction that this is the only case where a field operator lambda expression is lowered to a nested SDFG inside a map scope. We know, by construction, that the scan lambda only operates on local points. However, I am also fine with keeping the current check. I just have some small suggestions.

if type(parent_desc) is not type(nested_desc):
return False

if not all((start == 0) == True for start in consumed_subset.min_element()): # noqa: E712 [true-false-comparison] # SymPy comparison
Copy link
Contributor

@edopao edopao Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about the check on start == 0. Could it be that we see 0 only because we are using static domains in muphys? If we were using symbolic domain, we would probably see the _range_0 symbol.

The check should be (to be moved a bit down inside the while):
if not all((start == map_range[0]) == True for start, map_range in zip(consumed_subset.min_element(), edge.src.map.ndrange)):

It could be fused into the existing for loop:


    for i, (consumed_range, map_range) in enumerate(consumed_subset, edge.src.map.ndrange):
        if str(consumed_range.max_element()).isdigit():
            if (consumed_range.max_element() + 1) != parent_desc.shape[i]:  # `+1` because of storage format.
                return False
        else:
            if (consumed_range.min_element() == map_range[0]) != True:
               return False
            if str(consumed_range.max_element()) not in map_params:
               return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I am not sure where the 0 is coming from exactly, my bet is that it is already there.
I am also thinking that if the lower bound is not 0 at lowering/optimization time then the array is probably not mapped in entirely and if so it is more of a coincident and not general.

There are several points related to the proposed change of the lower bound test:

  • It should not be in the else branch but directly beneath the for.
  • You should not compare against map_range[0]. Consider a pathological Memlet with 0:i that is inside a Map with the range i=3:10 (think of 3 as strange_icon_constant), you "clearly" map in the whole array, but you will reject it.

But I like your suggestion regarding the map parameter test.
I have extended it such that a parameter can only be used once.

Copy link
Contributor

@edopao edopao Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I also tried to run graupel with symbolic domain but I had forgotten that gt4py takes forever to lower it to GTIR, so I could not run it.

@philip-paul-mueller
Copy link
Contributor Author

Here are the results for Blueline (was run with merged main but without the compile time domain).

bench_blueline_stencil_compute

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants