Skip to content

[Layout] Add tmem duplication attribute and tcgen05.copy multicast#152

Open
yaoyaoding wants to merge 1 commit intomainfrom
feat/enhance-tmem
Open

[Layout] Add tmem duplication attribute and tcgen05.copy multicast#152
yaoyaoding wants to merge 1 commit intomainfrom
feat/enhance-tmem

Conversation

@yaoyaoding
Copy link
Copy Markdown
Member

Two coupled additions, paving the way for block-scaled MMAs (NVFP4, MXFP4) that need scale-factor tensors replicated across TMEM warp sub-partitions.

TMemoryLayout duplication (PR1a):

  • New TMemoryDuplication enum: NONE / WARPX4 / WARPX2_02_13 / WARPX2_01_23.
  • TMemoryLayout.create() cross-validates shape[0] against duplication: WARPX4 requires shape[0]=32; WARPX2_* require shape[0]=64; NONE accepts {32, 64, 128}. WARPX* layouts disallow non-zero lane_offset (they span all 128 physical lanes via replication); NONE requires lane_offset + shape[0] <= 128.
  • tmemory_row_major accepts a duplication kwarg.
  • tmemory_slice preserves the parent's duplication on column-dim slicing; rejects lane-dim slicing on WARPX* tensors.
  • Tcgen05AllocRule (layout inference) pins WARPX4 for shape[0]=32 and NONE for shape[0]=128; defers shape[0]=64 (ambiguous between NONE and WARPX2_*) to downstream inference (e.g. a copy with multicast hint).
  • 21 new unit tests under tests/ir/layout/test_tmem_layout.py.

tcgen05.copy multicast kwarg (PR1b):

  • Tcgen05CopyInst gains a multicast: str = "" field (primitive type so IR functors walk it generically; converted to Tcgen05CopyMulticastKind in the codegen emitter).
  • User-facing Tcgen05Module.copy(..., multicast=None) accepts None, "warpx4", "warpx2_02_13", or "warpx2_01_23"; unknown names raise InstructionError.
  • The emitter routes the multicast through split_canonical_layout / generate_instructions / the tcgen05_copy PTX call. The NONE path is unchanged. Multicast modes raise NotImplementedError ("not yet supported in emitter") until the smaller-shape codegen lands as part of the NVFP4 matmul work.
  • 1 new test in test_tcgen05_copy.py exercises the user-facing validation for unknown multicast names.

Tests: 61 passed, 4 xfailed across all TMEM-related instruction + layout suites. matmul_v0..v6 examples still build cleanly.

…icasts

Two coupled additions, paving the way for block-scaled MMAs (NVFP4, MXFP4)
that need scale-factor tensors replicated across TMEM warp sub-partitions.

TMemoryLayout duplication:
- New `TMemoryDuplication` enum: NONE / WARPX4 / WARPX2_02_13 / WARPX2_01_23.
- `TMemoryLayout.create()` cross-validates shape[0] against duplication:
  WARPX4 requires shape[0]=32; WARPX2_* require shape[0]=64; NONE accepts
  {32, 64, 128}. WARPX* layouts disallow non-zero lane_offset (they span all
  128 physical lanes via replication); NONE requires
  lane_offset + shape[0] <= 128.
- `tmemory_row_major` accepts a `duplication` kwarg.
- `tmemory_slice` preserves the parent's duplication on column-dim slicing;
  rejects lane-dim slicing on WARPX* tensors.
- `Tcgen05AllocRule` pins WARPX4 for shape[0]=32 and NONE for shape[0]=128;
  defers shape[0]=64 (ambiguous between NONE and WARPX2_*) to the copy rule.
- 21 new unit tests under tests/ir/layout/test_tmem_layout.py.

tcgen05.copy multicast kwarg + emitter:
- `Tcgen05CopyInst` gains a `multicast: str = ""` field (primitive type so
  IR functors walk it generically; converted to Tcgen05CopyMulticastKind in
  the codegen emitter).
- User-facing `Tcgen05Module.copy(..., multicast=None)` accepts None,
  "warpx4", "warpx2_02_13", or "warpx2_01_23"; unknown names raise
  `InstructionError`.
- Emitter implements all three multicast modes by mapping each to its
  required source-row count (32 / 64) and shape kind (R32x128B / R64x128B),
  loosening the K-major alignment to `inst_n % T == 0` so the smaller shape
  kinds are accepted.
- Alloc emitter accepts shape[0] in {32, 64, 128}.
- `Tcgen05CopyRule` infers TMEM duplication from the multicast kwarg
  (covers the deferred shape[0]=64 cases).
- Fixed the PTX modifier strings on `Tcgen05CopyMulticastKind`
  (`.warpx2_02_13` -> `.warpx2::02_13`, `.warpx2_01_23` -> `.warpx2::01_23`)
  as required by ptxas.
- New SMEM->TMEM->global round-trip tests in test_tcgen05_copy.py exercise
  all three multicast modes plus the unknown-name validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding yaoyaoding force-pushed the feat/enhance-tmem branch from 3b59fc7 to b454685 Compare May 4, 2026 03:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant