From b4546859f9d103ae721ecaba6cf42b4b689b40db Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Mon, 4 May 2026 03:37:19 +0000 Subject: [PATCH] [Layout][Codegen] Add TMemoryLayout duplication and tcgen05.copy multicasts Two coupled additions, paving the way for block-scaled MMAs (NVFP4, MXFP4) that need scale-factor tensors replicated across TMEM warp sub-partitions. TMemoryLayout duplication: - New `TMemoryDuplication` enum: NONE / WARPX4 / WARPX2_02_13 / WARPX2_01_23. - `TMemoryLayout.create()` cross-validates shape[0] against duplication: WARPX4 requires shape[0]=32; WARPX2_* require shape[0]=64; NONE accepts {32, 64, 128}. WARPX* layouts disallow non-zero lane_offset (they span all 128 physical lanes via replication); NONE requires lane_offset + shape[0] <= 128. - `tmemory_row_major` accepts a `duplication` kwarg. - `tmemory_slice` preserves the parent's duplication on column-dim slicing; rejects lane-dim slicing on WARPX* tensors. - `Tcgen05AllocRule` pins WARPX4 for shape[0]=32 and NONE for shape[0]=128; defers shape[0]=64 (ambiguous between NONE and WARPX2_*) to the copy rule. - 21 new unit tests under tests/ir/layout/test_tmem_layout.py. tcgen05.copy multicast kwarg + emitter: - `Tcgen05CopyInst` gains a `multicast: str = ""` field (primitive type so IR functors walk it generically; converted to Tcgen05CopyMulticastKind in the codegen emitter). - User-facing `Tcgen05Module.copy(..., multicast=None)` accepts None, "warpx4", "warpx2_02_13", or "warpx2_01_23"; unknown names raise `InstructionError`. - Emitter implements all three multicast modes by mapping each to its required source-row count (32 / 64) and shape kind (R32x128B / R64x128B), loosening the K-major alignment to `inst_n % T == 0` so the smaller shape kinds are accepted. - Alloc emitter accepts shape[0] in {32, 64, 128}. - `Tcgen05CopyRule` infers TMEM duplication from the multicast kwarg (covers the deferred shape[0]=64 cases). - Fixed the PTX modifier strings on `Tcgen05CopyMulticastKind` (`.warpx2_02_13` -> `.warpx2::02_13`, `.warpx2_01_23` -> `.warpx2::01_23`) as required by ptxas. - New SMEM->TMEM->global round-trip tests in test_tcgen05_copy.py exercise all three multicast modes plus the unknown-name validation. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Yaoyao Ding --- .../backends/emitters/cuda/tcgen05/alloc.py | 4 +- .../backends/emitters/cuda/tcgen05/copy.py | 65 +++++-- .../tilus/hidet/ir/primitives/cuda/tcgen05.py | 4 +- python/tilus/ir/builders/stmt_builder.py | 4 +- python/tilus/ir/instructions/cuda/tcgen05.py | 10 +- .../inference_rules/tcgen05/alloc.py | 16 +- .../inference/inference_rules/tcgen05/copy.py | 44 +++-- python/tilus/ir/layout/ops/tmemory_ops.py | 24 ++- python/tilus/ir/layout/tmem_layout.py | 74 +++++++- python/tilus/lang/instructions/tcgen05.py | 39 +++- tests/instructions/test_tcgen05_copy.py | 115 +++++++++++ tests/ir/layout/test_tmem_layout.py | 179 ++++++++++++++++++ 12 files changed, 530 insertions(+), 48 deletions(-) create mode 100644 tests/ir/layout/test_tmem_layout.py diff --git a/python/tilus/backends/emitters/cuda/tcgen05/alloc.py b/python/tilus/backends/emitters/cuda/tcgen05/alloc.py index ecc13e59..a117ae33 100644 --- a/python/tilus/backends/emitters/cuda/tcgen05/alloc.py +++ b/python/tilus/backends/emitters/cuda/tcgen05/alloc.py @@ -38,8 +38,8 @@ class Tcgen05AllocDeallocEmitter(BaseInstEmitter): def get_num_columns(self, tmem_tensor: TMemoryTensor) -> int: shape = tmem_tensor.shape - if shape[0] != 128: - raise NotImplementedError(f"The emitter currently only supports shape[0] == 128, but got {shape[0]}") + if shape[0] not in (32, 64, 128): + raise NotImplementedError(f"The emitter only supports shape[0] in (32, 64, 128), but got {shape[0]}") if shape[-1] * tmem_tensor.dtype.nbits % 32 != 0: raise ValueError( f"shape[-1] * dtype.nbits must be divisible by 32, but got {shape[-1]} * {tmem_tensor.dtype.nbits} = {shape[-1] * tmem_tensor.dtype.nbits}" diff --git a/python/tilus/backends/emitters/cuda/tcgen05/copy.py b/python/tilus/backends/emitters/cuda/tcgen05/copy.py index 19f5842d..9cfb7a0e 100644 --- a/python/tilus/backends/emitters/cuda/tcgen05/copy.py +++ b/python/tilus/backends/emitters/cuda/tcgen05/copy.py @@ -56,10 +56,42 @@ def __str__(self) -> str: return "Tcgen05CopyInstMeta(" + ",\n ".join(items) + "\n)" +_MULTICAST_NAME_TO_KIND: dict[str, Tcgen05CopyMulticastKind] = { + "": Tcgen05CopyMulticastKind.NONE, + "warpx4": Tcgen05CopyMulticastKind.WARP_X4, + "warpx2_02_13": Tcgen05CopyMulticastKind.WARP_X2_02_13, + "warpx2_01_23": Tcgen05CopyMulticastKind.WARP_X2_01_23, +} + +# Per-multicast: required source-row count in SMEM and the candidate shape kinds. +# Smaller shape kinds (R32x128B / R64x128B) are only valid with multicast — the +# table maps each multicast to the shape kind PTX requires for it. +_MULTICAST_TO_SOURCE_ROWS: dict[Tcgen05CopyMulticastKind, int] = { + Tcgen05CopyMulticastKind.NONE: 128, + Tcgen05CopyMulticastKind.WARP_X4: 32, + Tcgen05CopyMulticastKind.WARP_X2_02_13: 64, + Tcgen05CopyMulticastKind.WARP_X2_01_23: 64, +} + +_MULTICAST_TO_SHAPE_KINDS: dict[Tcgen05CopyMulticastKind, tuple[Tcgen05CopyShapeKind, ...]] = { + Tcgen05CopyMulticastKind.NONE: ( + Tcgen05CopyShapeKind.R128x256B, + Tcgen05CopyShapeKind.R128x128B, + ), + Tcgen05CopyMulticastKind.WARP_X4: (Tcgen05CopyShapeKind.R32x128B,), + Tcgen05CopyMulticastKind.WARP_X2_02_13: (Tcgen05CopyShapeKind.R64x128B,), + Tcgen05CopyMulticastKind.WARP_X2_01_23: (Tcgen05CopyShapeKind.R64x128B,), +} + + @register_emitter(Tcgen05CopyInst, target=nvgpu_sm100) class Tcgen05CopyEmitter(BaseInstEmitter): def split_canonical_layout( - self, smem_addr: Expr, canonical: CanonicalSharedLayout, shape_kind: Tcgen05CopyShapeKind + self, + smem_addr: Expr, + canonical: CanonicalSharedLayout, + shape_kind: Tcgen05CopyShapeKind, + multicast: Tcgen05CopyMulticastKind = Tcgen05CopyMulticastKind.NONE, ) -> list[Tcgen05CopyInstMeta]: """ Split the canonical shared layout into multiple sub-tensors that can be copied by tcgen05.copy instructions. @@ -95,7 +127,7 @@ def split_canonical_layout( raise GenerationFailedError( "The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout" ) - if canonical.major_kind == "K" and (inst_m % 8 != 0 or inst_n % (canonical.T * 2) != 0): + if canonical.major_kind == "K" and (inst_m % 8 != 0 or inst_n % canonical.T != 0): raise GenerationFailedError( "The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout" ) @@ -151,7 +183,7 @@ def split_canonical_layout( instructions.append( Tcgen05CopyInstMeta( shape_kind=shape_kind, - multicast=Tcgen05CopyMulticastKind.NONE, + multicast=multicast, cta_group=Tcgen05CtaGroupKind.CTA_1, tmem_offset=tmem_offset, shared_descriptor=s_desc, @@ -161,7 +193,10 @@ def split_canonical_layout( return instructions def generate_instructions( - self, tmem_tensor: TMemoryTensor, shared_tensor: SharedTensor + self, + tmem_tensor: TMemoryTensor, + shared_tensor: SharedTensor, + multicast: Tcgen05CopyMulticastKind = Tcgen05CopyMulticastKind.NONE, ) -> list[Tcgen05CopyInstMeta]: dtype = shared_tensor.dtype canonical_layout: CanonicalSharedLayout | None = canonicalize_shared_layout( @@ -176,12 +211,9 @@ def generate_instructions( raise ValueError("\n".join(msg)) smem_addr = self.shared_tensor_shared_space_addr[shared_tensor] - for shape_kind in [ - Tcgen05CopyShapeKind.R128x256B, - Tcgen05CopyShapeKind.R128x128B, - ]: + for shape_kind in _MULTICAST_TO_SHAPE_KINDS[multicast]: try: - return self.split_canonical_layout(smem_addr, canonical_layout, shape_kind) + return self.split_canonical_layout(smem_addr, canonical_layout, shape_kind, multicast) except GenerationFailedError: continue @@ -193,19 +225,28 @@ def emit(self, inst: Tcgen05CopyInst) -> None: self.assert_is_warp_aligned(inst, "tcgen05.copy is a warp-cooperative instruction") + multicast_kind = _MULTICAST_NAME_TO_KIND.get(inst.multicast) + if multicast_kind is None: + raise ValueError("Unknown multicast {!r} on Tcgen05CopyInst".format(inst.multicast)) + if len(shared_tensor.shape) != 2: raise ValueError("The shared tensor must be a 2D tensor, got shape {}".format(shared_tensor.shape)) if len(tmem_tensor.shape) != 2: raise ValueError("The tensor memory tensor must be a 2D tensor, got shape {}".format(tmem_tensor.shape)) - if shared_tensor.shape[0] != 128: - raise NotImplementedError("The number of rows in the shared tensor must be 128") + expected_rows = _MULTICAST_TO_SOURCE_ROWS[multicast_kind] + if shared_tensor.shape[0] != expected_rows: + raise ValueError( + "tcgen05.copy multicast={!r} requires the shared tensor to have {} rows, got {}".format( + inst.multicast, expected_rows, shared_tensor.shape[0] + ) + ) if tmem_tensor.layout.lane_offset != 0: raise NotImplementedError("The first lane of the tmem tensor must be 0") tmem_base_addr = self.tensor2var[tmem_tensor] with self.single_thread(): - insts = self.generate_instructions(tmem_tensor, shared_tensor) + insts = self.generate_instructions(tmem_tensor, shared_tensor, multicast_kind) for inst_meta in insts: s_desc = self.declare_var("s_desc", tp=uint64, init=inst_meta.shared_descriptor.encoded()) t_addr = tmem_base_addr + inst_meta.tmem_offset diff --git a/python/tilus/hidet/ir/primitives/cuda/tcgen05.py b/python/tilus/hidet/ir/primitives/cuda/tcgen05.py index e1cb00df..1415ed12 100644 --- a/python/tilus/hidet/ir/primitives/cuda/tcgen05.py +++ b/python/tilus/hidet/ir/primitives/cuda/tcgen05.py @@ -155,8 +155,8 @@ def m(self) -> int: class Tcgen05CopyMulticastKind(Enum): NONE = "" - WARP_X2_02_13 = ".warpx2_02_13" - WARP_X2_01_23 = ".warpx2_01_23" + WARP_X2_02_13 = ".warpx2::02_13" + WARP_X2_01_23 = ".warpx2::01_23" WARP_X4 = ".warpx4" diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index be8ba34d..6b19d6ca 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -1624,8 +1624,8 @@ def tcgen05_wait_store(self) -> None: inst = Tcgen05WaitInst.create(wait_load=False, wait_store=True) self.append(inst) - def tcgen05_copy(self, src: SharedTensor, dst: TMemoryTensor) -> None: - inst = Tcgen05CopyInst.create(src=src, dst=dst) + def tcgen05_copy(self, src: SharedTensor, dst: TMemoryTensor, multicast: str = "") -> None: + inst = Tcgen05CopyInst.create(src=src, dst=dst, multicast=multicast) self.append(inst) def tcgen05_commit( diff --git a/python/tilus/ir/instructions/cuda/tcgen05.py b/python/tilus/ir/instructions/cuda/tcgen05.py index 62229a86..642085ba 100644 --- a/python/tilus/ir/instructions/cuda/tcgen05.py +++ b/python/tilus/ir/instructions/cuda/tcgen05.py @@ -119,10 +119,16 @@ def create(wait_load: bool, wait_store: bool) -> Tcgen05WaitInst: @dataclass(frozen=True, eq=False) class Tcgen05CopyInst(Instruction): + # Multicast pattern as a primitive string ("" for no multicast, "warpx4", + # "warpx2_02_13", "warpx2_01_23"). Converted to Tcgen05CopyMulticastKind in + # the codegen emitter — primitive types are kept here so IR functors can + # walk the field generically. + multicast: str = "" + @staticmethod - def create(src: SharedTensor, dst: TMemoryTensor) -> Tcgen05CopyInst: + def create(src: SharedTensor, dst: TMemoryTensor, multicast: str = "") -> Tcgen05CopyInst: # Note: 2D validation is performed at the lang layer (Tcgen05InstructionGroup.copy) - return Tcgen05CopyInst(output=None, inputs=(dst, src)) + return Tcgen05CopyInst(output=None, inputs=(dst, src), multicast=multicast) @dataclass(frozen=True, eq=False) diff --git a/python/tilus/ir/layout/inference/inference_rules/tcgen05/alloc.py b/python/tilus/ir/layout/inference/inference_rules/tcgen05/alloc.py index 7811bc3a..ad9becb8 100644 --- a/python/tilus/ir/layout/inference/inference_rules/tcgen05/alloc.py +++ b/python/tilus/ir/layout/inference/inference_rules/tcgen05/alloc.py @@ -20,8 +20,17 @@ register_rule, ) from tilus.ir.layout.ops.tmemory_ops import tmemory_row_major +from tilus.ir.layout.tmem_layout import TMemoryDuplication from tilus.ir.tensor import TMemoryTensor +# Map from shape[0] -> duplication mode that's uniquely determined by the lane size. +# shape[0]=64 is ambiguous (NONE / WARPX2_02_13 / WARPX2_01_23) and is left for +# downstream inference (e.g. from a tcgen05.copy multicast hint). +_LANE_TO_FORCED_DUPLICATION: dict[int, TMemoryDuplication] = { + 32: TMemoryDuplication.WARPX4, + 128: TMemoryDuplication.NONE, +} + @register_rule(Tcgen05AllocInst) class Tcgen05AllocRule(LayoutInferenceRule): @@ -30,4 +39,9 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05AllocInst) -> dict[TMemo tmem = inst.tmemory_output if tmem.optional_layout is not None: return {} - return {tmem: tmemory_row_major(tmem.shape)} + lane_size = tmem.shape[0] + if lane_size not in _LANE_TO_FORCED_DUPLICATION: + # shape[0]=64 — defer to downstream inference (writers/consumers). + return {} + duplication = _LANE_TO_FORCED_DUPLICATION[lane_size] + return {tmem: tmemory_row_major(tmem.shape, duplication=duplication)} diff --git a/python/tilus/ir/layout/inference/inference_rules/tcgen05/copy.py b/python/tilus/ir/layout/inference/inference_rules/tcgen05/copy.py index 55ec729c..8d7cb372 100644 --- a/python/tilus/ir/layout/inference/inference_rules/tcgen05/copy.py +++ b/python/tilus/ir/layout/inference/inference_rules/tcgen05/copy.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from tilus.ir.instructions.cuda.tcgen05 import Tcgen05CopyInst +from tilus.ir.layout import SharedLayout, TMemoryLayout from tilus.ir.layout.cuda.tcgen05.smem import ( Tcgen05SwizzleMode, generate_canonical_layout, @@ -24,22 +25,43 @@ LayoutInferenceRule, register_rule, ) -from tilus.ir.tensor import SharedLayout, SharedTensor +from tilus.ir.layout.ops.tmemory_ops import tmemory_row_major +from tilus.ir.layout.tmem_layout import TMemoryDuplication +from tilus.ir.tensor import SharedTensor, TMemoryTensor + +# Map the user-facing multicast string on Tcgen05CopyInst to the corresponding +# TMEM duplication mode. Empty string means no multicast (no replication). +_MULTICAST_TO_DUPLICATION: dict[str, TMemoryDuplication] = { + "": TMemoryDuplication.NONE, + "warpx4": TMemoryDuplication.WARPX4, + "warpx2_02_13": TMemoryDuplication.WARPX2_02_13, + "warpx2_01_23": TMemoryDuplication.WARPX2_01_23, +} @register_rule(Tcgen05CopyInst) class Tcgen05CopyRule(LayoutInferenceRule): @staticmethod - def inference(ctx: LayoutInferenceContext, inst: Tcgen05CopyInst) -> dict[SharedTensor, SharedLayout]: + def inference( + ctx: LayoutInferenceContext, inst: Tcgen05CopyInst + ) -> dict[SharedTensor | TMemoryTensor, SharedLayout | TMemoryLayout]: + result: dict[SharedTensor | TMemoryTensor, SharedLayout | TMemoryLayout] = {} + dst = inst.inputs[0].as_tmemory_tensor() src = inst.inputs[1].as_shared_tensor() - if src.has_layout(): - return {} + if not src.has_layout(): + if len(src.shape) != 2: + raise LayoutInferenceError(f"Only 2D SharedTensor is supported in copy, got shape {src.shape}") + canonical_layout = generate_canonical_layout( + (src.shape[0], src.shape[1]), src.dtype, "K", Tcgen05SwizzleMode.NO_SWIZZLE + ) + result[src] = get_shared_layout_from_canonical(canonical_layout) + + if not dst.has_layout(): + # Use the multicast kwarg to pick the right duplication. The + # Tcgen05AllocRule already pins layouts for unambiguous shape[0] + # (32 -> WARPX4, 128 -> NONE) and defers shape[0]=64 to here. + duplication = _MULTICAST_TO_DUPLICATION.get(inst.multicast, TMemoryDuplication.NONE) + result[dst] = tmemory_row_major(dst.shape, duplication=duplication) - if len(src.shape) != 2: - raise LayoutInferenceError(f"Only 2D SharedTensor is supported in copy, got shape {src.shape}") - canonical_layout = generate_canonical_layout( - (src.shape[0], src.shape[1]), src.dtype, "K", Tcgen05SwizzleMode.NO_SWIZZLE - ) - shared_layout = get_shared_layout_from_canonical(canonical_layout) - return {src: shared_layout} + return result diff --git a/python/tilus/ir/layout/ops/tmemory_ops.py b/python/tilus/ir/layout/ops/tmemory_ops.py index d49a0d9f..6298763f 100644 --- a/python/tilus/ir/layout/ops/tmemory_ops.py +++ b/python/tilus/ir/layout/ops/tmemory_ops.py @@ -15,9 +15,10 @@ from typing import Sequence from tilus.ir.layout import TMemoryLayout +from tilus.ir.layout.tmem_layout import TMemoryDuplication -def tmemory_row_major(shape: Sequence[int]) -> TMemoryLayout: +def tmemory_row_major(shape: Sequence[int], duplication: TMemoryDuplication = TMemoryDuplication.NONE) -> TMemoryLayout: # Convention: dim 0 is the lane axis (stride 0); dims 1..-1 are column-strided # in row-major order (innermost dim has stride 1). column_strides = [0] * len(shape) @@ -25,12 +26,27 @@ def tmemory_row_major(shape: Sequence[int]) -> TMemoryLayout: for dim in reversed(range(1, len(shape))): column_strides[dim] = stride stride *= shape[dim] - return TMemoryLayout.create(shape, column_strides, lane_offset=0) + return TMemoryLayout.create(shape, column_strides, lane_offset=0, duplication=duplication) def tmemory_slice( tmem_layout: TMemoryLayout, lane_offset: int, slice_dims: Sequence[int], shape: Sequence[int] ) -> TMemoryLayout: - lane_offset = tmem_layout.lane_offset + lane_offset + # Column-dim slicing preserves the parent's duplication and lane_offset. + # Lane-dim slicing (shape[0] differs from parent's shape[0]) is only legal + # for NONE-duplicated tensors and adds the requested lane_offset shift. + new_lane_offset = tmem_layout.lane_offset + lane_offset + if shape[0] != tmem_layout.shape[0] and tmem_layout.duplication != TMemoryDuplication.NONE: + raise ValueError( + "Lane-dim slicing of a TMEM tensor with duplication={} is not supported " + "(parent shape[0]={}, sliced shape[0]={}); only NONE-duplicated tensors " + "may have their lane dim sliced.".format(tmem_layout.duplication.value, tmem_layout.shape[0], shape[0]) + ) + if tmem_layout.duplication != TMemoryDuplication.NONE and lane_offset != 0: + raise ValueError( + "duplication={} disallows non-zero lane_offset shift, got {}".format( + tmem_layout.duplication.value, lane_offset + ) + ) strides = [tmem_layout.column_strides[dim] for dim in slice_dims] - return TMemoryLayout.create(shape, strides, lane_offset) + return TMemoryLayout.create(shape, strides, new_lane_offset, duplication=tmem_layout.duplication) diff --git a/python/tilus/ir/layout/tmem_layout.py b/python/tilus/ir/layout/tmem_layout.py index 5c087889..c02366bc 100644 --- a/python/tilus/ir/layout/tmem_layout.py +++ b/python/tilus/ir/layout/tmem_layout.py @@ -15,19 +15,67 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum from typing import Sequence from tilus.ir.node import IRNode +class TMemoryDuplication(Enum): + """How a TMEM tensor's lane data is replicated across sub-partitions. + + SM100 TMEM has 128 lanes per CTA, organized as 4 sub-partitions of 32 + lanes. Some MMA kinds (e.g. block-scaled MMAs that read SFs) require + operands to be replicated across sub-partitions; ``tcgen05.cp`` provides + the multicast modes that produce these layouts. The enum value is the + canonical name of the multicast pattern (matches the ``multicast=`` kwarg + on :meth:`tilus.lang.cuda.Tcgen05Module.copy`). + + Attributes + ---------- + NONE + No replication; each lane holds unique data. Compatible with + ``shape[0] in {32, 64, 128}`` (occupies that many physical lanes + starting at ``lane_offset``). + WARPX4 + 32 unique lanes broadcast to all 4 sub-partitions; ``shape[0] == 32``. + WARPX2_02_13 + 64 unique lanes broadcast to two warp-pairs (warps {0,2} share half, + warps {1,3} share the other half); ``shape[0] == 64``. + WARPX2_01_23 + 64 unique lanes broadcast to two warp-pairs (warps {0,1} share half, + warps {2,3} share the other half); ``shape[0] == 64``. + """ + + NONE = "none" + WARPX4 = "warpx4" + WARPX2_02_13 = "warpx2_02_13" + WARPX2_01_23 = "warpx2_01_23" + + +# Required ``shape[0]`` (unique lane count) for each duplication mode. +_DUPLICATION_LANE_SIZE: dict[TMemoryDuplication, set[int]] = { + TMemoryDuplication.NONE: {32, 64, 128}, + TMemoryDuplication.WARPX4: {32}, + TMemoryDuplication.WARPX2_02_13: {64}, + TMemoryDuplication.WARPX2_01_23: {64}, +} + + @dataclass(frozen=True, eq=False) class TMemoryLayout(IRNode): shape: tuple[int, ...] column_strides: tuple[int, ...] lane_offset: int + duplication: TMemoryDuplication = TMemoryDuplication.NONE @staticmethod - def create(shape: Sequence[int], column_strides: Sequence[int], lane_offset: int) -> TMemoryLayout: + def create( + shape: Sequence[int], + column_strides: Sequence[int], + lane_offset: int, + duplication: TMemoryDuplication = TMemoryDuplication.NONE, + ) -> TMemoryLayout: if len(shape) != len(column_strides): raise ValueError( "Dimension mismatch: shape has length {}, but column_strides has length {}".format( @@ -45,4 +93,26 @@ def create(shape: Sequence[int], column_strides: Sequence[int], lane_offset: int column_strides[0] ) ) - return TMemoryLayout(shape=tuple(shape), column_strides=tuple(column_strides), lane_offset=lane_offset) + # Cross-validate shape[0] against the duplication mode. + allowed_sizes = _DUPLICATION_LANE_SIZE[duplication] + if shape[0] not in allowed_sizes: + raise ValueError( + "duplication={} requires shape[0] in {}, got shape[0]={}".format( + duplication.value, sorted(allowed_sizes), shape[0] + ) + ) + # WARPX* modes always span all 128 physical lanes (replicated), so the + # tensor must start at lane 0. lane_offset is only meaningful for NONE. + if duplication != TMemoryDuplication.NONE and lane_offset != 0: + raise ValueError("duplication={} requires lane_offset=0, got {}".format(duplication.value, lane_offset)) + # NONE: data fits in one CTA's TMEM lanes [lane_offset, lane_offset + shape[0]). + if duplication == TMemoryDuplication.NONE and lane_offset + shape[0] > 128: + raise ValueError( + "lane_offset({}) + shape[0]({}) exceeds 128 physical TMEM lanes".format(lane_offset, shape[0]) + ) + return TMemoryLayout( + shape=tuple(shape), + column_strides=tuple(column_strides), + lane_offset=lane_offset, + duplication=duplication, + ) diff --git a/python/tilus/lang/instructions/tcgen05.py b/python/tilus/lang/instructions/tcgen05.py index 1dc23341..da842622 100644 --- a/python/tilus/lang/instructions/tcgen05.py +++ b/python/tilus/lang/instructions/tcgen05.py @@ -21,6 +21,12 @@ from .root import InstructionGroup +# Allowed user-facing names for the tcgen05.copy multicast kwarg. These match +# the duplication-mode names on TMemoryLayout. Stored as plain strings on the +# IR (no enum) so functors walk the field generically; converted to +# Tcgen05CopyMulticastKind at codegen time. +_VALID_COPY_MULTICAST_NAMES: tuple[str, ...] = ("warpx4", "warpx2_02_13", "warpx2_01_23") + class Tcgen05InstructionGroup(InstructionGroup): """Tensor Core Generation 05 (tcgen05) instructions for Blackwell GPUs. @@ -247,30 +253,43 @@ def wait_store(self) -> None: """ self._builder.tcgen05_wait_store() - def copy(self, src: SharedTensor, dst: TMemoryTensor) -> None: + def copy(self, src: SharedTensor, dst: TMemoryTensor, multicast: Optional[str] = None) -> None: """Copy data from shared memory to tensor memory. - Asynchronously copies a 2D shared tensor into a 2D tensor memory tensor. Use + Asynchronously copies a shared tensor into a tensor memory tensor. Use ``tcgen05.commit`` to signal completion via an mbarrier. Parameters ---------- src: SharedTensor - The source shared tensor. Must be 2D. + The source shared tensor. dst: TMemoryTensor - The destination tensor memory tensor. Must be 2D. + The destination tensor memory tensor. + multicast: Optional[str] + Multicast pattern for replicating ``src`` across TMEM sub-partitions. + + - ``None`` (default): plain 1:1 copy (no replication). + - ``"warpx4"``: replicate ``src`` to all 4 warp-aligned 32-lane stripes + of TMEM. Source ``src`` has 32 unique lane rows; ``dst`` is a TMEM + tensor with ``WARPX4`` duplication (``shape[0] == 32``). + - ``"warpx2_02_13"`` / ``"warpx2_01_23"``: replicate ``src`` to two + warp-pairs (by parity / by halves). Source has 64 unique lane + rows; ``dst`` has the matching ``WARPX2_*`` duplication + (``shape[0] == 64``). Notes ----- - **Thread group**: Must be executed by a warp-aligned thread group. - **Hardware**: Requires compute capability 10.0+ (sm_100). - - **PTX**: ``tcgen05.cp`` + - **PTX**: ``tcgen05.cp[.warpx4 / .warpx2_02_13 / .warpx2_01_23]`` """ - if len(src.shape) != 2: - raise InstructionError("copy requires a 2D shared tensor, got shape {}".format(src.shape)) - if len(dst.shape) != 2: - raise InstructionError("copy requires a 2D tensor memory tensor, got shape {}".format(dst.shape)) - self._builder.tcgen05_copy(src, dst) + if multicast is not None and multicast not in _VALID_COPY_MULTICAST_NAMES: + raise InstructionError( + "Unknown multicast mode {!r}. Expected None or one of: {}".format( + multicast, ", ".join(repr(k) for k in _VALID_COPY_MULTICAST_NAMES) + ) + ) + self._builder.tcgen05_copy(src, dst, multicast=multicast or "") def commit(self, mbarrier: Expr | RegisterTensor, cta_group: int = 1, multicast_mask: Optional[int] = None) -> None: """Commit pending tcgen05 async operations and signal an mbarrier. diff --git a/tests/instructions/test_tcgen05_copy.py b/tests/instructions/test_tcgen05_copy.py index 827eba71..bcc89fec 100644 --- a/tests/instructions/test_tcgen05_copy.py +++ b/tests/instructions/test_tcgen05_copy.py @@ -104,5 +104,120 @@ def test_tcgen05_copy(major_kind, swizzle_mode): torch.testing.assert_close(x, y) +class TmemCopyMulticastExample(tilus.Script): + """Round-trip SMEM -> TMEM (multicast) -> register -> global. + + Warps are chosen so the read-back covers the unique source rows of each + multicast. Multicasts replicate the source across TMEM sub-partitions; the load + instruction is constrained to ``current_num_threads == tmem.shape[0]``. + Reading exactly ``shape[0]`` rows back lets us verify the data the kernel + placed into TMEM. For ``warpx2_01_23`` the two warp pairs receive duplicate + halves, so the expected output is a tiling of the first 32 source rows. + """ + + def __init__(self, multicast: str, num_warps: int, block_m: int): + super().__init__() + self.multicast = multicast + self.num_warps = num_warps + self.block_m = block_m + self.block_n = 32 + self.shared_layout = generate_canonical_layout( + shape=(self.block_m, self.block_n), + dtype=int32, + major_kind="K", + swizzle_mode=Tcgen05SwizzleMode.NO_SWIZZLE, + ).as_shared_layout() + + def __call__(self, m_size: int, n_size: int, x_ptr: ~int32, y_ptr: ~int32): + self.attrs.blocks = cdiv(m_size, self.block_m), cdiv(n_size, self.block_n) + self.attrs.warps = self.num_warps + + m_offset = self.blockIdx.x * self.block_m + n_offset = self.blockIdx.y * self.block_n + + g_x = self.global_view(x_ptr, dtype=int32, shape=[m_size, n_size]) + g_y = self.global_view(y_ptr, dtype=int32, shape=[m_size, n_size]) + + s_x = self.shared_tensor(dtype=int32, shape=[self.block_m, self.block_n]) + t_x = self.tcgen05.alloc(dtype=int32, shape=[self.block_m, self.block_n]) + + barriers = self.mbarrier.alloc(counts=[1]) + + self.copy_async(src=g_x, dst=s_x, offsets=[m_offset, n_offset]) + self.copy_async_wait_all() + self.sync() + + with self.single_warp(): + self.tcgen05.copy(src=s_x, dst=t_x, multicast=self.multicast) + self.tcgen05.commit(mbarrier=barriers[0]) + self.mbarrier.wait(barriers[0], phase=0) + + r_y = self.tcgen05.load(t_x) + self.tcgen05.wait_load() + self.sync() + + self.store_global(g_y, r_y, offsets=[m_offset, n_offset]) + + self.tcgen05.dealloc(t_x) + + self.annotate_layout(s_x, self.shared_layout) + + +@tilus.testing.requires.nvgpu_sm100a +@pytest.mark.parametrize( + "multicast, num_warps, block_m", + [ + ("warpx4", 1, 32), + ("warpx2_02_13", 2, 64), + ("warpx2_01_23", 2, 64), + ], +) +def test_tcgen05_copy_multicast(multicast: str, num_warps: int, block_m: int): + n_size = 32 + x = torch.randint(0, 128, [block_m, n_size], dtype=torch.int32, device="cuda") + y = torch.zeros([block_m, n_size], dtype=torch.int32, device="cuda") + kernel = TmemCopyMulticastExample(multicast=multicast, num_warps=num_warps, block_m=block_m) + kernel(block_m, n_size, x, y) + torch.cuda.synchronize() + if multicast == "warpx2_01_23": + # Warps 0 and 1 share the first 32 source rows; the loaded TMEM + # therefore tiles src[0:32] across both halves of the output. + expected = torch.cat([x[:32], x[:32]], dim=0) + else: + expected = x + torch.testing.assert_close(y, expected) + + +def test_tcgen05_copy_multicast_invalid_name(): + """Reject unknown multicast names with a clear ``Unknown multicast`` error. + + The lang-layer ``copy(..., multicast=...)`` validates the string against + the allowed multicast-name set. + """ + + class _BadMulticast(tilus.Script): + def __init__(self): + super().__init__() + + def __call__(self, x_ptr: ~int32, y_ptr: ~int32): + self.attrs.blocks = 1 + self.attrs.warps = 4 + g_x = self.global_view(x_ptr, dtype=int32, shape=[128, 32]) + s_x = self.shared_tensor(dtype=int32, shape=[128, 32]) + t_x = self.tcgen05.alloc(dtype=int32, shape=[128, 32]) + self.copy_async(src=g_x, dst=s_x, offsets=[0, 0]) + self.copy_async_wait_all() + self.sync() + with self.single_warp(): + self.tcgen05.copy(src=s_x, dst=t_x, multicast="warpx_bogus") + self.tcgen05.dealloc(t_x) + + kernel = _BadMulticast() + x = torch.zeros(128, 32, dtype=torch.int32, device="cuda") + y = torch.zeros(128, 32, dtype=torch.int32, device="cuda") + with pytest.raises(Exception, match="Unknown multicast"): + kernel(x, y) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/ir/layout/test_tmem_layout.py b/tests/ir/layout/test_tmem_layout.py new file mode 100644 index 00000000..e25d88d2 --- /dev/null +++ b/tests/ir/layout/test_tmem_layout.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for :class:`TMemoryLayout` shape / duplication / lane_offset cross-validation.""" + +import pytest +from tilus.ir.layout import TMemoryLayout +from tilus.ir.layout.ops.tmemory_ops import tmemory_row_major, tmemory_slice +from tilus.ir.layout.tmem_layout import TMemoryDuplication + +# --------------------------------------------------------------------------- +# Construction: shape, column_strides, lane_offset, duplication +# --------------------------------------------------------------------------- + + +class TestCreate: + def test_default_duplication_is_none(self): + layout = TMemoryLayout.create(shape=(128, 64), column_strides=(0, 1), lane_offset=0) + assert layout.duplication == TMemoryDuplication.NONE + + @pytest.mark.parametrize("lane_size", [32, 64, 128]) + def test_none_duplication_accepts_all_lane_sizes(self, lane_size): + layout = TMemoryLayout.create( + shape=(lane_size, 16), + column_strides=(0, 1), + lane_offset=0, + duplication=TMemoryDuplication.NONE, + ) + assert layout.shape[0] == lane_size + + def test_warpx4_requires_lane_32(self): + # WARPX4 with shape[0]=32 — ok + TMemoryLayout.create( + shape=(32, 4), + column_strides=(0, 1), + lane_offset=0, + duplication=TMemoryDuplication.WARPX4, + ) + # WARPX4 with shape[0]=64 — illegal + with pytest.raises(ValueError, match="WARPX4|warpx4"): + TMemoryLayout.create( + shape=(64, 4), + column_strides=(0, 1), + lane_offset=0, + duplication=TMemoryDuplication.WARPX4, + ) + # WARPX4 with shape[0]=128 — illegal + with pytest.raises(ValueError, match="WARPX4|warpx4"): + TMemoryLayout.create( + shape=(128, 4), + column_strides=(0, 1), + lane_offset=0, + duplication=TMemoryDuplication.WARPX4, + ) + + @pytest.mark.parametrize("duplication", [TMemoryDuplication.WARPX2_02_13, TMemoryDuplication.WARPX2_01_23]) + def test_warpx2_requires_lane_64(self, duplication): + # legal: shape[0] == 64 + TMemoryLayout.create(shape=(64, 4), column_strides=(0, 1), lane_offset=0, duplication=duplication) + # illegal: shape[0] == 32 or 128 + for bad in (32, 128): + with pytest.raises(ValueError): + TMemoryLayout.create( + shape=(bad, 4), + column_strides=(0, 1), + lane_offset=0, + duplication=duplication, + ) + + @pytest.mark.parametrize( + "duplication", + [TMemoryDuplication.WARPX4, TMemoryDuplication.WARPX2_02_13, TMemoryDuplication.WARPX2_01_23], + ) + def test_duplicated_layout_disallows_nonzero_lane_offset(self, duplication): + lane_size = 32 if duplication == TMemoryDuplication.WARPX4 else 64 + with pytest.raises(ValueError, match="lane_offset"): + TMemoryLayout.create( + shape=(lane_size, 4), + column_strides=(0, 1), + lane_offset=32, + duplication=duplication, + ) + + def test_none_duplication_lane_offset_must_fit_in_128(self): + # legal: lane_offset + shape[0] <= 128 + TMemoryLayout.create(shape=(64, 4), column_strides=(0, 1), lane_offset=64) + TMemoryLayout.create(shape=(32, 4), column_strides=(0, 1), lane_offset=96) + # illegal: overflows 128 + with pytest.raises(ValueError, match="exceeds 128"): + TMemoryLayout.create(shape=(64, 4), column_strides=(0, 1), lane_offset=96) + with pytest.raises(ValueError, match="exceeds 128"): + TMemoryLayout.create(shape=(128, 4), column_strides=(0, 1), lane_offset=1) + + def test_invalid_lane_size(self): + with pytest.raises(ValueError, match="must be 32, 64, or 128"): + TMemoryLayout.create(shape=(16, 4), column_strides=(0, 1), lane_offset=0) + + def test_dim_mismatch(self): + with pytest.raises(ValueError, match="Dimension mismatch"): + TMemoryLayout.create(shape=(128, 4), column_strides=(0,), lane_offset=0) + + def test_lane_must_have_zero_stride(self): + with pytest.raises(ValueError, match="column_strides\\[0\\]"): + TMemoryLayout.create(shape=(128, 4), column_strides=(1, 1), lane_offset=0) + + def test_higher_rank(self): + # rank 3: lane=128, then two column-strided dims + layout = TMemoryLayout.create(shape=(128, 4, 8), column_strides=(0, 8, 1), lane_offset=0) + assert layout.shape == (128, 4, 8) + assert layout.column_strides == (0, 8, 1) + + +# --------------------------------------------------------------------------- +# tmemory_row_major +# --------------------------------------------------------------------------- + + +class TestRowMajor: + def test_default_none_duplication(self): + layout = tmemory_row_major(shape=(128, 64)) + assert layout.duplication == TMemoryDuplication.NONE + assert layout.column_strides == (0, 1) + + def test_warpx4_via_row_major(self): + layout = tmemory_row_major(shape=(32, 4, 4), duplication=TMemoryDuplication.WARPX4) + assert layout.duplication == TMemoryDuplication.WARPX4 + # row-major columns: dim 2 stride=1, dim 1 stride=4, dim 0 stride=0 (lane) + assert layout.column_strides == (0, 4, 1) + + def test_higher_rank_strides(self): + layout = tmemory_row_major(shape=(128, 2, 3, 4)) + # rightmost stride = 1, then 4, then 12; lane stride = 0 + assert layout.column_strides == (0, 12, 4, 1) + + +# --------------------------------------------------------------------------- +# tmemory_slice +# --------------------------------------------------------------------------- + + +class TestSlice: + def test_column_slice_preserves_duplication(self): + # parent: WARPX4 [32, M_fold=4, K=4] + parent = tmemory_row_major(shape=(32, 4, 4), duplication=TMemoryDuplication.WARPX4) + # slice at M_fold=1, keep [32, 4] + child = tmemory_slice(parent, lane_offset=0, slice_dims=[0, 2], shape=(32, 4)) + assert child.duplication == TMemoryDuplication.WARPX4 + assert child.shape == (32, 4) + assert child.lane_offset == 0 + + def test_lane_slice_only_for_none(self): + # NONE [128, 64] -> [64, 64] starting at lane 64 — legal + parent = tmemory_row_major(shape=(128, 64)) + child = tmemory_slice(parent, lane_offset=64, slice_dims=[0, 1], shape=(64, 64)) + assert child.duplication == TMemoryDuplication.NONE + assert child.shape == (64, 64) + assert child.lane_offset == 64 + + def test_lane_slice_warpx4_raises(self): + parent = tmemory_row_major(shape=(32, 4, 4), duplication=TMemoryDuplication.WARPX4) + # try to lane-slice (would change shape[0] from 32 to a different value + # OR keep 32 but apply non-zero lane_offset — both illegal for WARPX4). + with pytest.raises(ValueError, match="duplication=warpx4|lane_offset"): + tmemory_slice(parent, lane_offset=32, slice_dims=[0, 1, 2], shape=(32, 4, 4)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])