Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/blackwell_matmul/matmul_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def __call__(
s_a = self.shared_tensor(dtype=float16, shape=[tma_stages, block_m, block_k])
s_b = self.shared_tensor(dtype=float16, shape=[tma_stages, block_n, block_k])
# multi-stage accumulator: allows MMA and epilogue to overlap via mma_pipe
t_acc = self.tcgen05.alloc(dtype=float32, shape=[mma_stages, block_m, block_n])
# lane dim (block_m) is first; the stages axis is a column-strided sub-axis
t_acc = self.tcgen05.alloc(dtype=float32, shape=[block_m, mma_stages, block_n])

# 16-byte buffer for CLC responses (cancel result + blockIdx)
s_clc_response = self.shared_tensor(dtype=int32, shape=[clc_stages, 4])
Expand Down Expand Up @@ -214,7 +215,7 @@ def __call__(
self.tcgen05.mma(
s_a[tma_pipe.consumer_stage],
s_b[tma_pipe.consumer_stage].transpose(),
t_acc[mma_pipe.producer_stage],
t_acc[:, mma_pipe.producer_stage, :],
enable_input_d=offset_k != 0,
)
self.tcgen05.commit(mbarrier=tma_pipe.consumer_barrier())
Expand Down Expand Up @@ -261,7 +262,7 @@ def __call__(

for e_offset_n in range(0, block_n, e_block_n):
t_acc_slice = self.tcgen05.slice(
t_acc[mma_pipe.consumer_stage],
t_acc[:, mma_pipe.consumer_stage, :],
offsets=[0, e_offset_n],
shape=[block_m, e_block_n],
dims=[0, 1],
Expand Down
18 changes: 15 additions & 3 deletions examples/blackwell_matmul/matmul_v6.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tilus.utils import benchmark_func, cdiv

tilus.option.cache_dir("./cache")
tilus.option.debug.dump_ir()


class Pipeline(tilus.Class): # same as V4/V5
Expand Down Expand Up @@ -69,6 +70,16 @@ def consumer_advance(self):
@tilus.autotune("mma_stages", [2])
@tilus.autotune("swizzle_size", [4, 8, 16])
class BlackwellMatmulV6(tilus.Script):
debug_schedule = dict(
block_m=256,
block_n=256,
e_block_n=16,
block_k=64,
tma_stages=5,
mma_stages=2,
swizzle_size=8,
)

def __init__(
self,
block_m: int,
Expand Down Expand Up @@ -156,8 +167,9 @@ def __call__(
s_a = self.shared_tensor(dtype=float16, shape=[tma_stages, block_m // 2, block_k])
s_b = self.shared_tensor(dtype=float16, shape=[tma_stages, block_n // 2, block_k])
# cta_group=2: distributed MMA reads shared memory from both CTAs
# lane dim (block_m // 2 per CTA) is first; the stages axis is a column-strided sub-axis
t_acc = self.tcgen05.alloc(
dtype=float32, shape=[mma_stages, block_m // 2, block_n], cta_group=2
dtype=float32, shape=[block_m // 2, mma_stages, block_n], cta_group=2
)

s_clc_response = self.shared_tensor(dtype=int32, shape=[clc_stages, 4])
Expand Down Expand Up @@ -225,7 +237,7 @@ def __call__(
self.tcgen05.mma(
s_a[tma_pipe.consumer_stage],
s_b[tma_pipe.consumer_stage].transpose(),
t_acc[mma_pipe.producer_stage],
t_acc[:, mma_pipe.producer_stage, :],
enable_input_d=offset_k != 0,
cta_group=2,
)
Expand Down Expand Up @@ -282,7 +294,7 @@ def __call__(

for e_offset_n in range(0, block_n, e_block_n):
t_acc_slice = self.tcgen05.slice(
t_acc[mma_pipe.consumer_stage],
t_acc[:, mma_pipe.consumer_stage, :],
offsets=[0, e_offset_n],
shape=[block_m // 2, e_block_n],
dims=[0, 1],
Expand Down
13 changes: 8 additions & 5 deletions python/tilus/backends/emitters/cuda/tcgen05/alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@
class Tcgen05AllocDeallocEmitter(BaseInstEmitter):
def get_num_columns(self, tmem_tensor: TMemoryTensor) -> int:
shape = tmem_tensor.shape
if shape[-2] != 128:
raise NotImplementedError(f"The emitter currently only supports shape[-2] == 128, but got {shape[-2]}")
if shape[0] != 128:
raise NotImplementedError(f"The emitter currently only supports shape[0] == 128, but got {shape[0]}")
if shape[-1] * tmem_tensor.dtype.nbits % 32 != 0:
raise ValueError(
f"shape[-1] * dtype.nbits must be divisible by 32, but got {shape[-1]} * {tmem_tensor.dtype.nbits} = {shape[-1] * tmem_tensor.dtype.nbits}"
)
num_columns = prod(shape[:-2]) * shape[-1] * tmem_tensor.dtype.nbits // 32
# All dimensions after the lane dim (shape[0]) are column-strided; total column count
# is the product of those dims, scaled by the dtype's bits-per-element / 32.
num_columns = prod(shape[1:]) * tmem_tensor.dtype.nbits // 32

# the number of columns must be a power-of-two and in the range [32, 512]
# normalize it to be at least 32 and be a power of two
Expand Down Expand Up @@ -142,9 +144,10 @@ def emit(self, inst: Tcgen05ViewInst) -> None:
):
raise ValueError("The total number of bits must be the same as the original tensor.")

if not same_list(tmem_tensor.layout.column_strides[:-2], output_tmem_tensor.layout.column_strides[:-2]):
if not same_list(tmem_tensor.layout.column_strides[1:-1], output_tmem_tensor.layout.column_strides[1:-1]):
raise ValueError(
"The column strides of the leading dimensions (all dimensions except the last two ones) must be the same as the original tensor."
"The column strides of the middle dimensions (all dimensions except the lane (dim 0) and "
"the innermost column (dim -1)) must be the same as the original tensor."
)

tmem_addr = self.get_or_allocate_var(tmem_tensor)
Expand Down
2 changes: 1 addition & 1 deletion python/tilus/backends/emitters/cuda/tcgen05/ldst.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def emit_tcgen05_instructions(
raise ValueError(
"Lane mismatch: the first lane of the tmem tensor must be the same as the thread group begin"
)
if self.current_num_threads != tmem_tensor.shape[-2]:
if self.current_num_threads != tmem_tensor.shape[0]:
raise ValueError(
"The number of threads in the current thread group must be the same as the number of lanes in the tmem tensor"
)
Expand Down
10 changes: 6 additions & 4 deletions python/tilus/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,16 @@ def visit_Let(self, e: Let):
raise ValueError("please run 'expand_let_expr' pass before codegen")

def visit_Var(self, e: Var):
# Function-typed Vars are global symbols: their Var.name is the canonical
# identifier that must match the function definition. Use it verbatim
# instead of routing through the Namer's identity-based disambiguation.
if isinstance(e.type, FuncType) and e.name is not None:
return Text(self.canonize_funcname(e.name))
cast2int = {"threadIdx.x", "threadIdx.y", "threadIdx.z", "blockIdx.x", "blockIdx.y", "blockIdx.z"}
name = self.namer.get_name(e)
if name in cast2int:
return Text(f"(int){name}")
else:
if isinstance(e.type, FuncType):
name = self.canonize_funcname(name)
return Text(name)
return Text(name)

def visit_Constant(self, e: Constant):
if e.is_string():
Expand Down
5 changes: 5 additions & 0 deletions python/tilus/hidet/ir/tools/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ def visit_Address(self, e: Address):
return Text("&") + self(e.expr)

def visit_Var(self, e: Var):
# Function-typed Vars are global symbols: their Var.name is a canonical
# identifier that must match the function definition, so use it verbatim
# instead of going through the Namer's identity-based disambiguation.
if isinstance(e.type, FuncType) and e.name is not None:
return Text(e.name)
return Text(self.namer.get_name(e))

def visit_Constant(self, e: Constant):
Expand Down
9 changes: 5 additions & 4 deletions python/tilus/ir/instructions/cuda/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Tcgen05AllocInst(Instruction):
@staticmethod
def create(dtype: DataType, shape: Sequence[int], cta_group: int) -> Tcgen05AllocInst:
assert len(shape) >= 2, "Tcgen05AllocInst only supports tensors with rank >= 2."
assert shape[-2] in (32, 64, 128), "The second last dimension must be 32, 64, or 128."
assert shape[0] in (32, 64, 128), "The first (lane) dimension must be 32, 64, or 128."
output = TMemoryTensor.create(dtype=dtype, shape=shape)
return Tcgen05AllocInst(output=output, inputs=(), cta_group=cta_group)

Expand Down Expand Up @@ -65,10 +65,11 @@ def create(
) -> Tcgen05SliceInst:
assert len(tmem.shape) == len(offsets)
assert len(slice_shape) == len(slice_dims)
assert len(slice_dims) >= 2 and all(len(tmem.shape) - 1 - i in slice_dims for i in range(2)), (
"The last two dimensions must be included in the slice."
# The lane dim (0) and the innermost column dim (-1) must always be in the slice.
assert len(slice_dims) >= 2 and 0 in slice_dims and (len(tmem.shape) - 1) in slice_dims, (
"The lane dim (0) and the innermost column dim (-1) must be included in the slice."
)
assert isinstance(offsets[-2], Constant), "The row-offset must be a constant."
assert isinstance(offsets[0], Constant), "The lane (row) offset must be a constant."
output = TMemoryTensor.create(dtype=tmem.dtype, shape=slice_shape)
return Tcgen05SliceInst(output=output, inputs=(tmem,), offsets=tuple(offsets), slice_dims=tuple(slice_dims))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def inference(

# check that the lane matches the threads
lane_begin = tmem_tensor.layout.lane_offset
lane_end = lane_begin + tmem_tensor.shape[-2]
lane_end = lane_begin + tmem_tensor.shape[0]
thread_begin = ctx.thread_begin
thread_end = ctx.thread_end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05SliceInst) -> dict[TMemo
return {
tmem: tmemory_slice(
tmem_layout=inst.tmemory_input.optional_layout,
lane_offset=int(inst.offsets[-2]),
lane_offset=int(inst.offsets[0]),
slice_dims=inst.slice_dims,
shape=tmem.shape,
)
Expand Down
14 changes: 6 additions & 8 deletions python/tilus/ir/layout/ops/tmemory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@


def tmemory_row_major(shape: Sequence[int]) -> TMemoryLayout:
column_strides = []
# Convention: dim 0 is the lane axis (stride 0); dims 1..-1 are column-strided
# in row-major order (innermost dim has stride 1).
column_strides = [0] * len(shape)
stride = 1
for dim in reversed(range(len(shape))):
if dim == len(shape) - 2:
column_strides.append(0)
else:
column_strides.append(stride)
stride *= shape[dim]
column_strides = list(reversed(column_strides))
for dim in reversed(range(1, len(shape))):
column_strides[dim] = stride
stride *= shape[dim]
return TMemoryLayout.create(shape, column_strides, lane_offset=0)


Expand Down
11 changes: 6 additions & 5 deletions python/tilus/ir/layout/tmem_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ def create(shape: Sequence[int], column_strides: Sequence[int], lane_offset: int
)
if len(shape) < 2:
raise ValueError("TMemLayout requires at least 2 dimensions, got {}".format(len(shape)))
if shape[-2] not in [32, 64, 128]:
raise ValueError("The number of rows (shape[-2]) must be 32, 64, or 128, got {}".format(shape[-2]))
if column_strides[-2] != 0:
# Convention: shape[0] is the lane (row) dimension; all other dims are column-strided.
if shape[0] not in [32, 64, 128]:
raise ValueError("The number of rows (shape[0]) must be 32, 64, or 128, got {}".format(shape[0]))
if column_strides[0] != 0:
raise ValueError(
"The column stride for the row dimension (column_strides[-2]) must be 0, got {}".format(
column_strides[-2]
"The column stride for the row dimension (column_strides[0]) must be 0, got {}".format(
column_strides[0]
)
)
return TMemoryLayout(shape=tuple(shape), column_strides=tuple(column_strides), lane_offset=lane_offset)
16 changes: 8 additions & 8 deletions python/tilus/ir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,15 +661,15 @@ class TMemoryTensor(Tensor):

Tensor memory is a dedicated on-chip memory available on Blackwell (SM 10.0+) GPUs, private to the SM's tensor
cores. It is organized as a 2D structure of lanes (rows) and columns, with each cell being 32 bits. The number
of lanes (``shape[-2]``) must be 32, 64, or 128.
of lanes (``shape[0]``) must be 32, 64, or 128.

Attributes
----------
dtype: DataType
The data type of the tensor elements.
shape: tuple[int, ...]
The shape of the tensor. Must have at least 2 dimensions. The second-to-last dimension (lanes) must be
32, 64, or 128.
The shape of the tensor. Must have at least 2 dimensions. The first dimension (lanes) must be
32, 64, or 128. All remaining dimensions are column-strided.
optional_layout: TMemoryLayout, optional
The layout of the tensor, which is optional. When not provided, the layout will be automatically inferred
with compiler pass.
Expand All @@ -687,8 +687,8 @@ def create(dtype: DataType, shape: Sequence[int], optional_layout: Optional[TMem
dtype: DataType
The data type of the tensor elements.
shape: Sequence[int]
The shape of the tensor. Must have at least 2 dimensions, with the second-to-last
dimension (lanes) being 32, 64, or 128.
The shape of the tensor. Must have at least 2 dimensions, with the first
dimension (lanes) being 32, 64, or 128. All remaining dimensions are column-strided.
optional_layout: TMemoryLayout, optional
The layout of the tensor. If not provided, the layout will be inferred later.

Expand All @@ -699,8 +699,8 @@ def create(dtype: DataType, shape: Sequence[int], optional_layout: Optional[TMem
"""
if len(shape) < 2:
raise ValueError("TMemoryTensor requires at least 2 dimensions, got {}".format(len(shape)))
if shape[-2] not in (32, 64, 128):
raise ValueError("The number of rows (shape[-2]) must be 32, 64, or 128, got {}".format(shape[-2]))
if shape[0] not in (32, 64, 128):
raise ValueError("The number of rows (shape[0]) must be 32, 64, or 128, got {}".format(shape[0]))
if optional_layout is not None and tuple(shape) != tuple(optional_layout.shape):
raise ValueError(
f"Shape mismatch: provided shape {shape} does not match layout shape {optional_layout.shape}."
Expand Down Expand Up @@ -757,7 +757,7 @@ def with_layout(self, layout: TMemoryLayout) -> TMemoryTensor:
converted in the Tilus Script transpiler defined in tilus.lang.transpiler module.
"""

def __getitem__(self, indices: tuple[Expr | int, ...] | Expr | int) -> TMemoryTensor:
def __getitem__(self, indices: tuple[Expr | int | None | slice, ...] | Expr | int | None | slice) -> TMemoryTensor:
raise RuntimeError("tmemory_tensor[...] could only be used in Tilus Script.")


Expand Down
9 changes: 5 additions & 4 deletions python/tilus/lang/instructions/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def alloc(self, dtype: DataType, shape: Sequence[int], cta_group: int = 1) -> TM
dtype: DataType
The data type of the tensor elements (e.g., ``float32``, ``float16``).
shape: Sequence[int]
The shape of the tensor. Must have at least 2 dimensions. The second-to-last
dimension (``shape[-2]``) must be 32, 64, or 128.
The shape of the tensor. Must have at least 2 dimensions. The first
dimension (``shape[0]``) is the lane axis and must be 32, 64, or 128.
All remaining dimensions are column-strided.
cta_group: int
The CTA group size for the allocation. Must be 1 or 2. When 2, the tensor is
shared across two CTAs in the same cluster.
Expand All @@ -78,8 +79,8 @@ def alloc(self, dtype: DataType, shape: Sequence[int], cta_group: int = 1) -> TM
raise InstructionError("cta_group must be 1 or 2")
if len(shape) < 2:
raise InstructionError("shape must be a sequence of length 2 or more, got {}".format(shape))
if shape[-2] not in (32, 64, 128):
raise InstructionError("shape[-2] must be 32, 64, or 128, got {}".format(shape[-2]))
if shape[0] not in (32, 64, 128):
raise InstructionError("shape[0] must be 32, 64, or 128, got {}".format(shape[0]))
if 128 % dtype.nbits != 0:
raise InstructionError("dtype must be 1, 2, 4, 8, 16, 32, 64, or 128 bit, got {}".format(dtype))
ret = self._builder.tcgen05_alloc(dtype, shape, cta_group)
Expand Down
Loading