[Refactor] Use the first dimension of tensor memory tensor as lane dim#151
Merged
yaoyaoding merged 1 commit intomainfrom May 4, 2026
Merged
[Refactor] Use the first dimension of tensor memory tensor as lane dim#151yaoyaoding merged 1 commit intomainfrom
yaoyaoding merged 1 commit intomainfrom
Conversation
… name handling
PR0 of the NVFP4 enablement plan: migrate the TMEM tensor convention so
shape[0] (instead of shape[-2]) is the lane axis. With dim 0 fixed as the
lane, all remaining dims are column-strided sub-axes — a rank-stable
convention that scales naturally to higher-rank TMEM tensors (e.g. SF
tensors with separate M-fold and K-chunk axes).
Lane dim convention update:
- TMemoryLayout.create / TMemoryTensor.create / Tcgen05AllocInst:
shape[-2] -> shape[0] for the {32, 64, 128} lane validation; same for
column_strides[-2] -> [0] (lane stride must be zero).
- Tcgen05SliceInst.create: require dim 0 (lane) and dim -1 (innermost
column) in slice_dims; lane offset reads from offsets[0].
- ldst inference / emitter: read lane size from tmem.shape[0].
- alloc emitter: column-count formula updated from
prod(shape[:-2]) * shape[-1] to prod(shape[1:]); view emitter
preserves middle-dim strides via column_strides[1:-1].
- tmemory_row_major: zero stride placed at dim 0; columns flow row-major
from dim 1 onward.
Examples migrated:
- matmul_v5/matmul_v6: rank-3 t_acc shape changed from
[mma_stages, block_m, block_n] to [block_m, mma_stages, block_n] so dim 0
is the lane axis. All `t_acc[stage]` indexing rewritten to
`t_acc[:, stage, :]` (column-dim slicing of a rank-3 TMEM tensor).
Function-Var name handling (codegen bugfix):
- After the Var.name + Var.hint merge in 942c27e, the Namer began
disambiguating function-Var references by identity (kernel,
kernel_1, ...). When passes that re-instantiate function-typed Vars
(InlineRegisterTensorPass, FlattenTensorIndexPass, InlineFunctionPass,
...) ran, the resulting C source called an undefined kernel_1 symbol
even though the kernel was defined as kernel.
- Fix: in tilus/hidet/ir/tools/printer.py and tilus/hidet/backend/codegen.py,
print FuncType-typed Vars using Var.name verbatim — function-Vars are
global symbols whose name must match the function definition. The C
codegen still applies canonize_funcname for the tilus_ prefix.
- This makes matmul_v5 / matmul_v6 (which trigger the offending pass
combinations) compile and run.
Tests: tcgen05 instruction suite (43 tests) and matmul_v2 smoke test all
pass; matmul_v5 builds and validates against torch reference.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
[-2] -> [0]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR0 of the NVFP4 enablement plan: migrate the TMEM tensor convention so shape[0] (instead of shape[-2]) is the lane axis. With dim 0 fixed as the lane, all remaining dims are column-strided sub-axes — a rank-stable convention that scales naturally to higher-rank TMEM tensors (e.g. SF tensors with separate M-fold and K-chunk axes).
Lane dim convention update:
Examples migrated:
t_acc[stage]indexing rewritten tot_acc[:, stage, :](column-dim slicing of a rank-3 TMEM tensor).Function-Var name handling (codegen bugfix):
Tests: tcgen05 instruction suite (43 tests) and matmul_v2 smoke test all pass; matmul_v5 builds and validates against torch reference.