-
Notifications
You must be signed in to change notification settings - Fork 33
Pass block table pointer instead of 64 individual scalar args #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,6 @@ | |
|
|
||
| #include "tensor.h" | ||
|
|
||
| #define N_UNROLL 64 | ||
|
|
||
| using namespace pto; | ||
|
|
||
| #ifndef __gm__ | ||
|
|
@@ -39,7 +37,7 @@ static __aicore__ void qk_matmul_n_impl( | |
| __gm__ bfloat16_t* key_base, | ||
| __gm__ float* sij_base, | ||
| uint64_t n_blocks, | ||
| uint64_t* block_indices) { | ||
| __gm__ int32_t* block_table) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While passing the Consider pre-loading the block indices into a stack-allocated local array at the beginning of this function. This would perform a single burst read from GMEM and subsequent accesses within the loop would be much faster. For example: // At the top of the file, you could define:
constexpr int kMaxBlocks = 64;
// Then, inside qk_matmul_n_impl:
int32_t local_block_table[kMaxBlocks];
// A simple loop or a memcpy-like instruction could be used to load the data.
for (uint64_t i = 0; i < n_blocks; ++i) {
local_block_table[i] = block_table[i];
}
// And in the main loop:
for (uint64_t i = 0; i < n_blocks; i++) {
GlobalB kjGlobal(key_base + local_block_table[i] * N * K);
// ...
}Although you've noted performance is neutral, this is a good practice that could yield benefits, especially if |
||
|
|
||
| using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>; | ||
| using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>; | ||
|
|
@@ -69,7 +67,7 @@ static __aicore__ void qk_matmul_n_impl( | |
| TLOAD(aMatTile, qiGlobal); | ||
|
|
||
| for (uint64_t i = 0; i < n_blocks; i++) { | ||
| GlobalB kjGlobal(key_base + block_indices[i] * N * K); | ||
| GlobalB kjGlobal(key_base + block_table[i] * N * K); | ||
| GlobalOut sijGlobal(sij_base + i * M * N); | ||
|
|
||
| // Load only B each iteration (qi already in L1 from hoist) | ||
|
|
@@ -105,10 +103,7 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { | |
| __gm__ TensorData* key_cache = reinterpret_cast<__gm__ TensorData*>(args[1]); | ||
| __gm__ TensorData* sij_buf = reinterpret_cast<__gm__ TensorData*>(args[2]); | ||
| uint64_t n_blocks = static_cast<uint64_t>(args[3]); | ||
| uint64_t block_indices[N_UNROLL]; | ||
| for (uint64_t j = 0; j < n_blocks; j++) { | ||
| block_indices[j] = static_cast<uint64_t>(args[4 + j]); | ||
| } | ||
| __gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]); | ||
|
|
||
| __gm__ bfloat16_t* qi_base = reinterpret_cast<__gm__ bfloat16_t*>(qi->buffer.addr) + qi->start_offset; | ||
| __gm__ bfloat16_t* key_base = reinterpret_cast<__gm__ bfloat16_t*>(key_cache->buffer.addr); | ||
|
|
@@ -117,8 +112,8 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { | |
| uint64_t q_tile_size = static_cast<uint64_t>(qi->shapes[0]); | ||
|
|
||
| if (q_tile_size == 16) { | ||
| qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_indices); | ||
| qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_table); | ||
| } else { | ||
| qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_indices); | ||
| qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_table); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the
qk_matmulkernel, this function now accessesblock_tabledirectly from global memory. This happens at line 75 (block_table[0]) and inside the loop at line 112 (block_table[i + 1]).To avoid potential performance degradation from repeated GMEM reads, you could pre-load the necessary block indices into a local array at the start of the function. This would consolidate GMEM access into a single burst read.
Example:
This change would make the kernel more robust to memory latency variations.