diff --git a/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_pv_matmul.cpp b/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_pv_matmul.cpp index 02e03184..669b0d13 100644 --- a/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_pv_matmul.cpp +++ b/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_pv_matmul.cpp @@ -22,8 +22,6 @@ #include #include -#define N_UNROLL 64 - #include "tensor.h" using namespace pto; @@ -42,7 +40,7 @@ static __aicore__ void pv_matmul_n_impl( __gm__ bfloat16_t* val_base, __gm__ float* oi_base, uint64_t n_blocks, - uint64_t* block_indices) { + __gm__ int32_t* block_table) { using GlobalA = GlobalTensor, Stride>; using GlobalB = GlobalTensor, Stride>; @@ -74,7 +72,7 @@ static __aicore__ void pv_matmul_n_impl( // Pre-load first iteration's tiles into ping buffers GlobalA pijGlobal_0(pij_base); - GlobalB vjGlobal_0(val_base + block_indices[0] * K * N); + GlobalB vjGlobal_0(val_base + block_table[0] * K * N); TLOAD(aMatTile_ping, pijGlobal_0); TLOAD(bMatTile_ping, vjGlobal_0); @@ -111,7 +109,7 @@ static __aicore__ void pv_matmul_n_impl( TileMatA& nxtA = (i % 2 == 0) ? aMatTile_pong : aMatTile_ping; TileMatB& nxtB = (i % 2 == 0) ? bMatTile_pong : bMatTile_ping; GlobalA pijGlobal_next(pij_base + (i + 1) * M * K); - GlobalB vjGlobal_next(val_base + block_indices[i + 1] * K * N); + GlobalB vjGlobal_next(val_base + block_table[i + 1] * K * N); TLOAD(nxtA, pijGlobal_next); TLOAD(nxtB, vjGlobal_next); } @@ -130,10 +128,7 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { __gm__ TensorData* value_cache = reinterpret_cast<__gm__ TensorData*>(args[1]); __gm__ TensorData* oi_new = reinterpret_cast<__gm__ TensorData*>(args[2]); uint64_t n_blocks = static_cast(args[3]); - uint64_t block_indices[N_UNROLL]; - for (uint64_t j = 0; j < n_blocks; j++) { - block_indices[j] = static_cast(args[4 + j]); - } + __gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]); __gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_buf->buffer.addr) + pij_buf->start_offset; __gm__ bfloat16_t* val_base = reinterpret_cast<__gm__ bfloat16_t*>(value_cache->buffer.addr); @@ -142,8 +137,8 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { uint64_t q_tile_size = static_cast(pij_buf->shapes[0]); if (q_tile_size == 16) { - pv_matmul_n_impl<16, 128, 128>(pij_base, val_base, oi_base, n_blocks, block_indices); + pv_matmul_n_impl<16, 128, 128>(pij_base, val_base, oi_base, n_blocks, block_table); } else { - pv_matmul_n_impl<64, 64, 128>(pij_base, val_base, oi_base, n_blocks, block_indices); + pv_matmul_n_impl<64, 64, 128>(pij_base, val_base, oi_base, n_blocks, block_table); } } diff --git a/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_qk_matmul.cpp b/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_qk_matmul.cpp index ae16200a..b87e3940 100644 --- a/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_qk_matmul.cpp +++ b/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/aic/aic_qk_matmul.cpp @@ -21,8 +21,6 @@ #include "tensor.h" -#define N_UNROLL 64 - using namespace pto; #ifndef __gm__ @@ -39,7 +37,7 @@ static __aicore__ void qk_matmul_n_impl( __gm__ bfloat16_t* key_base, __gm__ float* sij_base, uint64_t n_blocks, - uint64_t* block_indices) { + __gm__ int32_t* block_table) { using GlobalA = GlobalTensor, Stride>; using GlobalB = GlobalTensor, Stride, Layout::DN>; @@ -69,7 +67,7 @@ static __aicore__ void qk_matmul_n_impl( TLOAD(aMatTile, qiGlobal); for (uint64_t i = 0; i < n_blocks; i++) { - GlobalB kjGlobal(key_base + block_indices[i] * N * K); + GlobalB kjGlobal(key_base + block_table[i] * N * K); GlobalOut sijGlobal(sij_base + i * M * N); // Load only B each iteration (qi already in L1 from hoist) @@ -105,10 +103,7 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { __gm__ TensorData* key_cache = reinterpret_cast<__gm__ TensorData*>(args[1]); __gm__ TensorData* sij_buf = reinterpret_cast<__gm__ TensorData*>(args[2]); uint64_t n_blocks = static_cast(args[3]); - uint64_t block_indices[N_UNROLL]; - for (uint64_t j = 0; j < n_blocks; j++) { - block_indices[j] = static_cast(args[4 + j]); - } + __gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]); __gm__ bfloat16_t* qi_base = reinterpret_cast<__gm__ bfloat16_t*>(qi->buffer.addr) + qi->start_offset; __gm__ bfloat16_t* key_base = reinterpret_cast<__gm__ bfloat16_t*>(key_cache->buffer.addr); @@ -117,8 +112,8 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { uint64_t q_tile_size = static_cast(qi->shapes[0]); if (q_tile_size == 16) { - qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_indices); + qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_table); } else { - qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_indices); + qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_table); } } diff --git a/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp b/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp index bd06e1c9..087f917e 100644 --- a/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp +++ b/tests/device_tests/a2a3/tensormap_and_ringbuffer/paged_attention_unroll/kernels/orchestration/paged_attention_orch.cpp @@ -181,7 +181,6 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(PTO2Runtim // Reusable PTOParam objects — reset() before each use avoids // repeated stack-frame construction in the inner loop. - // params_qk must persist until params_pv.copy_scalars_from(). PTOParam params_qk, params_sf, params_pv, params_up; for (uint64_t bn = 0; bn < bn_this_batch; bn += N_UNROLL) { @@ -203,7 +202,7 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(PTO2Runtim params_qk.add_input(key_cache); params_qk.add_output(sij_buf); params_qk.add_scalar(n_blocks); - params_qk.add_scalars_i32(bt_base + bn, N_UNROLL); + params_qk.add_scalar(reinterpret_cast(bt_base + bn)); CYCLE_COUNT_LAP(prof_param_setup); pto2_rt_submit_aic_task(rt, FUNC_QK_MATMUL, params_qk); prof_submit_count++; @@ -241,7 +240,7 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(PTO2Runtim params_pv.add_input(value_cache); params_pv.add_output(oi_new); params_pv.add_scalar(n_blocks); - params_pv.copy_scalars_from(params_qk, 1, N_UNROLL); + params_pv.add_scalar(reinterpret_cast(bt_base + bn)); CYCLE_COUNT_LAP(prof_param_setup); pto2_rt_submit_aic_task(rt, FUNC_PV_MATMUL, params_pv); prof_submit_count++;