Skip to content

[CuTeDSL] Make @cute.struct instances flattenable across scf.if / scf.while#3270

Open
cheshire wants to merge 1 commit into
NVIDIA:mainfrom
cheshire:fix/storage-get-tensor-in-if
Open

[CuTeDSL] Make @cute.struct instances flattenable across scf.if / scf.while#3270
cheshire wants to merge 1 commit into
NVIDIA:mainfrom
cheshire:fix/storage-get-tensor-in-if

Conversation

@cheshire
Copy link
Copy Markdown

Fixes #3268

A @cute.struct instance captured into an scf.if branch or scf.while body fails the DSL trace with:

    DSLRuntimeError: The 'if' statement encountered a user-defined Python
    object, which cannot be automatically converted into an dynamic
    expression.

This blocks the natural warp-specialization pattern, where each if warp_idx == <role>: branch reads its tile from a shared storage struct.

A struct instance is fully described by its base pointer (already DynamicExpression-aware via _Pointer); every field instance is re-derived from base + static offsets on construction. Implement the DynamicExpression protocol on each decorated class by forwarding __get_mlir_types__ / __extract_mlir_values__ to base, and __new_from_mlir_values__ to a fresh decorator invocation that re-derives the fields from a rebuilt base pointer.

….while

Fixes NVIDIA#3268

A `@cute.struct` instance captured into an `scf.if` branch or `scf.while`
body fails the DSL trace with:

    DSLRuntimeError: The 'if' statement encountered a user-defined Python
    object, which cannot be automatically converted into an dynamic
    expression.

This blocks the natural warp-specialization pattern, where each
`if warp_idx == <role>:` branch reads its tile from a shared storage
struct.

A struct instance is fully described by its `base` pointer (already
DynamicExpression-aware via `_Pointer`); every field instance is
re-derived from `base + static offsets` on construction. Implement the
DynamicExpression protocol on each decorated class by forwarding
`__get_mlir_types__` / `__extract_mlir_values__` to `base`, and
`__new_from_mlir_values__` to a fresh decorator invocation that
re-derives the fields from a rebuilt base pointer.

Tested in Docker on cutlass-dsl 4.5.1 with six new unit tests in
test/python/CuTeDSL/test_struct_in_if.py covering:
  * the original failing case (storage.get_tensor inside dynamic if),
  * regression: plain non-branched struct usage still works,
  * nested struct (struct-of-struct) inside a dynamic if,
  * if/else with both branches accessing the struct,
  * if/elif/elif/else (the actual warp-specialization shape),
  * scf.while body capturing the struct.
@anakinxc
Copy link
Copy Markdown

LGTM @grypp to double check

Copy link
Copy Markdown
Contributor

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks for contributing



@cute.struct
class _Outer:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could add a test with union as well.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add one when landing this internally

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] [CuTeDSL] storage.<field>.get_tensor(...) inside a dynamic if block fails with "encountered a user-defined Python object"

3 participants