Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tilus/backends/emitters/cuda/tcgen05/alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
class Tcgen05AllocDeallocEmitter(BaseInstEmitter):
def get_num_columns(self, tmem_tensor: TMemoryTensor) -> int:
shape = tmem_tensor.shape
if shape[0] != 128:
raise NotImplementedError(f"The emitter currently only supports shape[0] == 128, but got {shape[0]}")
if shape[0] not in (32, 64, 128):
raise NotImplementedError(f"The emitter only supports shape[0] in (32, 64, 128), but got {shape[0]}")
if shape[-1] * tmem_tensor.dtype.nbits % 32 != 0:
raise ValueError(
f"shape[-1] * dtype.nbits must be divisible by 32, but got {shape[-1]} * {tmem_tensor.dtype.nbits} = {shape[-1] * tmem_tensor.dtype.nbits}"
Expand Down
65 changes: 53 additions & 12 deletions python/tilus/backends/emitters/cuda/tcgen05/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,42 @@ def __str__(self) -> str:
return "Tcgen05CopyInstMeta(" + ",\n ".join(items) + "\n)"


_MULTICAST_NAME_TO_KIND: dict[str, Tcgen05CopyMulticastKind] = {
"": Tcgen05CopyMulticastKind.NONE,
"warpx4": Tcgen05CopyMulticastKind.WARP_X4,
"warpx2_02_13": Tcgen05CopyMulticastKind.WARP_X2_02_13,
"warpx2_01_23": Tcgen05CopyMulticastKind.WARP_X2_01_23,
}

# Per-multicast: required source-row count in SMEM and the candidate shape kinds.
# Smaller shape kinds (R32x128B / R64x128B) are only valid with multicast — the
# table maps each multicast to the shape kind PTX requires for it.
_MULTICAST_TO_SOURCE_ROWS: dict[Tcgen05CopyMulticastKind, int] = {
Tcgen05CopyMulticastKind.NONE: 128,
Tcgen05CopyMulticastKind.WARP_X4: 32,
Tcgen05CopyMulticastKind.WARP_X2_02_13: 64,
Tcgen05CopyMulticastKind.WARP_X2_01_23: 64,
}

_MULTICAST_TO_SHAPE_KINDS: dict[Tcgen05CopyMulticastKind, tuple[Tcgen05CopyShapeKind, ...]] = {
Tcgen05CopyMulticastKind.NONE: (
Tcgen05CopyShapeKind.R128x256B,
Tcgen05CopyShapeKind.R128x128B,
),
Tcgen05CopyMulticastKind.WARP_X4: (Tcgen05CopyShapeKind.R32x128B,),
Tcgen05CopyMulticastKind.WARP_X2_02_13: (Tcgen05CopyShapeKind.R64x128B,),
Tcgen05CopyMulticastKind.WARP_X2_01_23: (Tcgen05CopyShapeKind.R64x128B,),
}


@register_emitter(Tcgen05CopyInst, target=nvgpu_sm100)
class Tcgen05CopyEmitter(BaseInstEmitter):
def split_canonical_layout(
self, smem_addr: Expr, canonical: CanonicalSharedLayout, shape_kind: Tcgen05CopyShapeKind
self,
smem_addr: Expr,
canonical: CanonicalSharedLayout,
shape_kind: Tcgen05CopyShapeKind,
multicast: Tcgen05CopyMulticastKind = Tcgen05CopyMulticastKind.NONE,
) -> list[Tcgen05CopyInstMeta]:
"""
Split the canonical shared layout into multiple sub-tensors that can be copied by tcgen05.copy instructions.
Expand Down Expand Up @@ -95,7 +127,7 @@ def split_canonical_layout(
raise GenerationFailedError(
"The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
)
if canonical.major_kind == "K" and (inst_m % 8 != 0 or inst_n % (canonical.T * 2) != 0):
if canonical.major_kind == "K" and (inst_m % 8 != 0 or inst_n % canonical.T != 0):
raise GenerationFailedError(
"The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
)
Expand Down Expand Up @@ -151,7 +183,7 @@ def split_canonical_layout(
instructions.append(
Tcgen05CopyInstMeta(
shape_kind=shape_kind,
multicast=Tcgen05CopyMulticastKind.NONE,
multicast=multicast,
cta_group=Tcgen05CtaGroupKind.CTA_1,
tmem_offset=tmem_offset,
shared_descriptor=s_desc,
Expand All @@ -161,7 +193,10 @@ def split_canonical_layout(
return instructions

def generate_instructions(
self, tmem_tensor: TMemoryTensor, shared_tensor: SharedTensor
self,
tmem_tensor: TMemoryTensor,
shared_tensor: SharedTensor,
multicast: Tcgen05CopyMulticastKind = Tcgen05CopyMulticastKind.NONE,
) -> list[Tcgen05CopyInstMeta]:
dtype = shared_tensor.dtype
canonical_layout: CanonicalSharedLayout | None = canonicalize_shared_layout(
Expand All @@ -176,12 +211,9 @@ def generate_instructions(
raise ValueError("\n".join(msg))
smem_addr = self.shared_tensor_shared_space_addr[shared_tensor]

for shape_kind in [
Tcgen05CopyShapeKind.R128x256B,
Tcgen05CopyShapeKind.R128x128B,
]:
for shape_kind in _MULTICAST_TO_SHAPE_KINDS[multicast]:
try:
return self.split_canonical_layout(smem_addr, canonical_layout, shape_kind)
return self.split_canonical_layout(smem_addr, canonical_layout, shape_kind, multicast)
except GenerationFailedError:
continue

Expand All @@ -193,19 +225,28 @@ def emit(self, inst: Tcgen05CopyInst) -> None:

self.assert_is_warp_aligned(inst, "tcgen05.copy is a warp-cooperative instruction")

multicast_kind = _MULTICAST_NAME_TO_KIND.get(inst.multicast)
if multicast_kind is None:
raise ValueError("Unknown multicast {!r} on Tcgen05CopyInst".format(inst.multicast))

if len(shared_tensor.shape) != 2:
raise ValueError("The shared tensor must be a 2D tensor, got shape {}".format(shared_tensor.shape))
if len(tmem_tensor.shape) != 2:
raise ValueError("The tensor memory tensor must be a 2D tensor, got shape {}".format(tmem_tensor.shape))
if shared_tensor.shape[0] != 128:
raise NotImplementedError("The number of rows in the shared tensor must be 128")
expected_rows = _MULTICAST_TO_SOURCE_ROWS[multicast_kind]
if shared_tensor.shape[0] != expected_rows:
raise ValueError(
"tcgen05.copy multicast={!r} requires the shared tensor to have {} rows, got {}".format(
inst.multicast, expected_rows, shared_tensor.shape[0]
)
)
if tmem_tensor.layout.lane_offset != 0:
raise NotImplementedError("The first lane of the tmem tensor must be 0")

tmem_base_addr = self.tensor2var[tmem_tensor]

with self.single_thread():
insts = self.generate_instructions(tmem_tensor, shared_tensor)
insts = self.generate_instructions(tmem_tensor, shared_tensor, multicast_kind)
for inst_meta in insts:
s_desc = self.declare_var("s_desc", tp=uint64, init=inst_meta.shared_descriptor.encoded())
t_addr = tmem_base_addr + inst_meta.tmem_offset
Expand Down
4 changes: 2 additions & 2 deletions python/tilus/hidet/ir/primitives/cuda/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def m(self) -> int:

class Tcgen05CopyMulticastKind(Enum):
NONE = ""
WARP_X2_02_13 = ".warpx2_02_13"
WARP_X2_01_23 = ".warpx2_01_23"
WARP_X2_02_13 = ".warpx2::02_13"
WARP_X2_01_23 = ".warpx2::01_23"
WARP_X4 = ".warpx4"


Expand Down
4 changes: 2 additions & 2 deletions python/tilus/ir/builders/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,8 +1624,8 @@ def tcgen05_wait_store(self) -> None:
inst = Tcgen05WaitInst.create(wait_load=False, wait_store=True)
self.append(inst)

def tcgen05_copy(self, src: SharedTensor, dst: TMemoryTensor) -> None:
inst = Tcgen05CopyInst.create(src=src, dst=dst)
def tcgen05_copy(self, src: SharedTensor, dst: TMemoryTensor, multicast: str = "") -> None:
inst = Tcgen05CopyInst.create(src=src, dst=dst, multicast=multicast)
self.append(inst)

def tcgen05_commit(
Expand Down
10 changes: 8 additions & 2 deletions python/tilus/ir/instructions/cuda/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,16 @@ def create(wait_load: bool, wait_store: bool) -> Tcgen05WaitInst:

@dataclass(frozen=True, eq=False)
class Tcgen05CopyInst(Instruction):
# Multicast pattern as a primitive string ("" for no multicast, "warpx4",
# "warpx2_02_13", "warpx2_01_23"). Converted to Tcgen05CopyMulticastKind in
# the codegen emitter — primitive types are kept here so IR functors can
# walk the field generically.
multicast: str = ""

@staticmethod
def create(src: SharedTensor, dst: TMemoryTensor) -> Tcgen05CopyInst:
def create(src: SharedTensor, dst: TMemoryTensor, multicast: str = "") -> Tcgen05CopyInst:
# Note: 2D validation is performed at the lang layer (Tcgen05InstructionGroup.copy)
return Tcgen05CopyInst(output=None, inputs=(dst, src))
return Tcgen05CopyInst(output=None, inputs=(dst, src), multicast=multicast)


@dataclass(frozen=True, eq=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,17 @@
register_rule,
)
from tilus.ir.layout.ops.tmemory_ops import tmemory_row_major
from tilus.ir.layout.tmem_layout import TMemoryDuplication
from tilus.ir.tensor import TMemoryTensor

# Map from shape[0] -> duplication mode that's uniquely determined by the lane size.
# shape[0]=64 is ambiguous (NONE / WARPX2_02_13 / WARPX2_01_23) and is left for
# downstream inference (e.g. from a tcgen05.copy multicast hint).
_LANE_TO_FORCED_DUPLICATION: dict[int, TMemoryDuplication] = {
32: TMemoryDuplication.WARPX4,
128: TMemoryDuplication.NONE,
}


@register_rule(Tcgen05AllocInst)
class Tcgen05AllocRule(LayoutInferenceRule):
Expand All @@ -30,4 +39,9 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05AllocInst) -> dict[TMemo
tmem = inst.tmemory_output
if tmem.optional_layout is not None:
return {}
return {tmem: tmemory_row_major(tmem.shape)}
lane_size = tmem.shape[0]
if lane_size not in _LANE_TO_FORCED_DUPLICATION:
# shape[0]=64 — defer to downstream inference (writers/consumers).
return {}
duplication = _LANE_TO_FORCED_DUPLICATION[lane_size]
return {tmem: tmemory_row_major(tmem.shape, duplication=duplication)}
44 changes: 33 additions & 11 deletions python/tilus/ir/layout/inference/inference_rules/tcgen05/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from tilus.ir.instructions.cuda.tcgen05 import Tcgen05CopyInst
from tilus.ir.layout import SharedLayout, TMemoryLayout
from tilus.ir.layout.cuda.tcgen05.smem import (
Tcgen05SwizzleMode,
generate_canonical_layout,
Expand All @@ -24,22 +25,43 @@
LayoutInferenceRule,
register_rule,
)
from tilus.ir.tensor import SharedLayout, SharedTensor
from tilus.ir.layout.ops.tmemory_ops import tmemory_row_major
from tilus.ir.layout.tmem_layout import TMemoryDuplication
from tilus.ir.tensor import SharedTensor, TMemoryTensor

# Map the user-facing multicast string on Tcgen05CopyInst to the corresponding
# TMEM duplication mode. Empty string means no multicast (no replication).
_MULTICAST_TO_DUPLICATION: dict[str, TMemoryDuplication] = {
"": TMemoryDuplication.NONE,
"warpx4": TMemoryDuplication.WARPX4,
"warpx2_02_13": TMemoryDuplication.WARPX2_02_13,
"warpx2_01_23": TMemoryDuplication.WARPX2_01_23,
}


@register_rule(Tcgen05CopyInst)
class Tcgen05CopyRule(LayoutInferenceRule):
@staticmethod
def inference(ctx: LayoutInferenceContext, inst: Tcgen05CopyInst) -> dict[SharedTensor, SharedLayout]:
def inference(
ctx: LayoutInferenceContext, inst: Tcgen05CopyInst
) -> dict[SharedTensor | TMemoryTensor, SharedLayout | TMemoryLayout]:
result: dict[SharedTensor | TMemoryTensor, SharedLayout | TMemoryLayout] = {}
dst = inst.inputs[0].as_tmemory_tensor()
src = inst.inputs[1].as_shared_tensor()

if src.has_layout():
return {}
if not src.has_layout():
if len(src.shape) != 2:
raise LayoutInferenceError(f"Only 2D SharedTensor is supported in copy, got shape {src.shape}")
canonical_layout = generate_canonical_layout(
(src.shape[0], src.shape[1]), src.dtype, "K", Tcgen05SwizzleMode.NO_SWIZZLE
)
result[src] = get_shared_layout_from_canonical(canonical_layout)

if not dst.has_layout():
# Use the multicast kwarg to pick the right duplication. The
# Tcgen05AllocRule already pins layouts for unambiguous shape[0]
# (32 -> WARPX4, 128 -> NONE) and defers shape[0]=64 to here.
duplication = _MULTICAST_TO_DUPLICATION.get(inst.multicast, TMemoryDuplication.NONE)
result[dst] = tmemory_row_major(dst.shape, duplication=duplication)

if len(src.shape) != 2:
raise LayoutInferenceError(f"Only 2D SharedTensor is supported in copy, got shape {src.shape}")
canonical_layout = generate_canonical_layout(
(src.shape[0], src.shape[1]), src.dtype, "K", Tcgen05SwizzleMode.NO_SWIZZLE
)
shared_layout = get_shared_layout_from_canonical(canonical_layout)
return {src: shared_layout}
return result
24 changes: 20 additions & 4 deletions python/tilus/ir/layout/ops/tmemory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,38 @@
from typing import Sequence

from tilus.ir.layout import TMemoryLayout
from tilus.ir.layout.tmem_layout import TMemoryDuplication


def tmemory_row_major(shape: Sequence[int]) -> TMemoryLayout:
def tmemory_row_major(shape: Sequence[int], duplication: TMemoryDuplication = TMemoryDuplication.NONE) -> TMemoryLayout:
# Convention: dim 0 is the lane axis (stride 0); dims 1..-1 are column-strided
# in row-major order (innermost dim has stride 1).
column_strides = [0] * len(shape)
stride = 1
for dim in reversed(range(1, len(shape))):
column_strides[dim] = stride
stride *= shape[dim]
return TMemoryLayout.create(shape, column_strides, lane_offset=0)
return TMemoryLayout.create(shape, column_strides, lane_offset=0, duplication=duplication)


def tmemory_slice(
tmem_layout: TMemoryLayout, lane_offset: int, slice_dims: Sequence[int], shape: Sequence[int]
) -> TMemoryLayout:
lane_offset = tmem_layout.lane_offset + lane_offset
# Column-dim slicing preserves the parent's duplication and lane_offset.
# Lane-dim slicing (shape[0] differs from parent's shape[0]) is only legal
# for NONE-duplicated tensors and adds the requested lane_offset shift.
new_lane_offset = tmem_layout.lane_offset + lane_offset
if shape[0] != tmem_layout.shape[0] and tmem_layout.duplication != TMemoryDuplication.NONE:
raise ValueError(
"Lane-dim slicing of a TMEM tensor with duplication={} is not supported "
"(parent shape[0]={}, sliced shape[0]={}); only NONE-duplicated tensors "
"may have their lane dim sliced.".format(tmem_layout.duplication.value, tmem_layout.shape[0], shape[0])
)
if tmem_layout.duplication != TMemoryDuplication.NONE and lane_offset != 0:
raise ValueError(
"duplication={} disallows non-zero lane_offset shift, got {}".format(
tmem_layout.duplication.value, lane_offset
)
)
strides = [tmem_layout.column_strides[dim] for dim in slice_dims]
return TMemoryLayout.create(shape, strides, lane_offset)
return TMemoryLayout.create(shape, strides, new_lane_offset, duplication=tmem_layout.duplication)
Loading
Loading