diff --git a/examples/blackwell_matmul/matmul_v5.py b/examples/blackwell_matmul/matmul_v5.py index 23ee5135..b0715d43 100644 --- a/examples/blackwell_matmul/matmul_v5.py +++ b/examples/blackwell_matmul/matmul_v5.py @@ -152,7 +152,8 @@ def __call__( s_a = self.shared_tensor(dtype=float16, shape=[tma_stages, block_m, block_k]) s_b = self.shared_tensor(dtype=float16, shape=[tma_stages, block_n, block_k]) # multi-stage accumulator: allows MMA and epilogue to overlap via mma_pipe - t_acc = self.tcgen05.alloc(dtype=float32, shape=[mma_stages, block_m, block_n]) + # lane dim (block_m) is first; the stages axis is a column-strided sub-axis + t_acc = self.tcgen05.alloc(dtype=float32, shape=[block_m, mma_stages, block_n]) # 16-byte buffer for CLC responses (cancel result + blockIdx) s_clc_response = self.shared_tensor(dtype=int32, shape=[clc_stages, 4]) @@ -214,7 +215,7 @@ def __call__( self.tcgen05.mma( s_a[tma_pipe.consumer_stage], s_b[tma_pipe.consumer_stage].transpose(), - t_acc[mma_pipe.producer_stage], + t_acc[:, mma_pipe.producer_stage, :], enable_input_d=offset_k != 0, ) self.tcgen05.commit(mbarrier=tma_pipe.consumer_barrier()) @@ -261,7 +262,7 @@ def __call__( for e_offset_n in range(0, block_n, e_block_n): t_acc_slice = self.tcgen05.slice( - t_acc[mma_pipe.consumer_stage], + t_acc[:, mma_pipe.consumer_stage, :], offsets=[0, e_offset_n], shape=[block_m, e_block_n], dims=[0, 1], diff --git a/examples/blackwell_matmul/matmul_v6.py b/examples/blackwell_matmul/matmul_v6.py index 1fe1788b..52eefb4d 100644 --- a/examples/blackwell_matmul/matmul_v6.py +++ b/examples/blackwell_matmul/matmul_v6.py @@ -10,6 +10,7 @@ from tilus.utils import benchmark_func, cdiv tilus.option.cache_dir("./cache") +tilus.option.debug.dump_ir() class Pipeline(tilus.Class): # same as V4/V5 @@ -69,6 +70,16 @@ def consumer_advance(self): @tilus.autotune("mma_stages", [2]) @tilus.autotune("swizzle_size", [4, 8, 16]) class BlackwellMatmulV6(tilus.Script): + debug_schedule = dict( + block_m=256, + block_n=256, + e_block_n=16, + block_k=64, + tma_stages=5, + mma_stages=2, + swizzle_size=8, + ) + def __init__( self, block_m: int, @@ -156,8 +167,9 @@ def __call__( s_a = self.shared_tensor(dtype=float16, shape=[tma_stages, block_m // 2, block_k]) s_b = self.shared_tensor(dtype=float16, shape=[tma_stages, block_n // 2, block_k]) # cta_group=2: distributed MMA reads shared memory from both CTAs + # lane dim (block_m // 2 per CTA) is first; the stages axis is a column-strided sub-axis t_acc = self.tcgen05.alloc( - dtype=float32, shape=[mma_stages, block_m // 2, block_n], cta_group=2 + dtype=float32, shape=[block_m // 2, mma_stages, block_n], cta_group=2 ) s_clc_response = self.shared_tensor(dtype=int32, shape=[clc_stages, 4]) @@ -225,7 +237,7 @@ def __call__( self.tcgen05.mma( s_a[tma_pipe.consumer_stage], s_b[tma_pipe.consumer_stage].transpose(), - t_acc[mma_pipe.producer_stage], + t_acc[:, mma_pipe.producer_stage, :], enable_input_d=offset_k != 0, cta_group=2, ) @@ -282,7 +294,7 @@ def __call__( for e_offset_n in range(0, block_n, e_block_n): t_acc_slice = self.tcgen05.slice( - t_acc[mma_pipe.consumer_stage], + t_acc[:, mma_pipe.consumer_stage, :], offsets=[0, e_offset_n], shape=[block_m // 2, e_block_n], dims=[0, 1], diff --git a/python/tilus/backends/emitters/cuda/tcgen05/alloc.py b/python/tilus/backends/emitters/cuda/tcgen05/alloc.py index 8bfd4968..ecc13e59 100644 --- a/python/tilus/backends/emitters/cuda/tcgen05/alloc.py +++ b/python/tilus/backends/emitters/cuda/tcgen05/alloc.py @@ -38,13 +38,15 @@ class Tcgen05AllocDeallocEmitter(BaseInstEmitter): def get_num_columns(self, tmem_tensor: TMemoryTensor) -> int: shape = tmem_tensor.shape - if shape[-2] != 128: - raise NotImplementedError(f"The emitter currently only supports shape[-2] == 128, but got {shape[-2]}") + if shape[0] != 128: + raise NotImplementedError(f"The emitter currently only supports shape[0] == 128, but got {shape[0]}") if shape[-1] * tmem_tensor.dtype.nbits % 32 != 0: raise ValueError( f"shape[-1] * dtype.nbits must be divisible by 32, but got {shape[-1]} * {tmem_tensor.dtype.nbits} = {shape[-1] * tmem_tensor.dtype.nbits}" ) - num_columns = prod(shape[:-2]) * shape[-1] * tmem_tensor.dtype.nbits // 32 + # All dimensions after the lane dim (shape[0]) are column-strided; total column count + # is the product of those dims, scaled by the dtype's bits-per-element / 32. + num_columns = prod(shape[1:]) * tmem_tensor.dtype.nbits // 32 # the number of columns must be a power-of-two and in the range [32, 512] # normalize it to be at least 32 and be a power of two @@ -142,9 +144,10 @@ def emit(self, inst: Tcgen05ViewInst) -> None: ): raise ValueError("The total number of bits must be the same as the original tensor.") - if not same_list(tmem_tensor.layout.column_strides[:-2], output_tmem_tensor.layout.column_strides[:-2]): + if not same_list(tmem_tensor.layout.column_strides[1:-1], output_tmem_tensor.layout.column_strides[1:-1]): raise ValueError( - "The column strides of the leading dimensions (all dimensions except the last two ones) must be the same as the original tensor." + "The column strides of the middle dimensions (all dimensions except the lane (dim 0) and " + "the innermost column (dim -1)) must be the same as the original tensor." ) tmem_addr = self.get_or_allocate_var(tmem_tensor) diff --git a/python/tilus/backends/emitters/cuda/tcgen05/ldst.py b/python/tilus/backends/emitters/cuda/tcgen05/ldst.py index 8716ee83..6ea35061 100644 --- a/python/tilus/backends/emitters/cuda/tcgen05/ldst.py +++ b/python/tilus/backends/emitters/cuda/tcgen05/ldst.py @@ -68,7 +68,7 @@ def emit_tcgen05_instructions( raise ValueError( "Lane mismatch: the first lane of the tmem tensor must be the same as the thread group begin" ) - if self.current_num_threads != tmem_tensor.shape[-2]: + if self.current_num_threads != tmem_tensor.shape[0]: raise ValueError( "The number of threads in the current thread group must be the same as the number of lanes in the tmem tensor" ) diff --git a/python/tilus/hidet/backend/codegen.py b/python/tilus/hidet/backend/codegen.py index 5023e7cb..0ab4e402 100644 --- a/python/tilus/hidet/backend/codegen.py +++ b/python/tilus/hidet/backend/codegen.py @@ -502,14 +502,16 @@ def visit_Let(self, e: Let): raise ValueError("please run 'expand_let_expr' pass before codegen") def visit_Var(self, e: Var): + # Function-typed Vars are global symbols: their Var.name is the canonical + # identifier that must match the function definition. Use it verbatim + # instead of routing through the Namer's identity-based disambiguation. + if isinstance(e.type, FuncType) and e.name is not None: + return Text(self.canonize_funcname(e.name)) cast2int = {"threadIdx.x", "threadIdx.y", "threadIdx.z", "blockIdx.x", "blockIdx.y", "blockIdx.z"} name = self.namer.get_name(e) if name in cast2int: return Text(f"(int){name}") - else: - if isinstance(e.type, FuncType): - name = self.canonize_funcname(name) - return Text(name) + return Text(name) def visit_Constant(self, e: Constant): if e.is_string(): diff --git a/python/tilus/hidet/ir/tools/printer.py b/python/tilus/hidet/ir/tools/printer.py index dbbe801d..3d0155c1 100644 --- a/python/tilus/hidet/ir/tools/printer.py +++ b/python/tilus/hidet/ir/tools/printer.py @@ -366,6 +366,11 @@ def visit_Address(self, e: Address): return Text("&") + self(e.expr) def visit_Var(self, e: Var): + # Function-typed Vars are global symbols: their Var.name is a canonical + # identifier that must match the function definition, so use it verbatim + # instead of going through the Namer's identity-based disambiguation. + if isinstance(e.type, FuncType) and e.name is not None: + return Text(e.name) return Text(self.namer.get_name(e)) def visit_Constant(self, e: Constant): diff --git a/python/tilus/ir/instructions/cuda/tcgen05.py b/python/tilus/ir/instructions/cuda/tcgen05.py index 9dda1678..62229a86 100644 --- a/python/tilus/ir/instructions/cuda/tcgen05.py +++ b/python/tilus/ir/instructions/cuda/tcgen05.py @@ -30,7 +30,7 @@ class Tcgen05AllocInst(Instruction): @staticmethod def create(dtype: DataType, shape: Sequence[int], cta_group: int) -> Tcgen05AllocInst: assert len(shape) >= 2, "Tcgen05AllocInst only supports tensors with rank >= 2." - assert shape[-2] in (32, 64, 128), "The second last dimension must be 32, 64, or 128." + assert shape[0] in (32, 64, 128), "The first (lane) dimension must be 32, 64, or 128." output = TMemoryTensor.create(dtype=dtype, shape=shape) return Tcgen05AllocInst(output=output, inputs=(), cta_group=cta_group) @@ -65,10 +65,11 @@ def create( ) -> Tcgen05SliceInst: assert len(tmem.shape) == len(offsets) assert len(slice_shape) == len(slice_dims) - assert len(slice_dims) >= 2 and all(len(tmem.shape) - 1 - i in slice_dims for i in range(2)), ( - "The last two dimensions must be included in the slice." + # The lane dim (0) and the innermost column dim (-1) must always be in the slice. + assert len(slice_dims) >= 2 and 0 in slice_dims and (len(tmem.shape) - 1) in slice_dims, ( + "The lane dim (0) and the innermost column dim (-1) must be included in the slice." ) - assert isinstance(offsets[-2], Constant), "The row-offset must be a constant." + assert isinstance(offsets[0], Constant), "The lane (row) offset must be a constant." output = TMemoryTensor.create(dtype=tmem.dtype, shape=slice_shape) return Tcgen05SliceInst(output=output, inputs=(tmem,), offsets=tuple(offsets), slice_dims=tuple(slice_dims)) diff --git a/python/tilus/ir/layout/inference/inference_rules/tcgen05/ldst.py b/python/tilus/ir/layout/inference/inference_rules/tcgen05/ldst.py index 6d62812a..bb9dcd93 100644 --- a/python/tilus/ir/layout/inference/inference_rules/tcgen05/ldst.py +++ b/python/tilus/ir/layout/inference/inference_rules/tcgen05/ldst.py @@ -91,7 +91,7 @@ def inference( # check that the lane matches the threads lane_begin = tmem_tensor.layout.lane_offset - lane_end = lane_begin + tmem_tensor.shape[-2] + lane_end = lane_begin + tmem_tensor.shape[0] thread_begin = ctx.thread_begin thread_end = ctx.thread_end diff --git a/python/tilus/ir/layout/inference/inference_rules/tcgen05/slice.py b/python/tilus/ir/layout/inference/inference_rules/tcgen05/slice.py index e0896083..f0442ee7 100644 --- a/python/tilus/ir/layout/inference/inference_rules/tcgen05/slice.py +++ b/python/tilus/ir/layout/inference/inference_rules/tcgen05/slice.py @@ -31,7 +31,7 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05SliceInst) -> dict[TMemo return { tmem: tmemory_slice( tmem_layout=inst.tmemory_input.optional_layout, - lane_offset=int(inst.offsets[-2]), + lane_offset=int(inst.offsets[0]), slice_dims=inst.slice_dims, shape=tmem.shape, ) diff --git a/python/tilus/ir/layout/ops/tmemory_ops.py b/python/tilus/ir/layout/ops/tmemory_ops.py index 2ae5198d..d49a0d9f 100644 --- a/python/tilus/ir/layout/ops/tmemory_ops.py +++ b/python/tilus/ir/layout/ops/tmemory_ops.py @@ -18,15 +18,13 @@ def tmemory_row_major(shape: Sequence[int]) -> TMemoryLayout: - column_strides = [] + # Convention: dim 0 is the lane axis (stride 0); dims 1..-1 are column-strided + # in row-major order (innermost dim has stride 1). + column_strides = [0] * len(shape) stride = 1 - for dim in reversed(range(len(shape))): - if dim == len(shape) - 2: - column_strides.append(0) - else: - column_strides.append(stride) - stride *= shape[dim] - column_strides = list(reversed(column_strides)) + for dim in reversed(range(1, len(shape))): + column_strides[dim] = stride + stride *= shape[dim] return TMemoryLayout.create(shape, column_strides, lane_offset=0) diff --git a/python/tilus/ir/layout/tmem_layout.py b/python/tilus/ir/layout/tmem_layout.py index fe9aed32..5c087889 100644 --- a/python/tilus/ir/layout/tmem_layout.py +++ b/python/tilus/ir/layout/tmem_layout.py @@ -36,12 +36,13 @@ def create(shape: Sequence[int], column_strides: Sequence[int], lane_offset: int ) if len(shape) < 2: raise ValueError("TMemLayout requires at least 2 dimensions, got {}".format(len(shape))) - if shape[-2] not in [32, 64, 128]: - raise ValueError("The number of rows (shape[-2]) must be 32, 64, or 128, got {}".format(shape[-2])) - if column_strides[-2] != 0: + # Convention: shape[0] is the lane (row) dimension; all other dims are column-strided. + if shape[0] not in [32, 64, 128]: + raise ValueError("The number of rows (shape[0]) must be 32, 64, or 128, got {}".format(shape[0])) + if column_strides[0] != 0: raise ValueError( - "The column stride for the row dimension (column_strides[-2]) must be 0, got {}".format( - column_strides[-2] + "The column stride for the row dimension (column_strides[0]) must be 0, got {}".format( + column_strides[0] ) ) return TMemoryLayout(shape=tuple(shape), column_strides=tuple(column_strides), lane_offset=lane_offset) diff --git a/python/tilus/ir/tensor.py b/python/tilus/ir/tensor.py index 0e36abb2..2e33294b 100644 --- a/python/tilus/ir/tensor.py +++ b/python/tilus/ir/tensor.py @@ -661,15 +661,15 @@ class TMemoryTensor(Tensor): Tensor memory is a dedicated on-chip memory available on Blackwell (SM 10.0+) GPUs, private to the SM's tensor cores. It is organized as a 2D structure of lanes (rows) and columns, with each cell being 32 bits. The number - of lanes (``shape[-2]``) must be 32, 64, or 128. + of lanes (``shape[0]``) must be 32, 64, or 128. Attributes ---------- dtype: DataType The data type of the tensor elements. shape: tuple[int, ...] - The shape of the tensor. Must have at least 2 dimensions. The second-to-last dimension (lanes) must be - 32, 64, or 128. + The shape of the tensor. Must have at least 2 dimensions. The first dimension (lanes) must be + 32, 64, or 128. All remaining dimensions are column-strided. optional_layout: TMemoryLayout, optional The layout of the tensor, which is optional. When not provided, the layout will be automatically inferred with compiler pass. @@ -687,8 +687,8 @@ def create(dtype: DataType, shape: Sequence[int], optional_layout: Optional[TMem dtype: DataType The data type of the tensor elements. shape: Sequence[int] - The shape of the tensor. Must have at least 2 dimensions, with the second-to-last - dimension (lanes) being 32, 64, or 128. + The shape of the tensor. Must have at least 2 dimensions, with the first + dimension (lanes) being 32, 64, or 128. All remaining dimensions are column-strided. optional_layout: TMemoryLayout, optional The layout of the tensor. If not provided, the layout will be inferred later. @@ -699,8 +699,8 @@ def create(dtype: DataType, shape: Sequence[int], optional_layout: Optional[TMem """ if len(shape) < 2: raise ValueError("TMemoryTensor requires at least 2 dimensions, got {}".format(len(shape))) - if shape[-2] not in (32, 64, 128): - raise ValueError("The number of rows (shape[-2]) must be 32, 64, or 128, got {}".format(shape[-2])) + if shape[0] not in (32, 64, 128): + raise ValueError("The number of rows (shape[0]) must be 32, 64, or 128, got {}".format(shape[0])) if optional_layout is not None and tuple(shape) != tuple(optional_layout.shape): raise ValueError( f"Shape mismatch: provided shape {shape} does not match layout shape {optional_layout.shape}." @@ -757,7 +757,7 @@ def with_layout(self, layout: TMemoryLayout) -> TMemoryTensor: converted in the Tilus Script transpiler defined in tilus.lang.transpiler module. """ - def __getitem__(self, indices: tuple[Expr | int, ...] | Expr | int) -> TMemoryTensor: + def __getitem__(self, indices: tuple[Expr | int | None | slice, ...] | Expr | int | None | slice) -> TMemoryTensor: raise RuntimeError("tmemory_tensor[...] could only be used in Tilus Script.") diff --git a/python/tilus/lang/instructions/tcgen05.py b/python/tilus/lang/instructions/tcgen05.py index 006dacf3..1dc23341 100644 --- a/python/tilus/lang/instructions/tcgen05.py +++ b/python/tilus/lang/instructions/tcgen05.py @@ -57,8 +57,9 @@ def alloc(self, dtype: DataType, shape: Sequence[int], cta_group: int = 1) -> TM dtype: DataType The data type of the tensor elements (e.g., ``float32``, ``float16``). shape: Sequence[int] - The shape of the tensor. Must have at least 2 dimensions. The second-to-last - dimension (``shape[-2]``) must be 32, 64, or 128. + The shape of the tensor. Must have at least 2 dimensions. The first + dimension (``shape[0]``) is the lane axis and must be 32, 64, or 128. + All remaining dimensions are column-strided. cta_group: int The CTA group size for the allocation. Must be 1 or 2. When 2, the tensor is shared across two CTAs in the same cluster. @@ -78,8 +79,8 @@ def alloc(self, dtype: DataType, shape: Sequence[int], cta_group: int = 1) -> TM raise InstructionError("cta_group must be 1 or 2") if len(shape) < 2: raise InstructionError("shape must be a sequence of length 2 or more, got {}".format(shape)) - if shape[-2] not in (32, 64, 128): - raise InstructionError("shape[-2] must be 32, 64, or 128, got {}".format(shape[-2])) + if shape[0] not in (32, 64, 128): + raise InstructionError("shape[0] must be 32, 64, or 128, got {}".format(shape[0])) if 128 % dtype.nbits != 0: raise InstructionError("dtype must be 1, 2, 4, 8, 16, 32, 64, or 128 bit, got {}".format(dtype)) ret = self._builder.tcgen05_alloc(dtype, shape, cta_group)