Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
#include <cstdint>
#include <pto/pto-inst.hpp>

#define N_UNROLL 64

#include "tensor.h"

using namespace pto;
Expand All @@ -42,7 +40,7 @@ static __aicore__ void pv_matmul_n_impl(
__gm__ bfloat16_t* val_base,
__gm__ float* oi_base,
uint64_t n_blocks,
uint64_t* block_indices) {
__gm__ int32_t* block_table) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the qk_matmul kernel, this function now accesses block_table directly from global memory. This happens at line 75 (block_table[0]) and inside the loop at line 112 (block_table[i + 1]).

To avoid potential performance degradation from repeated GMEM reads, you could pre-load the necessary block indices into a local array at the start of the function. This would consolidate GMEM access into a single burst read.

Example:

// At the top of the file:
constexpr int kMaxBlocks = 64;

// In pv_matmul_n_impl:
int32_t local_block_table[kMaxBlocks];
for (uint64_t i = 0; i < n_blocks; ++i) {
    local_block_table[i] = block_table[i];
}

// Then use local_block_table for accesses.
GlobalB vjGlobal_0(val_base + local_block_table[0] * K * N);
// ... and in the loop ...
GlobalB vjGlobal_next(val_base + local_block_table[i + 1] * K * N);

This change would make the kernel more robust to memory latency variations.


using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
Expand Down Expand Up @@ -74,7 +72,7 @@ static __aicore__ void pv_matmul_n_impl(

// Pre-load first iteration's tiles into ping buffers
GlobalA pijGlobal_0(pij_base);
GlobalB vjGlobal_0(val_base + block_indices[0] * K * N);
GlobalB vjGlobal_0(val_base + block_table[0] * K * N);
TLOAD(aMatTile_ping, pijGlobal_0);
TLOAD(bMatTile_ping, vjGlobal_0);

Expand Down Expand Up @@ -111,7 +109,7 @@ static __aicore__ void pv_matmul_n_impl(
TileMatA& nxtA = (i % 2 == 0) ? aMatTile_pong : aMatTile_ping;
TileMatB& nxtB = (i % 2 == 0) ? bMatTile_pong : bMatTile_ping;
GlobalA pijGlobal_next(pij_base + (i + 1) * M * K);
GlobalB vjGlobal_next(val_base + block_indices[i + 1] * K * N);
GlobalB vjGlobal_next(val_base + block_table[i + 1] * K * N);
TLOAD(nxtA, pijGlobal_next);
TLOAD(nxtB, vjGlobal_next);
}
Expand All @@ -130,10 +128,7 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
__gm__ TensorData* value_cache = reinterpret_cast<__gm__ TensorData*>(args[1]);
__gm__ TensorData* oi_new = reinterpret_cast<__gm__ TensorData*>(args[2]);
uint64_t n_blocks = static_cast<uint64_t>(args[3]);
uint64_t block_indices[N_UNROLL];
for (uint64_t j = 0; j < n_blocks; j++) {
block_indices[j] = static_cast<uint64_t>(args[4 + j]);
}
__gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]);

__gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_buf->buffer.addr) + pij_buf->start_offset;
__gm__ bfloat16_t* val_base = reinterpret_cast<__gm__ bfloat16_t*>(value_cache->buffer.addr);
Expand All @@ -142,8 +137,8 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
uint64_t q_tile_size = static_cast<uint64_t>(pij_buf->shapes[0]);

if (q_tile_size == 16) {
pv_matmul_n_impl<16, 128, 128>(pij_base, val_base, oi_base, n_blocks, block_indices);
pv_matmul_n_impl<16, 128, 128>(pij_base, val_base, oi_base, n_blocks, block_table);
} else {
pv_matmul_n_impl<64, 64, 128>(pij_base, val_base, oi_base, n_blocks, block_indices);
pv_matmul_n_impl<64, 64, 128>(pij_base, val_base, oi_base, n_blocks, block_table);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

#include "tensor.h"

#define N_UNROLL 64

using namespace pto;

#ifndef __gm__
Expand All @@ -39,7 +37,7 @@ static __aicore__ void qk_matmul_n_impl(
__gm__ bfloat16_t* key_base,
__gm__ float* sij_base,
uint64_t n_blocks,
uint64_t* block_indices) {
__gm__ int32_t* block_table) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While passing the block_table pointer is a great simplification for the orchestration, accessing it directly from global memory inside the loop at line 70 (block_table[i]) might introduce performance overhead due to repeated GMEM reads.

Consider pre-loading the block indices into a stack-allocated local array at the beginning of this function. This would perform a single burst read from GMEM and subsequent accesses within the loop would be much faster.

For example:

// At the top of the file, you could define:
constexpr int kMaxBlocks = 64;

// Then, inside qk_matmul_n_impl:
int32_t local_block_table[kMaxBlocks];
// A simple loop or a memcpy-like instruction could be used to load the data.
for (uint64_t i = 0; i < n_blocks; ++i) {
    local_block_table[i] = block_table[i];
}

// And in the main loop:
for (uint64_t i = 0; i < n_blocks; i++) {
    GlobalB kjGlobal(key_base + local_block_table[i] * N * K);
    // ...
}

Although you've noted performance is neutral, this is a good practice that could yield benefits, especially if n_blocks were larger or if memory access patterns change.


using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
Expand Down Expand Up @@ -69,7 +67,7 @@ static __aicore__ void qk_matmul_n_impl(
TLOAD(aMatTile, qiGlobal);

for (uint64_t i = 0; i < n_blocks; i++) {
GlobalB kjGlobal(key_base + block_indices[i] * N * K);
GlobalB kjGlobal(key_base + block_table[i] * N * K);
GlobalOut sijGlobal(sij_base + i * M * N);

// Load only B each iteration (qi already in L1 from hoist)
Expand Down Expand Up @@ -105,10 +103,7 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
__gm__ TensorData* key_cache = reinterpret_cast<__gm__ TensorData*>(args[1]);
__gm__ TensorData* sij_buf = reinterpret_cast<__gm__ TensorData*>(args[2]);
uint64_t n_blocks = static_cast<uint64_t>(args[3]);
uint64_t block_indices[N_UNROLL];
for (uint64_t j = 0; j < n_blocks; j++) {
block_indices[j] = static_cast<uint64_t>(args[4 + j]);
}
__gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]);

__gm__ bfloat16_t* qi_base = reinterpret_cast<__gm__ bfloat16_t*>(qi->buffer.addr) + qi->start_offset;
__gm__ bfloat16_t* key_base = reinterpret_cast<__gm__ bfloat16_t*>(key_cache->buffer.addr);
Expand All @@ -117,8 +112,8 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
uint64_t q_tile_size = static_cast<uint64_t>(qi->shapes[0]);

if (q_tile_size == 16) {
qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_indices);
qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_table);
} else {
qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_indices);
qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_table);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(PTO2Runtim

// Reusable PTOParam objects — reset() before each use avoids
// repeated stack-frame construction in the inner loop.
// params_qk must persist until params_pv.copy_scalars_from().
PTOParam params_qk, params_sf, params_pv, params_up;

for (uint64_t bn = 0; bn < bn_this_batch; bn += N_UNROLL) {
Expand All @@ -203,7 +202,7 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(PTO2Runtim
params_qk.add_input(key_cache);
params_qk.add_output(sij_buf);
params_qk.add_scalar(n_blocks);
params_qk.add_scalars_i32(bt_base + bn, N_UNROLL);
params_qk.add_scalar(reinterpret_cast<uint64_t>(bt_base + bn));
CYCLE_COUNT_LAP(prof_param_setup);
pto2_rt_submit_aic_task(rt, FUNC_QK_MATMUL, params_qk);
prof_submit_count++;
Expand Down Expand Up @@ -241,7 +240,7 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(PTO2Runtim
params_pv.add_input(value_cache);
params_pv.add_output(oi_new);
params_pv.add_scalar(n_blocks);
params_pv.copy_scalars_from(params_qk, 1, N_UNROLL);
params_pv.add_scalar(reinterpret_cast<uint64_t>(bt_base + bn));
CYCLE_COUNT_LAP(prof_param_setup);
pto2_rt_submit_aic_task(rt, FUNC_PV_MATMUL, params_pv);
prof_submit_count++;
Expand Down
Loading