Skip to content

[Refactor] Use the first dimension of tensor memory tensor as lane dim#151

Merged
yaoyaoding merged 1 commit intomainfrom
refactor/tmem-lane-dim
May 4, 2026
Merged

[Refactor] Use the first dimension of tensor memory tensor as lane dim#151
yaoyaoding merged 1 commit intomainfrom
refactor/tmem-lane-dim

Conversation

@yaoyaoding
Copy link
Copy Markdown
Member

PR0 of the NVFP4 enablement plan: migrate the TMEM tensor convention so shape[0] (instead of shape[-2]) is the lane axis. With dim 0 fixed as the lane, all remaining dims are column-strided sub-axes — a rank-stable convention that scales naturally to higher-rank TMEM tensors (e.g. SF tensors with separate M-fold and K-chunk axes).

Lane dim convention update:

  • TMemoryLayout.create / TMemoryTensor.create / Tcgen05AllocInst: shape[-2] -> shape[0] for the {32, 64, 128} lane validation; same for column_strides[-2] -> [0] (lane stride must be zero).
  • Tcgen05SliceInst.create: require dim 0 (lane) and dim -1 (innermost column) in slice_dims; lane offset reads from offsets[0].
  • ldst inference / emitter: read lane size from tmem.shape[0].
  • alloc emitter: column-count formula updated from prod(shape[:-2]) * shape[-1] to prod(shape[1:]); view emitter preserves middle-dim strides via column_strides[1:-1].
  • tmemory_row_major: zero stride placed at dim 0; columns flow row-major from dim 1 onward.

Examples migrated:

  • matmul_v5/matmul_v6: rank-3 t_acc shape changed from [mma_stages, block_m, block_n] to [block_m, mma_stages, block_n] so dim 0 is the lane axis. All t_acc[stage] indexing rewritten to t_acc[:, stage, :] (column-dim slicing of a rank-3 TMEM tensor).

Function-Var name handling (codegen bugfix):

  • After the Var.name + Var.hint merge in 942c27e, the Namer began disambiguating function-Var references by identity (kernel, kernel_1, ...). When passes that re-instantiate function-typed Vars (InlineRegisterTensorPass, FlattenTensorIndexPass, InlineFunctionPass, ...) ran, the resulting C source called an undefined kernel_1 symbol even though the kernel was defined as kernel.
  • Fix: in tilus/hidet/ir/tools/printer.py and tilus/hidet/backend/codegen.py, print FuncType-typed Vars using Var.name verbatim — function-Vars are global symbols whose name must match the function definition. The C codegen still applies canonize_funcname for the tilus_ prefix.
  • This makes matmul_v5 / matmul_v6 (which trigger the offending pass combinations) compile and run.

Tests: tcgen05 instruction suite (43 tests) and matmul_v2 smoke test all pass; matmul_v5 builds and validates against torch reference.

… name handling

PR0 of the NVFP4 enablement plan: migrate the TMEM tensor convention so
shape[0] (instead of shape[-2]) is the lane axis. With dim 0 fixed as the
lane, all remaining dims are column-strided sub-axes — a rank-stable
convention that scales naturally to higher-rank TMEM tensors (e.g. SF
tensors with separate M-fold and K-chunk axes).

Lane dim convention update:
- TMemoryLayout.create / TMemoryTensor.create / Tcgen05AllocInst:
  shape[-2] -> shape[0] for the {32, 64, 128} lane validation; same for
  column_strides[-2] -> [0] (lane stride must be zero).
- Tcgen05SliceInst.create: require dim 0 (lane) and dim -1 (innermost
  column) in slice_dims; lane offset reads from offsets[0].
- ldst inference / emitter: read lane size from tmem.shape[0].
- alloc emitter: column-count formula updated from
  prod(shape[:-2]) * shape[-1] to prod(shape[1:]); view emitter
  preserves middle-dim strides via column_strides[1:-1].
- tmemory_row_major: zero stride placed at dim 0; columns flow row-major
  from dim 1 onward.

Examples migrated:
- matmul_v5/matmul_v6: rank-3 t_acc shape changed from
  [mma_stages, block_m, block_n] to [block_m, mma_stages, block_n] so dim 0
  is the lane axis. All `t_acc[stage]` indexing rewritten to
  `t_acc[:, stage, :]` (column-dim slicing of a rank-3 TMEM tensor).

Function-Var name handling (codegen bugfix):
- After the Var.name + Var.hint merge in 942c27e, the Namer began
  disambiguating function-Var references by identity (kernel,
  kernel_1, ...). When passes that re-instantiate function-typed Vars
  (InlineRegisterTensorPass, FlattenTensorIndexPass, InlineFunctionPass,
  ...) ran, the resulting C source called an undefined kernel_1 symbol
  even though the kernel was defined as kernel.
- Fix: in tilus/hidet/ir/tools/printer.py and tilus/hidet/backend/codegen.py,
  print FuncType-typed Vars using Var.name verbatim — function-Vars are
  global symbols whose name must match the function definition. The C
  codegen still applies canonize_funcname for the tilus_ prefix.
- This makes matmul_v5 / matmul_v6 (which trigger the offending pass
  combinations) compile and run.

Tests: tcgen05 instruction suite (43 tests) and matmul_v2 smoke test all
pass; matmul_v5 builds and validates against torch reference.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding yaoyaoding changed the title [Refactor] TMEM lane dim convention [-2] -> [0] [Refactor] TMEM lane dim convention from the last second to the first dim May 4, 2026
@yaoyaoding yaoyaoding changed the title [Refactor] TMEM lane dim convention from the last second to the first dim [Refactor] Use the first dimension of tensor memory tensor as lane dim May 4, 2026
@yaoyaoding yaoyaoding merged commit c6ac728 into main May 4, 2026
18 of 26 checks passed
@yaoyaoding yaoyaoding deleted the refactor/tmem-lane-dim branch May 4, 2026 02:50
github-actions Bot added a commit that referenced this pull request May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant