diff --git a/.gitignore b/.gitignore index 106fd907..7c4671a2 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,6 @@ outputs # Mid-work documentation .docs + +# Debug dumps (generated locally) +examples/host_build_graph/cpt_and_comm/debug/ diff --git a/examples/host_build_graph/allgather_Manual/README.md b/examples/host_build_graph/allgather_Manual/README.md new file mode 100644 index 00000000..e31309a4 --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/README.md @@ -0,0 +1,18 @@ +# AllGather (Manual RDMA 实现) + +多卡 AllGather 通信算子:使用 **直接 RDMA 读取**(HcclRemotePtr + TLOAD/TSTORE)实现。每个 rank 获得所有 rank 数据的拼接结果。 + +**实现方式**:`WindowMemCopyIn -> CommBarrier -> AllGatherManual -> WindowMemCopyOut -> CommBarrier(post)` + +- 无 TGATHER 集体调用,所有 rank 并行执行 +- 适用于性能对比测试 + +## 运行 + +```bash +python examples/scripts/multi_card_run_example.py \ + -k examples/host_build_graph/allgather_Manual/kernels \ + -g examples/host_build_graph/allgather_Manual/golden.py +``` + +需要设置 `PTO_COMM_ISA_ROOT` 指向 pto-comm-isa 根目录,以及多卡 HCCL 环境。 diff --git a/examples/host_build_graph/allgather_Manual/golden.py b/examples/host_build_graph/allgather_Manual/golden.py new file mode 100644 index 00000000..595aea8d --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/golden.py @@ -0,0 +1,59 @@ +""" +Golden reference for AllGather (Manual RDMA variant, no compute). + +Each rank contributes GATHER_COUNT float32 elements. +After AllGather, EVERY rank holds the concatenation of all ranks' data. +""" + +import ctypes +import numpy as np + +GATHER_COUNT = 64 + +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" +__outputs__ = ["out"] +RTOL = 1e-4 +ATOL = 1e-4 + + +def generate_inputs(params: dict) -> list: + """Return flat argument list. For requires_comm, params includes device_ctx_ptr, win_in_base, win_out_base, n_ranks, root, rank_id.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + np.random.seed(42 + rank_id) + src = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out = np.zeros((n_ranks * GATHER_COUNT,), dtype=np.float32) + + result = [ + ("src", src), + ("out", out), + ("size_src", ctypes.c_int64(src.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ] + + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), + ("root", ctypes.c_int32(root)), + ("rank_id", ctypes.c_int32(rank_id)), + ]) + + return result + + +def compute_golden(tensors: dict, params: dict) -> None: + """AllGather: every rank gets the full concatenation of all ranks' data.""" + n_ranks = params.get("n_ranks", 2) + out = tensors["out"] + + out_np = out.cpu().numpy() if hasattr(out, 'cpu') else np.asarray(out) + for r in range(n_ranks): + np.random.seed(42 + r) + src_r = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = src_r[:GATHER_COUNT] diff --git a/examples/host_build_graph/allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp b/examples/host_build_graph/allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp new file mode 100644 index 00000000..6a93a85c --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp @@ -0,0 +1,64 @@ +/** + * Manual AllGather kernel - direct RDMA reads, no TGATHER. + * + * Each rank independently reads from all ranks' win_src via HcclRemotePtr + * and writes to local dst. No collective TGATHER call, so no deadlock. + * All ranks can run in parallel (single kernel, single barrier). + * + * Args: dst, src, ctx, nranks, rank_id (rank_id unused, for API compatibility) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* src = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[2]); + int nranks = static_cast(args[3]); + (void)args[4]; /* rank_id unused */ + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + ShapeDyn sliceShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn sliceStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int r = 0; r < actual_nranks; ++r) { + __gm__ float* remote_src = HcclRemotePtr(hcclCtx, src, r); + __gm__ float* local_dst = dst + static_cast(r) * GATHER_COUNT; + + Global srcG(remote_src, sliceShape, sliceStride); + Global dstG(local_dst, sliceShape, sliceStride); + + TLOAD(ubTile, srcG); + set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + TSTORE(dstG, ubTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp b/examples/host_build_graph/allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp new file mode 100644 index 00000000..b1711665 --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp @@ -0,0 +1,52 @@ +/** + * All-to-all barrier (多对多): every rank waits for every other rank. + * + * Used by AllGather where every rank reads from all ranks' windows. + * Unlike comm_barrier_kernel (many-to-one), ALL ranks do TWAIT here. + * + * Flow: + * 1. Each rank TNOTIFY to root's barrier slot[my_rank] + * 2. Each rank TWAIT on root's barrier until all n_ranks slots >= 1 + * + * Args: + * args[0] = barrier_base (local barrier buffer; root's is used for sync) + * args[1] = device_ctx_ptr (HcclDeviceContext*) + * args[2] = n_ranks + * args[3] = root (whose barrier buffer is the sync point) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(args[0]); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + // Step 1: Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Step 2: ALL ranks wait until every rank's flag is >= 1 (multi-to-multi). + __gm__ int32_t* root_barrier = HcclRemotePtr(ctx, local_barrier, root); + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(root_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/allgather_Manual/kernels/aiv/window_memcopy_in.cpp b/examples/host_build_graph/allgather_Manual/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..38408baa --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before AllGather so remote ranks can read. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/allgather_Manual/kernels/aiv/window_memcopy_out.cpp b/examples/host_build_graph/allgather_Manual/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..99e83e76 --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * After AllGather, every rank copies gathered result to device. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/allgather_Manual/kernels/kernel_config.py b/examples/host_build_graph/allgather_Manual/kernels/kernel_config.py new file mode 100644 index 00000000..0a3e8133 --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/kernels/kernel_config.py @@ -0,0 +1,30 @@ +""" +AllGather (Manual): direct RDMA reads for performance comparison. + +Flow: WindowMemCopyIn -> CommBarrier -> AllGatherManual (HcclRemotePtr+TLOAD/TSTORE) +-> WindowMemCopyOut -> CommBarrier(post). +No TGATHER, all ranks run in parallel. Requires HCCL, PTO_COMM_ISA_ROOT. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "allgather_orch.cpp"), + "function_name": "build_allgather_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 1, "name": "AllGatherManual", "source": str(_KERNELS_ROOT / "aiv" / "allgather_manual_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "CommBarrierAll", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/host_build_graph/allgather_Manual/kernels/orchestration/allgather_orch.cpp b/examples/host_build_graph/allgather_Manual/kernels/orchestration/allgather_orch.cpp new file mode 100644 index 00000000..54aa5b81 --- /dev/null +++ b/examples/host_build_graph/allgather_Manual/kernels/orchestration/allgather_orch.cpp @@ -0,0 +1,115 @@ +/** + * AllGather (Manual): direct RDMA reads, no TGATHER. + * + * Flow: WindowMemCopyIn -> CommBarrier -> AllGatherManual -> WindowMemCopyOut + * -> CommBarrier(post). All ranks run in parallel. + * + * Args (10): [0] host_src, [1] host_out, [2] size_src, [3] size_out, + * [4] device_ctx_ptr, [5] win_in_base, [6] win_out_base, + * [7] n_ranks, [8] root (unused), [9] rank_id + */ + +#include "runtime.h" +#include +#include +#include + +extern "C" { + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_WIN_MEMCOPY_IN 0 +#define FUNC_ALLGATHER 1 +#define FUNC_WIN_MEMCOPY_OUT 2 +#define FUNC_COMM_BARRIER 3 + +int build_allgather_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 10) { + std::cerr << "build_allgather_graph: Expected at least 10 args, got " << arg_count << '\n'; + return -1; + } + + void* host_src = reinterpret_cast(args[0]); + void* host_out = reinterpret_cast(args[1]); + size_t size_src = static_cast(args[2]); + size_t size_out = static_cast(args[3]); + uint64_t device_ctx_ptr = args[4]; + uint64_t win_in_base = args[5]; + uint64_t win_out_base = args[6]; + int n_ranks = static_cast(args[7]); + int rank_id = static_cast(args[9]); + + std::cout << "\n=== build_allgather_graph (Manual RDMA) ===" << '\n'; + std::cout << " n_ranks=" << n_ranks << " rank_id=" << rank_id << '\n'; + + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base_pre = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t barrier_base_post = barrier_base_pre + barrier_size; + uint64_t win_src = barrier_base_post + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + int32_t zeros[64] = {}; + std::memset(zeros, 0, sizeof(zeros)); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base_pre), + zeros, barrier_size); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base_post), + zeros, barrier_size); + + void* dev_src = runtime->host_api.device_malloc(size_src); + if (!dev_src) return -1; + runtime->host_api.copy_to_device(dev_src, host_src, size_src); + + void* dev_out = runtime->host_api.device_malloc(size_out); + if (!dev_out) { + runtime->host_api.device_free(dev_src); + return -1; + } + runtime->record_tensor_pair(host_out, dev_out, size_out); + + uint64_t args_wmin[3] = { + win_src, + reinterpret_cast(dev_src), + static_cast(GATHER_COUNT) + }; + int t0 = runtime->add_task(args_wmin, 3, FUNC_WIN_MEMCOPY_IN, CoreType::AIV); + + uint64_t args_barrier_pre[4] = { + barrier_base_pre, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t1 = runtime->add_task(args_barrier_pre, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t0, t1); + + uint64_t args_allgather[5] = { + win_dst, win_src, device_ctx_ptr, + static_cast(n_ranks), static_cast(rank_id) + }; + int t2 = runtime->add_task(args_allgather, 5, FUNC_ALLGATHER, CoreType::AIV); + runtime->add_successor(t1, t2); + + uint64_t args_wmout[3] = { + reinterpret_cast(dev_out), + win_dst, + static_cast(n_ranks * GATHER_COUNT) + }; + int t3 = runtime->add_task(args_wmout, 3, FUNC_WIN_MEMCOPY_OUT, CoreType::AIV); + runtime->add_successor(t2, t3); + + uint64_t args_barrier_post[4] = { + barrier_base_post, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t4 = runtime->add_task(args_barrier_post, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t3, t4); + + std::cout << " task" << t0 << ": WindowMemCopyIn [AIV]\n"; + std::cout << " task" << t1 << ": CommBarrierAll (pre) [AIV]\n"; + std::cout << " task" << t2 << ": AllGatherManual [AIV]\n"; + std::cout << " task" << t3 << ": WindowMemCopyOut [AIV]\n"; + std::cout << " task" << t4 << ": CommBarrierAll (post) [AIV]\n"; + + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/allgather_Tgather/README.md b/examples/host_build_graph/allgather_Tgather/README.md new file mode 100644 index 00000000..73a496f2 --- /dev/null +++ b/examples/host_build_graph/allgather_Tgather/README.md @@ -0,0 +1,18 @@ +# AllGather (TGATHER 实现) + +多卡 AllGather 通信算子:使用 **N 次顺序 TGATHER** 实现。每个 rank 获得所有 rank 数据的拼接结果。 + +**实现方式**:`for r in [0, n_ranks): Barrier -> Gather(root=r) -> [rank r: WindowMemCopyOut] -> Barrier(post)` + +- 仅 root 调用 TGATHER,避免多 rank 同时调用导致的死锁 +- 适用于性能对比测试 + +## 运行 + +```bash +python examples/scripts/multi_card_run_example.py \ + -k examples/host_build_graph/allgather_Tgather/kernels \ + -g examples/host_build_graph/allgather_Tgather/golden.py +``` + +需要设置 `PTO_COMM_ISA_ROOT` 指向 pto-comm-isa 根目录,以及多卡 HCCL 环境。 diff --git a/examples/host_build_graph/allgather_Tgather/golden.py b/examples/host_build_graph/allgather_Tgather/golden.py new file mode 100644 index 00000000..cc637788 --- /dev/null +++ b/examples/host_build_graph/allgather_Tgather/golden.py @@ -0,0 +1,59 @@ +""" +Golden reference for AllGather (TGATHER variant, no compute). + +Each rank contributes GATHER_COUNT float32 elements. +After AllGather, EVERY rank holds the concatenation of all ranks' data. +""" + +import ctypes +import numpy as np + +GATHER_COUNT = 64 + +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" +__outputs__ = ["out"] +RTOL = 1e-4 +ATOL = 1e-4 + + +def generate_inputs(params: dict) -> list: + """Return flat argument list. For requires_comm, params includes device_ctx_ptr, win_in_base, win_out_base, n_ranks, root, rank_id.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + np.random.seed(42 + rank_id) + src = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out = np.zeros((n_ranks * GATHER_COUNT,), dtype=np.float32) + + result = [ + ("src", src), + ("out", out), + ("size_src", ctypes.c_int64(src.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ] + + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), + ("root", ctypes.c_int32(root)), + ("rank_id", ctypes.c_int32(rank_id)), + ]) + + return result + + +def compute_golden(tensors: dict, params: dict) -> None: + """AllGather: every rank gets the full concatenation of all ranks' data.""" + n_ranks = params.get("n_ranks", 2) + out = tensors["out"] + + out_np = out.cpu().numpy() if hasattr(out, 'cpu') else np.asarray(out) + for r in range(n_ranks): + np.random.seed(42 + r) + src_r = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = src_r[:GATHER_COUNT] diff --git a/examples/host_build_graph/allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp b/examples/host_build_graph/allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp new file mode 100644 index 00000000..b1711665 --- /dev/null +++ b/examples/host_build_graph/allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp @@ -0,0 +1,52 @@ +/** + * All-to-all barrier (多对多): every rank waits for every other rank. + * + * Used by AllGather where every rank reads from all ranks' windows. + * Unlike comm_barrier_kernel (many-to-one), ALL ranks do TWAIT here. + * + * Flow: + * 1. Each rank TNOTIFY to root's barrier slot[my_rank] + * 2. Each rank TWAIT on root's barrier until all n_ranks slots >= 1 + * + * Args: + * args[0] = barrier_base (local barrier buffer; root's is used for sync) + * args[1] = device_ctx_ptr (HcclDeviceContext*) + * args[2] = n_ranks + * args[3] = root (whose barrier buffer is the sync point) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(args[0]); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + // Step 1: Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Step 2: ALL ranks wait until every rank's flag is >= 1 (multi-to-multi). + __gm__ int32_t* root_barrier = HcclRemotePtr(ctx, local_barrier, root); + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(root_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/allgather_Tgather/kernels/aiv/window_memcopy_in.cpp b/examples/host_build_graph/allgather_Tgather/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..38408baa --- /dev/null +++ b/examples/host_build_graph/allgather_Tgather/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before AllGather so remote ranks can read. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/allgather_Tgather/kernels/aiv/window_memcopy_out.cpp b/examples/host_build_graph/allgather_Tgather/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..99e83e76 --- /dev/null +++ b/examples/host_build_graph/allgather_Tgather/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * After AllGather, every rank copies gathered result to device. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/allgather_Tgather/kernels/kernel_config.py b/examples/host_build_graph/allgather_Tgather/kernels/kernel_config.py new file mode 100644 index 00000000..3387b59a --- /dev/null +++ b/examples/host_build_graph/allgather_Tgather/kernels/kernel_config.py @@ -0,0 +1,31 @@ +""" +AllGather (TGATHER): N sequential Gathers for performance comparison. + +Flow: WindowMemCopyIn -> for each r in [0,n_ranks): Barrier -> Gather(root=r) +-> [if rank==r: WindowMemCopyOut] -> Barrier(post). +Only root calls TGATHER per round. Requires HCCL, PTO_COMM_ISA_ROOT. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_GATHER_ROOT = _KERNELS_ROOT.parent.parent / "gather" / "kernels" + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "allgather_orch.cpp"), + "function_name": "build_allgather_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 1, "name": "Gather", "source": str(_GATHER_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "CommBarrierAll", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/host_build_graph/allgather_Tgather/kernels/orchestration/allgather_orch.cpp b/examples/host_build_graph/allgather_Tgather/kernels/orchestration/allgather_orch.cpp new file mode 100644 index 00000000..6c2fdf84 --- /dev/null +++ b/examples/host_build_graph/allgather_Tgather/kernels/orchestration/allgather_orch.cpp @@ -0,0 +1,121 @@ +/** + * AllGather (TGATHER): N sequential Gathers for performance comparison. + * + * Flow: for r in [0, n_ranks): Barrier -> Gather(root=r) -> [rank r: WindowMemCopyOut] + * Only root calls TGATHER per round. Avoids deadlock when all ranks call TGATHER. + * + * Args (10): [0] host_src, [1] host_out, [2] size_src, [3] size_out, + * [4] device_ctx_ptr, [5] win_in_base, [6] win_out_base, + * [7] n_ranks, [8] root (unused), [9] rank_id + */ + +#include "runtime.h" +#include +#include +#include + +extern "C" { + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_WIN_MEMCOPY_IN 0 +#define FUNC_GATHER 1 +#define FUNC_WIN_MEMCOPY_OUT 2 +#define FUNC_COMM_BARRIER 3 + +int build_allgather_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 10) { + std::cerr << "build_allgather_graph: Expected at least 10 args, got " << arg_count << '\n'; + return -1; + } + + void* host_src = reinterpret_cast(args[0]); + void* host_out = reinterpret_cast(args[1]); + size_t size_src = static_cast(args[2]); + size_t size_out = static_cast(args[3]); + uint64_t device_ctx_ptr = args[4]; + uint64_t win_in_base = args[5]; + uint64_t win_out_base = args[6]; + int n_ranks = static_cast(args[7]); + int rank_id = static_cast(args[9]); + + std::cout << "\n=== build_allgather_graph (TGATHER, N sequential) ===" << '\n'; + std::cout << " n_ranks=" << n_ranks << " rank_id=" << rank_id << '\n'; + + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + size_t total_barrier_bytes = barrier_size * (static_cast(n_ranks) + 1); + uint64_t barrier_base_0 = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base_0 + total_barrier_bytes; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + int32_t zeros[64] = {}; + std::memset(zeros, 0, sizeof(zeros)); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base_0), + zeros, total_barrier_bytes); + + void* dev_src = runtime->host_api.device_malloc(size_src); + if (!dev_src) return -1; + runtime->host_api.copy_to_device(dev_src, host_src, size_src); + + void* dev_out = runtime->host_api.device_malloc(size_out); + if (!dev_out) { + runtime->host_api.device_free(dev_src); + return -1; + } + runtime->record_tensor_pair(host_out, dev_out, size_out); + + uint64_t args_wmin[3] = { + win_src, + reinterpret_cast(dev_src), + static_cast(GATHER_COUNT) + }; + int t0 = runtime->add_task(args_wmin, 3, FUNC_WIN_MEMCOPY_IN, CoreType::AIV); + + int t_prev = t0; + for (int r = 0; r < n_ranks; r++) { + uint64_t barrier_base_r = barrier_base_0 + static_cast(r) * barrier_size; + uint64_t args_barrier[4] = { + barrier_base_r, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t_barrier = runtime->add_task(args_barrier, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t_prev, t_barrier); + + uint64_t args_gather[5] = { + win_dst, win_src, device_ctx_ptr, + static_cast(n_ranks), static_cast(r) + }; + int t_gather = runtime->add_task(args_gather, 5, FUNC_GATHER, CoreType::AIV); + runtime->add_successor(t_barrier, t_gather); + + if (rank_id == r) { + uint64_t args_wmout[3] = { + reinterpret_cast(dev_out), + win_dst, + static_cast(n_ranks * GATHER_COUNT) + }; + int t_wmout = runtime->add_task(args_wmout, 3, FUNC_WIN_MEMCOPY_OUT, CoreType::AIV); + runtime->add_successor(t_gather, t_wmout); + t_prev = t_wmout; + } else { + t_prev = t_gather; + } + } + + uint64_t barrier_base_post = barrier_base_0 + static_cast(n_ranks) * barrier_size; + uint64_t args_barrier_post[4] = { + barrier_base_post, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t_post = runtime->add_task(args_barrier_post, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t_prev, t_post); + + std::cout << " task" << t0 << ": WindowMemCopyIn [AIV]\n"; + std::cout << " tasks: " << n_ranks << "x (Barrier -> Gather(root=r) -> [rank r: WinCopyOut])\n"; + std::cout << " task" << t_post << ": CommBarrierAll (post) [AIV]\n"; + + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/cpt_and_comm/README.md b/examples/host_build_graph/cpt_and_comm/README.md new file mode 100644 index 00000000..6fe20649 --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/README.md @@ -0,0 +1,44 @@ +# cpt_and_comm + +多卡「先计算,再通信」示例:GEMM → WindowMemCopyIn → TGATHER → WindowMemCopyOut。 + +## 流程 + +1. **GEMM**:每卡执行 C = A @ B(64x64) +2. **WindowMemCopyIn**:将 dev_C 前 64 元素拷贝到 HCCL window +3. **TGATHER**:root 从各 rank 收集到本地 window +4. **WindowMemCopyOut**:root 将 gathered 结果拷贝到 dev_out + +## 依赖 + +- CANN(libhccl.so、libacl.so) +- pto-comm-isa(`pto::comm::TGATHER`、`ParallelGroup`) +- a2a3 真机(不支持 sim) + +## 运行 + +```bash +# 1. 先 source CANN 环境(与跑 pto-comm-isa comm case 相同) +source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash + +# 2. 编译 C++ HCCL 辅助库(与 pto-comm-isa 同方式链接,Python 通过 ctypes 调用) +cd examples/scripts/hccl_helper && mkdir -p build && cd build && cmake .. && make && cd ../../../../.. + +# 3. 设置 pto-comm-isa 路径(注意用单个 =) +export PTO_COMM_ISA_ROOT=/path/to/pto-comm-isa + +# 验证:应存在 include/pto/pto-inst.hpp +ls $PTO_COMM_ISA_ROOT/include/pto/pto-inst.hpp + +# 4. 2 卡运行 +python examples/scripts/multi_card_run_example.py \ + -k examples/host_build_graph/cpt_and_comm/kernels \ + -g examples/host_build_graph/cpt_and_comm/golden.py \ + --n-devices 2 --first-device 0 +``` + +## 配置 + +- `RUNTIME_CONFIG.requires_comm`: True +- `RUNTIME_CONFIG.n_devices`: 2 +- `RUNTIME_CONFIG.root`: 0 diff --git a/examples/host_build_graph/cpt_and_comm/golden.py b/examples/host_build_graph/cpt_and_comm/golden.py new file mode 100644 index 00000000..34f5f912 --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/golden.py @@ -0,0 +1,84 @@ +""" +Golden reference for cpt_and_comm: GEMM then TGATHER. + +Each rank: C = A @ B (64x64), gather first 64 elements to root. +Root output: [rank0_first64, rank1_first64, ...] +""" + +import ctypes +import numpy as np + +# GEMM 64x64, gather 64 elements per rank +TILE = 64 +GATHER_COUNT = 64 + +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" +__outputs__ = ["out"] +RTOL = 1e-4 +ATOL = 1e-4 + + +def generate_inputs(params: dict) -> list: + """Return flat argument list. For requires_comm, params includes device_ctx_ptr, win_in_base, win_out_base, n_ranks, root, rank_id.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + # A, B: 64x64 per rank (different data per rank) + np.random.seed(42 + rank_id) + a = np.random.randn(TILE, TILE).astype(np.float32) * 0.1 + b = np.random.randn(TILE, TILE).astype(np.float32) * 0.1 + c = np.zeros((TILE, TILE), dtype=np.float32) # GEMM output + out = np.zeros((n_ranks * GATHER_COUNT,), dtype=np.float32) # root only + + result = [ + ("a", a), + ("b", b), + ("c", c), + ("out", out), + ("size_a", ctypes.c_int64(a.nbytes)), + ("size_b", ctypes.c_int64(b.nbytes)), + ("size_c", ctypes.c_int64(c.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ] + + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), + ("root", ctypes.c_int32(root)), + ("rank_id", ctypes.c_int32(rank_id)), + ]) + + return result + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected: BGEMM then gather first GATHER_COUNT from each rank.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + a = tensors["a"] + b = tensors["b"] + c = tensors["c"] + out = tensors["out"] + + # GEMM: C = A @ B + c[:] = a @ b + + # Gather: root collects first GATHER_COUNT from each rank + if rank_id == root: + # out is torch.Tensor (CodeRunner converts); use numpy view for assignment + out_np = out.cpu().numpy() + for r in range(n_ranks): + # Simulate rank r's GEMM output (we only have our own, so for golden we compute all) + np.random.seed(42 + r) + ar = np.random.randn(TILE, TILE).astype(np.float32) * 0.1 + br = np.random.randn(TILE, TILE).astype(np.float32) * 0.1 + cr = ar @ br + flat = cr.flatten() + out_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = flat[:GATHER_COUNT] diff --git a/examples/host_build_graph/cpt_and_comm/kernels/aic/kernel_gemm_tile.cpp b/examples/host_build_graph/cpt_and_comm/kernels/aic/kernel_gemm_tile.cpp new file mode 100644 index 00000000..92c93d32 --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/kernels/aic/kernel_gemm_tile.cpp @@ -0,0 +1,90 @@ +/** + * Tile-based Matrix Multiplication Kernel (Cube Core) + * + * Computes: output = input_a @ input_b (64x64 tile matmul) + * Uses TMATMUL instruction + */ + +#include +#include +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE constexpr inline T CeilAlign(T num_1, T num_2) { + if (num_2 == 0) { + return 0; + } + return (num_1 + num_2 - 1) / num_2 * num_2; +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* input_a = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* input_b = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ float* output = reinterpret_cast<__gm__ float*>(args[2]); + + constexpr int TILE = 64; + constexpr int blockAlign = C0_SIZE_BYTE / sizeof(float); + constexpr int M = CeilAlign(TILE, 16); + constexpr int K = CeilAlign(TILE, blockAlign); + constexpr int N = CeilAlign(TILE, blockAlign); + + using GlobalDataA = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataB = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataC = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + + GlobalDataA src0Global(input_a); + GlobalDataB src1Global(input_b); + GlobalDataC dstGlobal(output); + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + TLOAD(aMatTile, src0Global); + TLOAD(bMatTile, src1Global); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(dstGlobal, cTile); +} diff --git a/examples/host_build_graph/cpt_and_comm/kernels/aiv/comm_barrier_kernel.cpp b/examples/host_build_graph/cpt_and_comm/kernels/aiv/comm_barrier_kernel.cpp new file mode 100644 index 00000000..7e210a16 --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/kernels/aiv/comm_barrier_kernel.cpp @@ -0,0 +1,51 @@ +/** + * Device-side cross-rank barrier using TNOTIFY/TWAIT from pto-comm-isa. + * + * Each rank notifies root that it has finished the compute phase by writing + * a flag to root's barrier slot. Root then spins until all ranks have + * reported, guaranteeing that every rank's window data is visible before + * TGATHER reads it. + * + * Args: + * args[0] = barrier_base (local barrier signal buffer in own windowsIn) + * args[1] = device_ctx_ptr (HcclDeviceContext*) + * args[2] = n_ranks + * args[3] = root + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(args[0]); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + // Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Root waits until every rank's flag is >= 1. + if (my_rank == root) { + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(local_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/cpt_and_comm/kernels/aiv/gather_kernel.cpp b/examples/host_build_graph/cpt_and_comm/kernels/aiv/gather_kernel.cpp new file mode 100644 index 00000000..2d972cfa --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/kernels/aiv/gather_kernel.cpp @@ -0,0 +1,62 @@ +/** + * TGATHER collective kernel - root gathers from all ranks. + * Requires pto-comm-isa (PTO_ISA_ROOT or PTO_COMM_ISA_ROOT). + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* src = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[2]); + int nranks = static_cast(args[3]); + int root = static_cast(args[4]); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + + using TileData = pto::Tile; + + int my_rank = static_cast(hcclCtx->rankId); + + ShapeDyn srcShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn srcStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + ShapeDyn dstShape(1, 1, 1, nranks, GATHER_COUNT); + StrideDyn dstStride(nranks * GATHER_COUNT, nranks * GATHER_COUNT, nranks * GATHER_COUNT, GATHER_COUNT, 1); + Global dstG(dst, dstShape, dstStride); + + Global tensors[16]; + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int i = 0; i < actual_nranks; ++i) { + __gm__ float* remoteSrc = HcclRemotePtr(hcclCtx, src, i); + tensors[i] = Global(remoteSrc, srcShape, srcStride); + } + + pto::comm::ParallelGroup pg(tensors, actual_nranks, root); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + if (my_rank == root) { + pto::comm::TGATHER(pg, dstG, ubTile); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/cpt_and_comm/kernels/aiv/window_memcopy_in.cpp b/examples/host_build_graph/cpt_and_comm/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..73504fa1 --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before TGATHER so remote ranks can read. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/cpt_and_comm/kernels/aiv/window_memcopy_out.cpp b/examples/host_build_graph/cpt_and_comm/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..3f2ef586 --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * Root only - after TGATHER, copy gathered result to device. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/cpt_and_comm/kernels/kernel_config.py b/examples/host_build_graph/cpt_and_comm/kernels/kernel_config.py new file mode 100644 index 00000000..660c791a --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/kernels/kernel_config.py @@ -0,0 +1,31 @@ +""" +Kernel configuration for cpt_and_comm (compute then communicate). + +Flow: GEMM -> WindowMemCopyIn -> CommBarrier -> TGATHER -> WindowMemCopyOut (root only). +CommBarrier uses TNOTIFY/TWAIT for device-side cross-rank synchronization. +Requires HCCL (multi-card), PTO_COMM_ISA_ROOT pointing to pto-comm-isa for comm headers. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "cpt_and_comm_orch.cpp"), + "function_name": "build_cpt_and_comm_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "GEMM", "source": str(_KERNELS_ROOT / "aic" / "kernel_gemm_tile.cpp"), "core_type": "aic"}, + {"func_id": 1, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "Gather", "source": str(_KERNELS_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 4, "name": "CommBarrier", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/host_build_graph/cpt_and_comm/kernels/orchestration/cpt_and_comm_orch.cpp b/examples/host_build_graph/cpt_and_comm/kernels/orchestration/cpt_and_comm_orch.cpp new file mode 100644 index 00000000..e6fdb0d1 --- /dev/null +++ b/examples/host_build_graph/cpt_and_comm/kernels/orchestration/cpt_and_comm_orch.cpp @@ -0,0 +1,147 @@ +/** + * cpt_and_comm orchestration: GEMM -> WindowMemCopyIn -> CommBarrier -> TGATHER -> WindowMemCopyOut (root only). + * + * CommBarrier uses TNOTIFY/TWAIT to synchronize all ranks at the device level, + * guaranteeing every rank's window data is visible before TGATHER reads it. + * + * Args: host_A, host_B, host_C, host_out, size_A, size_B, size_C, size_out, + * device_ctx_ptr, win_in_base, win_out_base, n_ranks, root, rank_id + */ + +#include "runtime.h" +#include +#include +#include + +extern "C" { + +constexpr int TILE = 64; +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +int build_cpt_and_comm_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 14) { + std::cerr << "build_cpt_and_comm_graph: Expected at least 14 args, got " << arg_count << '\n'; + return -1; + } + + void* host_A = reinterpret_cast(args[0]); + void* host_B = reinterpret_cast(args[1]); + void* host_C = reinterpret_cast(args[2]); + void* host_out = reinterpret_cast(args[3]); + size_t size_A = static_cast(args[4]); + size_t size_B = static_cast(args[5]); + size_t size_C = static_cast(args[6]); + size_t size_out = static_cast(args[7]); + uint64_t device_ctx_ptr = args[8]; + uint64_t win_in_base = args[9]; + uint64_t win_out_base = args[10]; + int n_ranks = static_cast(args[11]); + int root = static_cast(args[12]); + int rank_id = static_cast(args[13]); + + std::cout << "\n=== build_cpt_and_comm_graph ===" << '\n'; + std::cout << " n_ranks=" << n_ranks << " root=" << root + << " rank_id=" << rank_id << '\n'; + + // ── Window layout ──────────────────────────────────────────────── + // [0, SYNC_PREFIX) : HCCL sync prefix (reserved) + // [SYNC_PREFIX, SYNC_PREFIX + n_ranks*4) : barrier signals (int32 per rank) + // [SYNC_PREFIX + n_ranks*4, ...) : src, then dst + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + // Zero-initialize barrier slots so TWAIT starts from a clean state. + int32_t zeros[64] = {}; + std::memset(zeros, 0, sizeof(zeros)); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base), zeros, + barrier_size); + + // ── Allocate device memory for GEMM operands ───────────────────── + void* dev_A = runtime->host_api.device_malloc(size_A); + if (!dev_A) return -1; + runtime->host_api.copy_to_device(dev_A, host_A, size_A); + + void* dev_B = runtime->host_api.device_malloc(size_B); + if (!dev_B) { runtime->host_api.device_free(dev_A); return -1; } + runtime->host_api.copy_to_device(dev_B, host_B, size_B); + + void* dev_C = runtime->host_api.device_malloc(size_C); + if (!dev_C) { + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + return -1; + } + runtime->host_api.copy_to_device(dev_C, host_C, size_C); + + void* dev_out = nullptr; + if (rank_id == root) { + dev_out = runtime->host_api.device_malloc(size_out); + if (!dev_out) { + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + runtime->host_api.device_free(dev_C); + return -1; + } + runtime->record_tensor_pair(host_out, dev_out, size_out); + } + + // ── Task 0: GEMM C = A @ B [AIC] ────────────────────────────── + uint64_t args_gemm[3]; + args_gemm[0] = reinterpret_cast(dev_A); + args_gemm[1] = reinterpret_cast(dev_B); + args_gemm[2] = reinterpret_cast(dev_C); + int t0 = runtime->add_task(args_gemm, 3, 0, CoreType::AIC); + + // ── Task 1: WindowMemCopyIn [AIV] ─────────────────────────────── + uint64_t args_wmin[3]; + args_wmin[0] = win_src; + args_wmin[1] = reinterpret_cast(dev_C); + args_wmin[2] = static_cast(GATHER_COUNT); + int t1 = runtime->add_task(args_wmin, 3, 1, CoreType::AIV); + + // ── Task 2: CommBarrier (TNOTIFY/TWAIT) [AIV] ─────────────────── + uint64_t args_barrier[4]; + args_barrier[0] = barrier_base; + args_barrier[1] = device_ctx_ptr; + args_barrier[2] = static_cast(n_ranks); + args_barrier[3] = static_cast(root); + int t2 = runtime->add_task(args_barrier, 4, 4, CoreType::AIV); + + // ── Task 3: Gather [AIV] ──────────────────────────────────────── + uint64_t args_gather[5]; + args_gather[0] = win_dst; + args_gather[1] = win_src; + args_gather[2] = device_ctx_ptr; + args_gather[3] = static_cast(n_ranks); + args_gather[4] = static_cast(root); + int t3 = runtime->add_task(args_gather, 5, 2, CoreType::AIV); + + // Dependencies: GEMM → MemCopyIn → CommBarrier → Gather + runtime->add_successor(t0, t1); + runtime->add_successor(t1, t2); + runtime->add_successor(t2, t3); + + int t4 = -1; + if (dev_out != nullptr) { + // ── Task 4: WindowMemCopyOut (root only) [AIV] ────────────── + uint64_t args_wmout[3]; + args_wmout[0] = reinterpret_cast(dev_out); + args_wmout[1] = win_dst; + args_wmout[2] = static_cast(n_ranks * GATHER_COUNT); + t4 = runtime->add_task(args_wmout, 3, 3, CoreType::AIV); + runtime->add_successor(t3, t4); + } + + std::cout << " task" << t0 << ": GEMM [AIC]\n"; + std::cout << " task" << t1 << ": WindowMemCopyIn [AIV]\n"; + std::cout << " task" << t2 << ": CommBarrier (TNOTIFY/TWAIT) [AIV]\n"; + std::cout << " task" << t3 << ": Gather [AIV]\n"; + if (t4 >= 0) std::cout << " task" << t4 << ": WindowMemCopyOut [AIV]\n"; + + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/gather/README.md b/examples/host_build_graph/gather/README.md new file mode 100644 index 00000000..58559603 --- /dev/null +++ b/examples/host_build_graph/gather/README.md @@ -0,0 +1,17 @@ +# Gather + +纯 gather 通信算子:仅保留多卡之间的 TGATHER 通信,无计算。 + +流程:WindowMemCopyIn → CommBarrier → TGATHER → WindowMemCopyOut (root only) + +每个 rank 有本地 src 数据(GATHER_COUNT=64 个 float),root 将各 rank 的前 GATHER_COUNT 个元素收集到 out 中。 + +## 运行 + +```bash +python examples/scripts/multi_card_run_example.py \ + -k examples/host_build_graph/gather/kernels \ + -g examples/host_build_graph/gather/golden.py +``` + +需要设置 `PTO_COMM_ISA_ROOT` 指向 pto-comm-isa 根目录,以及多卡 HCCL 环境。 diff --git a/examples/host_build_graph/gather/golden.py b/examples/host_build_graph/gather/golden.py new file mode 100644 index 00000000..d341a3f3 --- /dev/null +++ b/examples/host_build_graph/gather/golden.py @@ -0,0 +1,64 @@ +""" +Golden reference for gather: multi-card TGATHER only, no computation. + +Each rank has local src data (GATHER_COUNT elements). Root gathers first +GATHER_COUNT from each rank into out: [rank0_data, rank1_data, ...]. +""" + +import ctypes +import numpy as np + +GATHER_COUNT = 64 + +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" +__outputs__ = ["out"] +RTOL = 1e-4 +ATOL = 1e-4 + + +def generate_inputs(params: dict) -> list: + """Return flat argument list. For requires_comm, params includes device_ctx_ptr, win_in_base, win_out_base, n_ranks, root, rank_id.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + # Per-rank src data (different per rank) + np.random.seed(42 + rank_id) + src = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out = np.zeros((n_ranks * GATHER_COUNT,), dtype=np.float32) # root only + + result = [ + ("src", src), + ("out", out), + ("size_src", ctypes.c_int64(src.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ] + + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), + ("root", ctypes.c_int32(root)), + ("rank_id", ctypes.c_int32(rank_id)), + ]) + + return result + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected: gather first GATHER_COUNT from each rank to root.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + out = tensors["out"] + + if rank_id == root: + out_np = out.cpu().numpy() + for r in range(n_ranks): + np.random.seed(42 + r) + src_r = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = src_r[:GATHER_COUNT] diff --git a/examples/host_build_graph/gather/kernels/aiv/comm_barrier_kernel.cpp b/examples/host_build_graph/gather/kernels/aiv/comm_barrier_kernel.cpp new file mode 100644 index 00000000..7e210a16 --- /dev/null +++ b/examples/host_build_graph/gather/kernels/aiv/comm_barrier_kernel.cpp @@ -0,0 +1,51 @@ +/** + * Device-side cross-rank barrier using TNOTIFY/TWAIT from pto-comm-isa. + * + * Each rank notifies root that it has finished the compute phase by writing + * a flag to root's barrier slot. Root then spins until all ranks have + * reported, guaranteeing that every rank's window data is visible before + * TGATHER reads it. + * + * Args: + * args[0] = barrier_base (local barrier signal buffer in own windowsIn) + * args[1] = device_ctx_ptr (HcclDeviceContext*) + * args[2] = n_ranks + * args[3] = root + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(args[0]); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + // Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Root waits until every rank's flag is >= 1. + if (my_rank == root) { + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(local_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/gather/kernels/aiv/gather_kernel.cpp b/examples/host_build_graph/gather/kernels/aiv/gather_kernel.cpp new file mode 100644 index 00000000..2d972cfa --- /dev/null +++ b/examples/host_build_graph/gather/kernels/aiv/gather_kernel.cpp @@ -0,0 +1,62 @@ +/** + * TGATHER collective kernel - root gathers from all ranks. + * Requires pto-comm-isa (PTO_ISA_ROOT or PTO_COMM_ISA_ROOT). + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* src = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[2]); + int nranks = static_cast(args[3]); + int root = static_cast(args[4]); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + + using TileData = pto::Tile; + + int my_rank = static_cast(hcclCtx->rankId); + + ShapeDyn srcShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn srcStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + ShapeDyn dstShape(1, 1, 1, nranks, GATHER_COUNT); + StrideDyn dstStride(nranks * GATHER_COUNT, nranks * GATHER_COUNT, nranks * GATHER_COUNT, GATHER_COUNT, 1); + Global dstG(dst, dstShape, dstStride); + + Global tensors[16]; + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int i = 0; i < actual_nranks; ++i) { + __gm__ float* remoteSrc = HcclRemotePtr(hcclCtx, src, i); + tensors[i] = Global(remoteSrc, srcShape, srcStride); + } + + pto::comm::ParallelGroup pg(tensors, actual_nranks, root); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + if (my_rank == root) { + pto::comm::TGATHER(pg, dstG, ubTile); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/gather/kernels/aiv/window_memcopy_in.cpp b/examples/host_build_graph/gather/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..73504fa1 --- /dev/null +++ b/examples/host_build_graph/gather/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before TGATHER so remote ranks can read. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/gather/kernels/aiv/window_memcopy_out.cpp b/examples/host_build_graph/gather/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..3f2ef586 --- /dev/null +++ b/examples/host_build_graph/gather/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * Root only - after TGATHER, copy gathered result to device. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/gather/kernels/kernel_config.py b/examples/host_build_graph/gather/kernels/kernel_config.py new file mode 100644 index 00000000..fed896b0 --- /dev/null +++ b/examples/host_build_graph/gather/kernels/kernel_config.py @@ -0,0 +1,30 @@ +""" +Gather-only: multi-card TGATHER communication, no computation. + +Flow: WindowMemCopyIn -> CommBarrier -> TGATHER -> WindowMemCopyOut (root only). +CommBarrier uses TNOTIFY/TWAIT for device-side cross-rank synchronization. +Requires HCCL (multi-card), PTO_COMM_ISA_ROOT for comm headers. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "gather_orch.cpp"), + "function_name": "build_gather_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 1, "name": "Gather", "source": str(_KERNELS_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "CommBarrier", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/host_build_graph/gather/kernels/orchestration/gather_orch.cpp b/examples/host_build_graph/gather/kernels/orchestration/gather_orch.cpp new file mode 100644 index 00000000..422a465b --- /dev/null +++ b/examples/host_build_graph/gather/kernels/orchestration/gather_orch.cpp @@ -0,0 +1,126 @@ +/** + * Gather-only orchestration: WindowMemCopyIn -> CommBarrier -> TGATHER -> WindowMemCopyOut (root only). + * + * No computation. Each rank has local src data; first GATHER_COUNT elements are gathered to root. + * CommBarrier uses TNOTIFY/TWAIT for device-side cross-rank synchronization. + * + * Args (10): + * [0] host_src + * [1] host_out (root only, output buffer) + * [2] size_src + * [3] size_out + * [4] device_ctx_ptr + * [5] win_in_base + * [6] win_out_base + * [7] n_ranks + * [8] root + * [9] rank_id + */ + +#include "runtime.h" +#include +#include +#include + +extern "C" { + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_WIN_MEMCOPY_IN 0 +#define FUNC_GATHER 1 +#define FUNC_WIN_MEMCOPY_OUT 2 +#define FUNC_COMM_BARRIER 3 + +int build_gather_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 10) { + std::cerr << "build_gather_graph: Expected at least 10 args, got " << arg_count << '\n'; + return -1; + } + + void* host_src = reinterpret_cast(args[0]); + void* host_out = reinterpret_cast(args[1]); + size_t size_src = static_cast(args[2]); + size_t size_out = static_cast(args[3]); + uint64_t device_ctx_ptr = args[4]; + uint64_t win_in_base = args[5]; + uint64_t win_out_base = args[6]; + int n_ranks = static_cast(args[7]); + int root = static_cast(args[8]); + int rank_id = static_cast(args[9]); + + std::cout << "\n=== build_gather_graph ===" << '\n'; + std::cout << " n_ranks=" << n_ranks << " root=" << root + << " rank_id=" << rank_id << '\n'; + + /* ── Window layout ──────────────────────────────────────────────── */ + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + int32_t zeros[64] = {}; + std::memset(zeros, 0, sizeof(zeros)); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base), + zeros, barrier_size); + + /* ── Allocate device memory for src ─────────────────────────────── */ + void* dev_src = runtime->host_api.device_malloc(size_src); + if (!dev_src) return -1; + runtime->host_api.copy_to_device(dev_src, host_src, size_src); + + void* dev_out = nullptr; + if (rank_id == root) { + dev_out = runtime->host_api.device_malloc(size_out); + if (!dev_out) { + runtime->host_api.device_free(dev_src); + return -1; + } + runtime->record_tensor_pair(host_out, dev_out, size_out); + } + + /* ── Task 0: WindowMemCopyIn ─────────────────────────────────────── */ + uint64_t args_wmin[3] = { + win_src, + reinterpret_cast(dev_src), + static_cast(GATHER_COUNT) + }; + int t0 = runtime->add_task(args_wmin, 3, FUNC_WIN_MEMCOPY_IN, CoreType::AIV); + + /* ── Task 1: CommBarrier (TNOTIFY/TWAIT) ─────────────────────────── */ + uint64_t args_barrier[4] = { + barrier_base, device_ctx_ptr, + static_cast(n_ranks), static_cast(root) + }; + int t1 = runtime->add_task(args_barrier, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t0, t1); + + /* ── Task 2: TGATHER ─────────────────────────────────────────────── */ + uint64_t args_gather[5] = { + win_dst, win_src, device_ctx_ptr, + static_cast(n_ranks), static_cast(root) + }; + int t2 = runtime->add_task(args_gather, 5, FUNC_GATHER, CoreType::AIV); + runtime->add_successor(t1, t2); + + int t3 = -1; + if (dev_out != nullptr) { + /* ── Task 3: WindowMemCopyOut (root only) ─────────────────────── */ + uint64_t args_wmout[3] = { + reinterpret_cast(dev_out), + win_dst, + static_cast(n_ranks * GATHER_COUNT) + }; + t3 = runtime->add_task(args_wmout, 3, FUNC_WIN_MEMCOPY_OUT, CoreType::AIV); + runtime->add_successor(t2, t3); + } + + std::cout << " task" << t0 << ": WindowMemCopyIn [AIV]\n"; + std::cout << " task" << t1 << ": CommBarrier [AIV]\n"; + std::cout << " task" << t2 << ": Gather [AIV]\n"; + if (t3 >= 0) std::cout << " task" << t3 << ": WindowMemCopyOut [AIV]\n"; + + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/gemm_gather/README.md b/examples/host_build_graph/gemm_gather/README.md new file mode 100644 index 00000000..605933a6 --- /dev/null +++ b/examples/host_build_graph/gemm_gather/README.md @@ -0,0 +1,58 @@ +# GEMM + Gather Example (a2a3) + +本 case 包含两个 kernel:一个计算 kernel(GEMM)、一个通信 kernel(Gather),面向 **a2a3** 平台。任务依赖:先计算(GEMM)再通信(Gather),即 t0 → t1。 + +## 1. pto-comm-isa 与 simpler-PTO 自带 _deps/pto-isa 的差异说明 + +### 1.1 结论(a2a3 tgather 相关) + +- **a2a3 下 `tests/npu/a2a3/src/st/testcase/tgather` 目录**: + **pto-comm-isa** 与 **simpler-PTO 自带 `examples/scripts/_deps/pto-isa`** 中对应路径下的实现 **内容一致**(同一套 tgather_kernel.cpp / tgather_common.h)。 +- 即:当前 a2a3 的 TGATHER 用例,两份仓库在实现上无差异,本 case 的 Gather 以 pto-comm-isa 的 a2a3 tgather 为参考即可,头文件与 PTO_ISA_ROOT 使用 simpler-PTO 自带的 _deps/pto-isa 即可。 + +### 1.2 两套仓库的定位与区别(概览) + +| 项目 | 说明 | +|------|------| +| **simpler-PTO 自带 `_deps/pto-isa`** | 来自 `gitcode.com/cann/pto-isa`,由 README 要求克隆到 `examples/scripts/_deps/pto-isa`,作为 **PTO ISA 头文件与接口** 的来源,供 AIC/AIV kernel 编译使用。 | +| **pto-comm-isa** | 独立仓库(PTO Tile Library),除与 pto-isa 同源的 include 外,还包含 **kernels/**、**tests/**(含 a2a3/a5/cpu 等)、**docs/** 等,用于参考实现与测试。 | + +- **Include / API**:两边在 a2a3 相关头文件(如 `pto-inst.hpp`、TGATHER 等)上同源或兼容,本 case 仅依赖 _deps/pto-isa 的 include。 +- **测试与用例**:pto-comm-isa 的 `tests/npu/a2a3/src/st/testcase/tgather` 作为 **实现参考**(尤其是 runTGather1D、shape、流水);本 case 的 kernel 写在 simpler-PTO 内,编译时用 _deps/pto-isa。 + +## 2. Case 说明 + +- **GEMM**:64×64 float 块乘,复用与 bgemm 同构的实现(AIC),单 task。 +- **Gather**:1D 索引 gather,`out = src0[src1]`,shape 与 pto-comm-isa a2a3 tgather 一致(src0: 32×1024 float,src1: 16×64 int32,out: 16×64 float),AIV,单 task。 +- **任务依赖**:t0(GEMM)→ t1(Gather),先计算再通信。 + +## 3. 目录与运行 + +``` +gemm_gather/ +├── README.md +├── golden.py +└── kernels/ + ├── kernel_config.py + ├── orchestration/ + │ └── gemm_gather_orch.cpp + ├── aic/ + │ └── kernel_gemm.cpp + └── aiv/ + ├── tgather_common.h + └── kernel_gather.cpp +``` + +**a2a3 运行示例**(需 Ascend 设备与 CANN 环境): + +```bash +python examples/scripts/run_example.py \ + -k examples/host_build_graph/gemm_gather/kernels \ + -g examples/host_build_graph/gemm_gather/golden.py \ + -p a2a3 -d +``` + +## 4. Shape 约定(与 comm-isa 一致) + +- **GEMM**:A(64,64), B(64,64), C(64,64),float。 +- **Gather**:src0 (32, 1024) float,src1 (16, 64) int32(索引),out (16, 64) float。 diff --git a/examples/host_build_graph/gemm_gather/golden.py b/examples/host_build_graph/gemm_gather/golden.py new file mode 100644 index 00000000..ea10a67d --- /dev/null +++ b/examples/host_build_graph/gemm_gather/golden.py @@ -0,0 +1,97 @@ +""" +Golden test for gemm_gather (a2a3). + +GEMM: C = A @ B, 64x64 float. +Gather: out = src0[src1] (linear index), src0 (32, 1024), src1 (16, 64) int32, out (16, 64) float. +Args: [A, B, C, src0, src1, out, size_A, size_B, size_C, size_src0, size_src1, size_out] +code_runner passes torch tensors to compute_golden; use .numel() and .item() for compatibility. +""" + +import ctypes +import numpy as np +import torch + +__outputs__ = ["C", "out"] +RTOL = 1e-3 +ATOL = 1e-3 + +# GEMM: single 64x64 tile +GEMM_TILE = 64 + +# Gather: comm-isa shape +GATHER_SRC0_ROWS = 32 +GATHER_SRC0_COLS = 1024 +GATHER_SRC1_ROWS = 16 +GATHER_SRC1_COLS = 64 + + +def generate_inputs(params: dict) -> list: + np.random.seed(42) + + # GEMM: A, B, C (64, 64) float + A = np.random.randn(GEMM_TILE, GEMM_TILE).astype(np.float32) * 0.01 + B = np.random.randn(GEMM_TILE, GEMM_TILE).astype(np.float32) * 0.01 + C = np.zeros((GEMM_TILE, GEMM_TILE), dtype=np.float32) + + # Gather: src0 (32, 1024), src1 (16, 64) int32 indices, out (16, 64) + src0 = np.random.randn(GATHER_SRC0_ROWS, GATHER_SRC0_COLS).astype(np.float32) * 0.01 + src0_flat = src0.flatten() + max_idx = src0_flat.size - 1 + src1 = np.random.randint(0, max_idx + 1, size=(GATHER_SRC1_ROWS, GATHER_SRC1_COLS), dtype=np.int32) + out = np.zeros((GATHER_SRC1_ROWS, GATHER_SRC1_COLS), dtype=np.float32) + + A_flat = A.flatten() + B_flat = B.flatten() + C_flat = C.flatten() + src0_flat = src0.flatten() + src1_flat = src1.flatten() + out_flat = out.flatten() + + return [ + ("A", A_flat), + ("B", B_flat), + ("C", C_flat), + ("src0", src0_flat), + ("src1", src1_flat), + ("out", out_flat), + ("size_A", ctypes.c_int64(A_flat.nbytes)), + ("size_B", ctypes.c_int64(B_flat.nbytes)), + ("size_C", ctypes.c_int64(C_flat.nbytes)), + ("size_src0", ctypes.c_int64(src0_flat.nbytes)), + ("size_src1", ctypes.c_int64(src1_flat.nbytes)), + ("size_out", ctypes.c_int64(out_flat.nbytes)), + ] + + +def _numel(x): + """Element count for both numpy and torch (code_runner passes torch tensors).""" + if hasattr(x, "numel") and callable(x.numel): + return x.numel() + return int(x.size) + + +def compute_golden(tensors: dict, params: dict) -> None: + A = tensors["A"].reshape(GEMM_TILE, GEMM_TILE) + B = tensors["B"].reshape(GEMM_TILE, GEMM_TILE) + C = tensors["C"].reshape(GEMM_TILE, GEMM_TILE) + src0 = tensors["src0"].reshape(GATHER_SRC0_ROWS, GATHER_SRC0_COLS) + src1 = tensors["src1"].reshape(GATHER_SRC1_ROWS, GATHER_SRC1_COLS) + out = tensors["out"].reshape(GATHER_SRC1_ROWS, GATHER_SRC1_COLS) + + # GEMM: C = A @ B (tensors from code_runner are torch) + C.copy_(torch.matmul(A.float(), B.float())) + + # Gather: out[i,j] = src0.flatten()[src1[i,j]] + src0_flat = src0.flatten() + n_src0 = _numel(src0_flat) + for i in range(GATHER_SRC1_ROWS): + for j in range(GATHER_SRC1_COLS): + idx = int(src1[i, j].item() if hasattr(src1[i, j], "item") else src1[i, j]) + if idx < 0: + idx = 0 + if idx >= n_src0: + idx = n_src0 - 1 + out[i, j] = src0_flat[idx] + + tensors["C"][:] = C.flatten() + tensors["out"][:] = out.flatten() diff --git a/examples/host_build_graph/gemm_gather/kernels/aic/kernel_gemm.cpp b/examples/host_build_graph/gemm_gather/kernels/aic/kernel_gemm.cpp new file mode 100644 index 00000000..41f8acef --- /dev/null +++ b/examples/host_build_graph/gemm_gather/kernels/aic/kernel_gemm.cpp @@ -0,0 +1,90 @@ +/** + * Tile-based Matrix Multiplication Kernel (Cube Core, a2a3) + * + * Computes: output = input_a @ input_b (64x64 tile matmul) + * Uses TMATMUL instruction. Copied from host_build_graph/bgemm for standalone case. + */ + +#include +#include +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE constexpr inline T CeilAlign(T num_1, T num_2) { + if (num_2 == 0) { + return 0; + } + return (num_1 + num_2 - 1) / num_2 * num_2; +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* input_a = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* input_b = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ float* output = reinterpret_cast<__gm__ float*>(args[2]); + + constexpr int TILE = 64; + constexpr int blockAlign = C0_SIZE_BYTE / sizeof(float); + constexpr int M = CeilAlign(TILE, 16); + constexpr int K = CeilAlign(TILE, blockAlign); + constexpr int N = CeilAlign(TILE, blockAlign); + + using GlobalDataA = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataB = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataC = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + + GlobalDataA src0Global(input_a); + GlobalDataB src1Global(input_b); + GlobalDataC dstGlobal(output); + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + TLOAD(aMatTile, src0Global); + TLOAD(bMatTile, src1Global); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(dstGlobal, cTile); +} diff --git a/examples/host_build_graph/gemm_gather/kernels/aiv/kernel_gather.cpp b/examples/host_build_graph/gemm_gather/kernels/aiv/kernel_gather.cpp new file mode 100644 index 00000000..8277a892 --- /dev/null +++ b/examples/host_build_graph/gemm_gather/kernels/aiv/kernel_gather.cpp @@ -0,0 +1,75 @@ +/** + * 1D Gather Kernel (AIV, a2a3) + * + * out = src0[src1] with comm-isa shape: src0 (32, 1024) float, src1 (16, 64) int32, out (16, 64) float. + * Reference: pto-comm-isa tests/npu/a2a3/src/st/testcase/tgather (runTGather1D). + * Uses TGATHER on Vector pipe (PIPE_V). + * All logic in kernel_entry (same pattern as vector_example kernel_add.cpp) so set_flag/wait_flag + * are used only in __aicore__ context. + */ + +#include +#include +#include "tgather_common.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* out = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* src0 = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ int32_t* src1 = reinterpret_cast<__gm__ int32_t*>(args[2]); + + constexpr int kGRows0_ = GATHER_SRC0_ROWS; + constexpr int kGCols0_ = GATHER_SRC0_COLS; + constexpr int kGRows1_ = GATHER_SRC1_ROWS; + constexpr int kGCols1_ = GATHER_SRC1_COLS; + constexpr int src0_row = kGRows0_; + constexpr int src0_col = kGCols0_; + constexpr int src1_row = kGRows1_; + constexpr int src1_col = kGCols1_; + constexpr int dst_row = kGRows1_; + constexpr int dst_col = kGCols1_; + + using DynShapeDim5_src0 = pto::Shape<1, 1, 1, kGRows0_, kGCols0_>; + using DynStridDim5_src0 = pto::Stride<1, 1, 1, kGCols0_, 1>; + using GlobalData_src0 = GlobalTensor; + using DynShapeDim5_src1 = pto::Shape<1, 1, 1, kGRows1_, kGCols1_>; + using DynStridDim5_src1 = pto::Stride<1, 1, 1, kGCols1_, 1>; + using GlobalData_src1 = GlobalTensor; + using DynShapeDim5_dst = pto::Shape<1, 1, 1, kGRows1_, kGCols1_>; + using DynStridDim5_dst = pto::Stride<1, 1, 1, kGCols1_, 1>; + using GlobalData_dst = GlobalTensor; + + using TileData_src0 = Tile; + using TileData_src1 = Tile; + using TileData_dst = Tile; + + TileData_src0 src0Tile(src0_row, src0_col); + TileData_src1 src1Tile(src1_row, src1_col); + TileData_dst dstTile(dst_row, dst_col); + + TASSIGN(src0Tile, 0x0); + TASSIGN(src1Tile, 0x20000); + TASSIGN(dstTile, 0x28000); + + GlobalData_src0 src0Global(src0); + GlobalData_src1 src1Global(src1); + GlobalData_dst dstGlobal(out); + + TLOAD(src0Tile, src0Global); + TLOAD(src1Tile, src1Global); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TGATHER(dstTile, src0Tile, src1Tile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(dstGlobal, dstTile); +} diff --git a/examples/host_build_graph/gemm_gather/kernels/aiv/tgather_common.h b/examples/host_build_graph/gemm_gather/kernels/aiv/tgather_common.h new file mode 100644 index 00000000..b93b9a67 --- /dev/null +++ b/examples/host_build_graph/gemm_gather/kernels/aiv/tgather_common.h @@ -0,0 +1,14 @@ +/** + * Shape constants for TGATHER (a2a3, comm-isa aligned). + * Used by kernel_gather.cpp. Same values as pto-comm-isa a2a3 tgather. + */ +#ifndef GEMM_GATHER_TGATHER_COMMON_H +#define GEMM_GATHER_TGATHER_COMMON_H + +// runTGather1D_float: src0 (32, 1024), src1 (16, 64), out (16, 64) +#define GATHER_SRC0_ROWS 32 +#define GATHER_SRC0_COLS 1024 +#define GATHER_SRC1_ROWS 16 +#define GATHER_SRC1_COLS 64 + +#endif diff --git a/examples/host_build_graph/gemm_gather/kernels/kernel_config.py b/examples/host_build_graph/gemm_gather/kernels/kernel_config.py new file mode 100644 index 00000000..37525822 --- /dev/null +++ b/examples/host_build_graph/gemm_gather/kernels/kernel_config.py @@ -0,0 +1,26 @@ +""" +Kernel and orchestration configuration for gemm_gather (a2a3). + +Two kernels: GEMM (AIC), Gather (AIV). +Dependency: compute first, then communication (t0 -> t1). +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "gemm_gather_orch.cpp"), + "function_name": "build_gemm_gather_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "GEMM", "source": str(_KERNELS_ROOT / "aic" / "kernel_gemm.cpp"), "core_type": "aic"}, + {"func_id": 1, "name": "Gather", "source": str(_KERNELS_ROOT / "aiv" / "kernel_gather.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "aicpu_thread_num": 3, + "block_dim": 3, +} diff --git a/examples/host_build_graph/gemm_gather/kernels/orchestration/gemm_gather_orch.cpp b/examples/host_build_graph/gemm_gather/kernels/orchestration/gemm_gather_orch.cpp new file mode 100644 index 00000000..f1352c4d --- /dev/null +++ b/examples/host_build_graph/gemm_gather/kernels/orchestration/gemm_gather_orch.cpp @@ -0,0 +1,116 @@ +/** + * GEMM + Gather orchestration (a2a3). + * + * Two independent tasks (no dependency): + * Task 0: C = A @ B (64x64 GEMM, AIC) + * Task 1: out = src0[src1] (Gather, AIV, comm-isa shape) + */ + +#include "runtime.h" +#include +#include + +extern "C" { + +constexpr int GEMM_TILE = 64; +constexpr int GATHER_SRC0_ROWS = 32; +constexpr int GATHER_SRC0_COLS = 1024; +constexpr int GATHER_SRC1_ROWS = 16; +constexpr int GATHER_SRC1_COLS = 64; + +int build_gemm_gather_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 12) { + std::cerr << "build_gemm_gather_graph: Expected at least 12 args, got " << arg_count << '\n'; + return -1; + } + + void* host_A = reinterpret_cast(args[0]); + void* host_B = reinterpret_cast(args[1]); + void* host_C = reinterpret_cast(args[2]); + void* host_src0 = reinterpret_cast(args[3]); + void* host_src1 = reinterpret_cast(args[4]); + void* host_out = reinterpret_cast(args[5]); + size_t size_A = static_cast(args[6]); + size_t size_B = static_cast(args[7]); + size_t size_C = static_cast(args[8]); + size_t size_src0 = static_cast(args[9]); + size_t size_src1 = static_cast(args[10]); + size_t size_out = static_cast(args[11]); + + std::cout << "\n=== build_gemm_gather_graph (a2a3) ===" << '\n'; + + // Allocate device memory and copy inputs + void* dev_A = runtime->host_api.device_malloc(size_A); + if (!dev_A) return -1; + runtime->host_api.copy_to_device(dev_A, host_A, size_A); + + void* dev_B = runtime->host_api.device_malloc(size_B); + if (!dev_B) { + runtime->host_api.device_free(dev_A); + return -1; + } + runtime->host_api.copy_to_device(dev_B, host_B, size_B); + + void* dev_C = runtime->host_api.device_malloc(size_C); + if (!dev_C) { + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + return -1; + } + runtime->record_tensor_pair(host_C, dev_C, size_C); + + void* dev_src0 = runtime->host_api.device_malloc(size_src0); + if (!dev_src0) { + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + runtime->host_api.device_free(dev_C); + return -1; + } + runtime->host_api.copy_to_device(dev_src0, host_src0, size_src0); + + void* dev_src1 = runtime->host_api.device_malloc(size_src1); + if (!dev_src1) { + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + runtime->host_api.device_free(dev_C); + runtime->host_api.device_free(dev_src0); + return -1; + } + runtime->host_api.copy_to_device(dev_src1, host_src1, size_src1); + + void* dev_out = runtime->host_api.device_malloc(size_out); + if (!dev_out) { + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + runtime->host_api.device_free(dev_C); + runtime->host_api.device_free(dev_src0); + runtime->host_api.device_free(dev_src1); + return -1; + } + runtime->record_tensor_pair(host_out, dev_out, size_out); + + // Task 0: GEMM C = A @ B (func_id=0, AIC) + uint64_t args_geom[3]; + args_geom[0] = reinterpret_cast(dev_A); + args_geom[1] = reinterpret_cast(dev_B); + args_geom[2] = reinterpret_cast(dev_C); + int t0 = runtime->add_task(args_geom, 3, 0, CoreType::AIC); + + // Task 1: Gather out = src0[src1] (func_id=1, AIV) + uint64_t args_gather[3]; + args_gather[0] = reinterpret_cast(dev_out); + args_gather[1] = reinterpret_cast(dev_src0); + args_gather[2] = reinterpret_cast(dev_src1); + int t1 = runtime->add_task(args_gather, 3, 1, CoreType::AIV); + + // Dependency: compute first, then communication (t0 -> t1) + runtime->add_successor(t0, t1); + + std::cout << " task" << t0 << ": GEMM C=A@B [AIC]\n"; + std::cout << " task" << t1 << ": Gather out=src0[src1] [AIV]\n"; + std::cout << " Dependency: t0 -> t1 (compute then communication).\n"; + + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/multi_bgemm/PLAN_multi_device_compute.md b/examples/host_build_graph/multi_bgemm/PLAN_multi_device_compute.md new file mode 100644 index 00000000..c267c2be --- /dev/null +++ b/examples/host_build_graph/multi_bgemm/PLAN_multi_device_compute.md @@ -0,0 +1,109 @@ +# 多卡独立跑计算算子 — 修改计划(multi_bgemm) + +目标:新增一个 case,**仅跑计算 kernel**(与 bgemm 完全一致),**一次调用**可在 **2 张或 4 张卡**上各自独立跑同一套 graph,**无卡间同步、无通信、无依赖**。 + +--- + +## 一、目标与约束 + +- **Case 名**:`multi_bgemm`(路径:`examples/host_build_graph/multi_bgemm/`) +- **单卡逻辑**:与现有 `bgemm` 完全一致(同一套 orchestration、同一套 kernel:GEMM + tile_add,同一套 golden) +- **多卡方式**:一次 `run_example.py` 调用可指定 2 或 4 张卡;每张卡独立跑同一 case,卡间无同步、无通信 +- **不引入**:HCCL、comm_gather、fork 建联等;沿用现有 host_build_graph 单卡 Runtime 流程 + +--- + +## 二、实现思路 + +- 每张卡执行的都是**同一套** host_build_graph 流程:build runtime → 编 orchestration + kernels → `set_device(device_id)` → 对同一组输入跑 `initialize` + `launch_runtime` → `finalize` → 与 golden 比对。 +- 多卡实现方式:**多进程**。主进程根据 `n_devices`(2 或 4)起 N 个子进程,每个子进程执行一次「单卡 run」:即当前已有的 `run_example.py -k ... -g ... -d ` 流程;主进程汇总 N 个子进程退出码,全部为 0 才认为通过。 +- 这样无需改 Runtime、不改 bindings、不改单卡逻辑,仅在一层「调度」上做多卡扩展。 + +--- + +## 三、端到端需修改/新增内容 + +### 3.1 新增 case 目录 `multi_bgemm` + +| 项 | 说明 | +|----|------| +| **目录** | `examples/host_build_graph/multi_bgemm/` | +| **与 bgemm 关系** | 与 bgemm 同构:同一套编排与 kernel、同一套 golden;仅配置上增加「多卡」相关项。 | +| **建议做法** | 从 bgemm 拷贝以下内容到 multi_bgemm,再按下面调整配置与文档:
• `kernels/`(含 `kernel_config.py`、`orchestration/`、`aic/`、`aiv/`)
• `golden.py`
• `README.md`(重写为 multi_bgemm 说明) | + +**需要改动的仅**: + +- **`kernels/kernel_config.py`** + - 在 `RUNTIME_CONFIG` 中增加 `n_devices`,默认 `2`(或 `4`,由你定);其余(ORCHESTRATION、KERNELS)与 bgemm 保持一致。 +- **`README.md`** + - 说明本 case 为「多卡独立跑 bgemm」,支持 2/4 卡;示例命令用 `--n-devices 2` 或 `--n-devices 4`。 + +### 3.2 命令行与 code_runner 入参 + +| 位置 | 修改内容 | +|------|----------| +| **run_example.py** | 新增可选参数 `--n-devices`(类型 `int`,默认 `None`)。解析后传入 `create_code_runner(..., n_devices=args.n_devices)`。 | +| **code_runner.create_code_runner** | 增加形参 `n_devices=None`,并传给 `CodeRunner`。 | +| **CodeRunner.__init__** | 增加 `n_devices` 参数。逻辑:若调用方传入 `n_devices is not None`,则用传入值;否则从 `kernel_config.RUNTIME_CONFIG.get("n_devices", 1)` 读取,默认 `1`。这样既支持「按 case 默认多卡」,也支持「命令行覆盖」。 | + +### 3.3 code_runner.run() 中「多卡分支」 + +- **触发条件**:在 `run()` 开头(在 `comm_gather` 分支之后、正常单卡 build 之前)判断:若 `self.n_devices > 1`,则走「多卡独立跑」分支,直接 return,不再走下面单卡 build/launch。 +- **多卡分支逻辑**: + 1. 解析当前脚本所在目录,得到 `run_example.py` 的路径(与 code_runner 同目录:`Path(__file__).resolve().parent / "run_example.py"`)。 + 2. 对 `device_id in range(self.n_devices)` 依次(或并行,见下)执行子进程: + - 命令:`[sys.executable, str(run_example.py), "-k", str(self.kernels_dir), "-g", str(self.golden_path), "-d", str(device_id), "-p", self.platform]` + - 可选:把当前 `--all` / `--case` / `--log-level` 等与单次 run 相关的参数一并传入,保证与单卡行为一致。 + 3. 等待所有子进程结束;若任一子进程 `returncode != 0`,则抛出异常或置失败,并带上设备 id 信息;若全部为 0,则视为通过。 +- **并行 vs 串行**:为简单起见先做成**串行**(按 device 0,1,... 顺序跑),避免多进程同时 build 争抢;若你希望 2/4 卡并行跑,可再改为 `concurrent.futures.ProcessPoolExecutor` 或 `multiprocessing.Pool`,但需注意编译/资源竞争,建议首版串行。 +- **Build 次数**:每个子进程都会完整跑一遍 run_example(含 build runtime、编 orchestration、编 kernel),因此会 build N 次;实现简单,首版接受该冗余,后续若有需要可再做「只跑不编」的优化。 + +### 3.4 不需要改动的部分 + +- **Runtime / bindings / 单卡 launch**:不变。 +- **单卡 case(如 bgemm)**:不加 `n_devices` 时行为与现在完全一致(`n_devices` 默认 1,不进入多卡分支)。 +- **comm_gather**:不受影响,仍走 `_run_comm_gather()`。 + +--- + +## 四、文件与配置清单(实施顺序) + +| 步骤 | 操作 | +|------|------| +| 1 | 新建目录 `examples/host_build_graph/multi_bgemm/`。 | +| 2 | 从 `bgemm` 拷贝 `golden.py`、`kernels/`(整个目录:kernel_config.py、orchestration/、aic/、aiv/)到 `multi_bgemm/`。 | +| 3 | 修改 `multi_bgemm/kernels/kernel_config.py`:增加 `RUNTIME_CONFIG = { "runtime": "host_build_graph", "n_devices": 2 }`(或 4);若 bgemm 当前无 `RUNTIME_CONFIG`,则仅在 multi_bgemm 中显式写出,保持与 bgemm 相同的 ORCHESTRATION、KERNELS。 | +| 4 | 编写 `multi_bgemm/README.md`:说明本 case 为多卡独立跑计算(与 bgemm 同逻辑),示例命令 `--n-devices 2` / `--n-devices 4`,无卡间同步与通信。 | +| 5 | **run_example.py**:增加 `--n-devices` 参数,并传入 `create_code_runner(..., n_devices=args.n_devices)`。 | +| 6 | **code_runner.py**:`create_code_runner` 增加 `n_devices` 参数;`CodeRunner.__init__` 中读取并保存 `self.n_devices`(来自参数或 `RUNTIME_CONFIG`,默认 1)。 | +| 7 | **code_runner.run()**:在 `comm_gather` 分支之后、单卡 build 之前,若 `self.n_devices > 1`,则执行「多卡子进程」逻辑(循环或并行启动 N 个 `run_example.py -d 0..N-1`),全部成功则 return,否则抛错。 | + +--- + +## 五、运行示例(计划通过后) + +```bash +# 2 张卡独立跑 multi_bgemm(每卡跑同一套 bgemm 计算) +python examples/scripts/run_example.py \ + -k examples/host_build_graph/multi_bgemm/kernels \ + -g examples/host_build_graph/multi_bgemm/golden.py \ + --n-devices 2 + +# 4 张卡(若 kernel_config 默认 n_devices=4 也可不写) +python examples/scripts/run_example.py \ + -k examples/host_build_graph/multi_bgemm/kernels \ + -g examples/host_build_graph/multi_bgemm/golden.py \ + --n-devices 4 +``` + +每张卡使用与 bgemm 相同的输入(各自 `generate_inputs` 一份,或子进程内各自生成,因随机数可能不同;若需完全一致可后续改为主进程生成写文件再传入,首版可保持各进程独立 `generate_inputs`,golden 仍按同一规则校验)。 + +--- + +## 六、小结 + +- **新 case**:`multi_bgemm`,与 bgemm 同编排、同 kernel、同 golden,仅配置多卡数。 +- **入口**:`run_example.py` 增加 `--n-devices`;code_runner 支持 `n_devices` 从配置或命令行传入。 +- **执行**:`n_devices > 1` 时用子进程对每张卡跑一次「单卡 run_example」流程,无卡间同步与通信,全部成功即通过。 + +确认该计划 OK 后,再按上述步骤改代码。 diff --git a/examples/host_build_graph/multi_bgemm/README.md b/examples/host_build_graph/multi_bgemm/README.md new file mode 100644 index 00000000..7bbc5525 --- /dev/null +++ b/examples/host_build_graph/multi_bgemm/README.md @@ -0,0 +1,47 @@ +# multi_bgemm — 多卡独立跑 BGEMM(无通信) + +在 2 张或 4 张卡上**并行**跑与 **bgemm** 完全相同的计算(C = A @ B,4x4x4 grid、64x64 tile),每张卡独立执行,无卡间同步、无通信。 + +## 用法 + +```bash +# 2 张卡,起始卡号 0 → device 0, 1 并行 +python examples/scripts/run_example.py \ + -k examples/host_build_graph/multi_bgemm/kernels \ + -g examples/host_build_graph/multi_bgemm/golden.py \ + --n-devices 2 --first-device 0 + +# 4 张卡,起始卡号 4 → device 4, 5, 6, 7 并行 +python examples/scripts/run_example.py \ + -k examples/host_build_graph/multi_bgemm/kernels \ + -g examples/host_build_graph/multi_bgemm/golden.py \ + --n-devices 4 --first-device 4 +``` + +- **--n-devices**:卡数(不传时使用 kernel_config 中 `RUNTIME_CONFIG["n_devices"]`,默认 2)。 +- **--first-device**:起始卡号(不传时使用 `RUNTIME_CONFIG["first_device_id"]`,默认 0)。 +- 设备号区间:`[first-device, first-device + n-devices)`。 + +单卡时可不传 `--n-devices` 或传 `--n-devices 1`,则走普通单卡流程。 + +## 行为说明 + +- 与 **bgemm** 使用同一套 orchestration(`build_bgemm_graph`)、同一套 kernel(GEMM + tile_add)、同一套 golden。 +- **编译与运行分离**:主进程先 `compile()` 一次,创建 N 个 CodeRunner(传入 compiled_artifacts),通过 `ProcessPoolExecutor` 多进程并行执行各 `runner.run()`,**无重复编译**。 +- 不引入 HCCL、通信算子或建联逻辑;与后续多卡通信方案兼容(通信 case 将使用独立 C++ 入口)。 + +## 目录结构 + +``` +multi_bgemm/ +├── golden.py +├── README.md +└── kernels/ + ├── kernel_config.py # 含 RUNTIME_CONFIG (n_devices, first_device_id) + ├── orchestration/ + │ └── bgemm_orch.cpp + ├── aic/ + │ └── kernel_gemm_tile.cpp + └── aiv/ + └── kernel_tile_add.cpp +``` diff --git a/examples/host_build_graph/multi_bgemm/golden.py b/examples/host_build_graph/multi_bgemm/golden.py new file mode 100644 index 00000000..84788315 --- /dev/null +++ b/examples/host_build_graph/multi_bgemm/golden.py @@ -0,0 +1,68 @@ +""" +Golden test specification for BGEMM (Host Build Graph Runtime). + +Computation: C = A @ B (tiled matrix multiplication) +Configuration: 4x4x4 grid, 64x64 tiles + +Args layout: [ptr_A, ptr_B, ptr_C, size_A, size_B, size_C] +""" + +import ctypes +import numpy as np + +__outputs__ = ["C"] +RTOL = 1e-3 +ATOL = 1e-3 + +TILE_M = 64 +TILE_K = 64 +TILE_N = 64 + +GRID_M = 4 +GRID_K = 4 +GRID_N = 4 +BATCH = 1 + +M = TILE_M * GRID_M +K = TILE_K * GRID_K +N = TILE_N * GRID_N + + +def generate_inputs(params: dict) -> list: + """Generate input tensors with tile-first memory layout.""" + A = np.random.randn(BATCH, GRID_M, GRID_K, TILE_M, TILE_K).astype(np.float32) * 0.01 + B = np.random.randn(BATCH, GRID_K, GRID_N, TILE_K, TILE_N).astype(np.float32) * 0.01 + C = np.zeros((BATCH, GRID_M, GRID_N, TILE_M, TILE_N), dtype=np.float32) + + A_flat = A.flatten() + B_flat = B.flatten() + C_flat = C.flatten() + + return [ + ("A", A_flat), + ("B", B_flat), + ("C", C_flat), + ("size_A", ctypes.c_int64(A_flat.nbytes)), + ("size_B", ctypes.c_int64(B_flat.nbytes)), + ("size_C", ctypes.c_int64(C_flat.nbytes)), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute golden result: C[m,n] = sum(k) A[m,k] @ B[k,n].""" + A = tensors["A"].reshape(BATCH, GRID_M, GRID_K, TILE_M, TILE_K) + B = tensors["B"].reshape(BATCH, GRID_K, GRID_N, TILE_K, TILE_N) + C = tensors["C"].reshape(BATCH, GRID_M, GRID_N, TILE_M, TILE_N) + + C[:] = 0.0 + + for batch in range(BATCH): + for m_idx in range(GRID_M): + for n_idx in range(GRID_N): + for k_idx in range(GRID_K): + C[batch, m_idx, n_idx] += np.matmul( + A[batch, m_idx, k_idx], + B[batch, k_idx, n_idx] + ) + + tensors["C"][:] = C.flatten() diff --git a/examples/host_build_graph/multi_bgemm/kernels/aic/kernel_gemm_tile.cpp b/examples/host_build_graph/multi_bgemm/kernels/aic/kernel_gemm_tile.cpp new file mode 100644 index 00000000..92c93d32 --- /dev/null +++ b/examples/host_build_graph/multi_bgemm/kernels/aic/kernel_gemm_tile.cpp @@ -0,0 +1,90 @@ +/** + * Tile-based Matrix Multiplication Kernel (Cube Core) + * + * Computes: output = input_a @ input_b (64x64 tile matmul) + * Uses TMATMUL instruction + */ + +#include +#include +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE constexpr inline T CeilAlign(T num_1, T num_2) { + if (num_2 == 0) { + return 0; + } + return (num_1 + num_2 - 1) / num_2 * num_2; +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* input_a = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* input_b = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ float* output = reinterpret_cast<__gm__ float*>(args[2]); + + constexpr int TILE = 64; + constexpr int blockAlign = C0_SIZE_BYTE / sizeof(float); + constexpr int M = CeilAlign(TILE, 16); + constexpr int K = CeilAlign(TILE, blockAlign); + constexpr int N = CeilAlign(TILE, blockAlign); + + using GlobalDataA = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataB = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataC = GlobalTensor, + Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + + GlobalDataA src0Global(input_a); + GlobalDataB src1Global(input_b); + GlobalDataC dstGlobal(output); + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + TLOAD(aMatTile, src0Global); + TLOAD(bMatTile, src1Global); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(dstGlobal, cTile); +} diff --git a/examples/host_build_graph/multi_bgemm/kernels/aiv/kernel_tile_add.cpp b/examples/host_build_graph/multi_bgemm/kernels/aiv/kernel_tile_add.cpp new file mode 100644 index 00000000..61cb59a6 --- /dev/null +++ b/examples/host_build_graph/multi_bgemm/kernels/aiv/kernel_tile_add.cpp @@ -0,0 +1,53 @@ +/** + * Tile-based Element-wise Addition Kernel (Vector Core) + * + * Computes: output = input_a + input_b (64x64 tile addition) + * Uses TADD instruction + */ + +#include +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* input_a = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* input_b = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ float* output = reinterpret_cast<__gm__ float*>(args[2]); + + constexpr int TILE = 64; + + using DynShapeDim5 = Shape<1, 1, 1, TILE, TILE>; + using DynStridDim5 = Stride<1, 1, 1, TILE, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData aTile(TILE, TILE); + TileData bTile(TILE, TILE); + TileData outTile(TILE, TILE); + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x10000); + TASSIGN(outTile, 0x20000); + + GlobalData aGlobal(input_a); + GlobalData bGlobal(input_b); + GlobalData outGlobal(output); + + TLOAD(aTile, aGlobal); + TLOAD(bTile, bGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(outTile, aTile, bTile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(outGlobal, outTile); +} diff --git a/examples/host_build_graph/multi_bgemm/kernels/kernel_config.py b/examples/host_build_graph/multi_bgemm/kernels/kernel_config.py new file mode 100644 index 00000000..7bef9f78 --- /dev/null +++ b/examples/host_build_graph/multi_bgemm/kernels/kernel_config.py @@ -0,0 +1,34 @@ +""" +Kernel configuration for multi_bgemm (multi-device BGEMM, Host Build Graph Runtime). + +Same orchestration and kernels as bgemm; supports running on multiple devices in parallel. +Use --n-devices and --first-device to specify card count and starting device ID. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "bgemm_orch.cpp"), + "function_name": "build_bgemm_graph", +} + +KERNELS = [ + { + "func_id": 0, + "source": str(_KERNELS_ROOT / "aic" / "kernel_gemm_tile.cpp"), + "core_type": "aic", + }, + { + "func_id": 1, + "source": str(_KERNELS_ROOT / "aiv" / "kernel_tile_add.cpp"), + "core_type": "aiv", + }, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, +} diff --git a/examples/host_build_graph/multi_bgemm/kernels/orchestration/bgemm_orch.cpp b/examples/host_build_graph/multi_bgemm/kernels/orchestration/bgemm_orch.cpp new file mode 100644 index 00000000..6038c4ac --- /dev/null +++ b/examples/host_build_graph/multi_bgemm/kernels/orchestration/bgemm_orch.cpp @@ -0,0 +1,139 @@ +/** + * BGEMM Orchestration Function (Host Build Graph Runtime) + * + * Builds the task graph for tiled matrix multiplication: C = A @ B + * + * Configuration: + * - Tile size: 64 x 64 + * - Grid: 4 x 4 x 4 (GRID_M x GRID_K x GRID_N) + * + * Memory layout (tile-first): + * A: [BATCH, GRID_M, GRID_K, TILE_M, TILE_K] + * B: [BATCH, GRID_K, GRID_N, TILE_K, TILE_N] + * C: [BATCH, GRID_M, GRID_N, TILE_M, TILE_N] + * + * Task graph per output tile: + * for k in [0, GRID_K): + * P = A[m,k] @ B[k,n] (gemm_tile on Cube core) + * C[m,n] = C[m,n] + P (tile_add on Vector core) + */ + +#include "runtime.h" +#include +#include + +extern "C" { + +constexpr int TILE = 64; +constexpr int GRID_M = 4; +constexpr int GRID_K = 4; +constexpr int GRID_N = 4; +constexpr int BATCH = 1; + +constexpr size_t TILE_BYTES = TILE * TILE * sizeof(float); + +int build_bgemm_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 6) { + std::cerr << "build_bgemm_graph: Expected at least 6 args, got " << arg_count << '\n'; + return -1; + } + + void* host_A = reinterpret_cast(args[0]); + void* host_B = reinterpret_cast(args[1]); + void* host_C = reinterpret_cast(args[2]); + size_t size_A = static_cast(args[3]); + size_t size_B = static_cast(args[4]); + size_t size_C = static_cast(args[5]); + + std::cout << "\n=== build_bgemm_graph ===" << '\n'; + std::cout << "Grid: " << GRID_M << " x " << GRID_K << " x " << GRID_N << '\n'; + + // Allocate device memory and copy inputs + void* dev_A = runtime->host_api.device_malloc(size_A); + if (!dev_A) return -1; + runtime->host_api.copy_to_device(dev_A, host_A, size_A); + + void* dev_B = runtime->host_api.device_malloc(size_B); + if (!dev_B) { + runtime->host_api.device_free(dev_A); + return -1; + } + runtime->host_api.copy_to_device(dev_B, host_B, size_B); + + void* dev_C = runtime->host_api.device_malloc(size_C); + if (!dev_C) { + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + return -1; + } + runtime->host_api.copy_to_device(dev_C, host_C, size_C); + runtime->record_tensor_pair(host_C, dev_C, size_C); + + // Allocate intermediate P buffers (one per C tile) + constexpr int NUM_P_BUFFERS = BATCH * GRID_M * GRID_N; + std::vector dev_P(NUM_P_BUFFERS, nullptr); + for (int i = 0; i < NUM_P_BUFFERS; i++) { + dev_P[i] = runtime->host_api.device_malloc(TILE_BYTES); + if (!dev_P[i]) { + for (int j = 0; j < i; j++) { + runtime->host_api.device_free(dev_P[j]); + } + runtime->host_api.device_free(dev_A); + runtime->host_api.device_free(dev_B); + runtime->host_api.device_free(dev_C); + return -1; + } + } + + // Track last add task for each C tile (for K accumulation dependency) + std::vector last_add_task(BATCH * GRID_M * GRID_N, -1); + + // Build task graph: 4-level tiling loop + for (int batch = 0; batch < BATCH; batch++) { + for (int m_idx = 0; m_idx < GRID_M; m_idx++) { + for (int n_idx = 0; n_idx < GRID_N; n_idx++) { + for (int k_idx = 0; k_idx < GRID_K; k_idx++) { + // Calculate tile offsets + size_t A_offset = (batch * GRID_M * GRID_K + m_idx * GRID_K + k_idx) * TILE_BYTES; + size_t B_offset = (batch * GRID_K * GRID_N + k_idx * GRID_N + n_idx) * TILE_BYTES; + size_t C_offset = (batch * GRID_M * GRID_N + m_idx * GRID_N + n_idx) * TILE_BYTES; + + int c_tile_idx = batch * GRID_M * GRID_N + m_idx * GRID_N + n_idx; + + // Task 1: P = A[m,k] @ B[k,n] (gemm_tile on Cube core) + uint64_t args_gemm[6]; + args_gemm[0] = reinterpret_cast(static_cast(dev_A) + A_offset); + args_gemm[1] = reinterpret_cast(static_cast(dev_B) + B_offset); + args_gemm[2] = reinterpret_cast(dev_P[c_tile_idx]); + args_gemm[3] = TILE; + args_gemm[4] = TILE; + args_gemm[5] = TILE; + int t_gemm = runtime->add_task(args_gemm, 6, 0, CoreType::AIC); + + // Task 2: C[m,n] = C[m,n] + P (tile_add on Vector core) + uint64_t args_add[5]; + args_add[0] = reinterpret_cast(static_cast(dev_C) + C_offset); + args_add[1] = reinterpret_cast(dev_P[c_tile_idx]); + args_add[2] = reinterpret_cast(static_cast(dev_C) + C_offset); + args_add[3] = TILE; + args_add[4] = TILE; + int t_add = runtime->add_task(args_add, 5, 1, CoreType::AIV); + + // Dependency: gemm must complete before add + runtime->add_successor(t_gemm, t_add); + + // Dependency: previous add must complete before current gemm (K accumulation) + if (last_add_task[c_tile_idx] >= 0) { + runtime->add_successor(last_add_task[c_tile_idx], t_gemm); + } + last_add_task[c_tile_idx] = t_add; + } + } + } + } + + std::cout << "Created " << runtime->get_task_count() << " tasks\n"; + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/README.md b/examples/host_build_graph/paged_attention_allgather_Manual/README.md new file mode 100644 index 00000000..1df38b14 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/README.md @@ -0,0 +1,14 @@ +# Paged Attention + AllGather (Manual) - host_build_graph + +Paged Attention 计算后 AllGather(直接 RDMA 读取)。 + +流程:QK → Softmax → PV → OnlineUpdate → WindowMemCopyIn → CommBarrier(pre) +→ AllGatherManual → WindowMemCopyOut → CommBarrier(post) + +所有 rank 获得完整 allgather 输出。 + +## 运行 + +```bash +./run_hostbuild.sh paged_attention_allgather_Manual 2 0 +``` diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/golden.py b/examples/host_build_graph/paged_attention_allgather_Manual/golden.py new file mode 100644 index 00000000..c17a62f4 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/golden.py @@ -0,0 +1,148 @@ +""" +Paged Attention + AllGather (Manual): Paged Attention → AllGather. + +Each rank independently computes paged attention on its own Q/K/V data, +then AllGather: every rank gets the concatenation of all ranks' first +GATHER_COUNT elements of attn_out. + +Same golden logic as paged_attention_allgather_Tgather. +""" + +import ctypes +import struct +import torch +import numpy as np + +GATHER_COUNT = 64 +BATCH = 1 +NUM_HEADS = 16 +KV_HEAD_NUM = 1 +HEAD_DIM = 16 +BLOCK_SIZE = 16 +CONTEXT_LEN = 16 +MAX_MODEL_LEN = 256 + +__outputs__ = ["attn_out", "allgather_out"] +RTOL = 1e-2 +ATOL = 1e-2 +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" + +def _make_block_table_and_context(): + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + cur_valid_blocks = (CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE + total_blocks = BATCH * cur_valid_blocks + torch.manual_seed(100) + block_table = torch.randint(0, max(total_blocks, 1), size=(BATCH, max_num_blocks_per_req), dtype=torch.int32) + context_lens = torch.full((BATCH,), CONTEXT_LEN, dtype=torch.int32) + return block_table, context_lens, total_blocks, max_num_blocks_per_req + +def _make_qkv(rank_id, total_blocks): + torch.manual_seed(42 + rank_id) + q = (torch.rand(BATCH, 1, NUM_HEADS * HEAD_DIM) - 0.5).to(torch.float16) + q = q.reshape(BATCH, NUM_HEADS, HEAD_DIM) + k = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) - 0.5).to(torch.float16) + v = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) * 2 - 1).to(torch.float16) + return q, k, v + +def generate_inputs(params: dict) -> list: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + block_table, context_lens, total_blocks, max_num_blocks_per_req = _make_block_table_and_context() + query_fp16, key_fp16, value_fp16 = _make_qkv(rank_id, total_blocks) + scale_value = 1.0 + scale_bits = struct.unpack('I', struct.pack('f', scale_value))[0] + config = torch.tensor([BATCH, NUM_HEADS, KV_HEAD_NUM, HEAD_DIM, BLOCK_SIZE, max_num_blocks_per_req, scale_bits], dtype=torch.int64) + query = query_fp16.flatten() + key_cache = key_fp16.flatten() + value_cache = value_fp16.flatten() + block_table_flat = block_table.flatten() + attn_out = torch.zeros(BATCH * NUM_HEADS * HEAD_DIM, dtype=torch.float32) + allgather_out = torch.zeros(n_ranks * GATHER_COUNT, dtype=torch.float32) + result = [ + ("query", query), ("key_cache", key_cache), ("value_cache", value_cache), + ("block_table", block_table_flat), ("context_lens", context_lens), + ("attn_out", attn_out), ("allgather_out", allgather_out), ("config", config), + ("size_query", ctypes.c_int64(query.nbytes)), ("size_key_cache", ctypes.c_int64(key_cache.nbytes)), + ("size_value_cache", ctypes.c_int64(value_cache.nbytes)), ("size_block_table", ctypes.c_int64(block_table_flat.nbytes)), + ("size_context_lens", ctypes.c_int64(context_lens.nbytes)), ("size_attn_out", ctypes.c_int64(attn_out.nbytes)), + ("size_allgather_out", ctypes.c_int64(allgather_out.nbytes)), ("size_config", ctypes.c_int64(config.nbytes)), + ] + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), ("root", ctypes.c_int32(root)), ("rank_id", ctypes.c_int32(rank_id)), + ]) + return result + +def paged_attention(query, key_cache, value_cache, num_kv_heads, num_heads, scale_value, block_table, context_lens): + assert num_kv_heads == 1 + batch, num_heads_dim, head_dim = query.shape + _, block_size, _, _ = key_cache.shape + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + q_tile = min(num_heads_dim, 128) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + oi, li, mi = None, None, None + for bn in range(max_bn): + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 + if not active_mask.any(): break + block_indices = block_table[:, bn] + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) + valid_mask = pos < valid_lens.unsqueeze(1) + valid_mask = valid_mask.unsqueeze(1) + sij = sij.masked_fill(~valid_mask, float('-inf')) + batch_mask = active_mask.view(-1, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + mij = sij.max(dim=-1, keepdim=True)[0] + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) + oi_new = torch.bmm(pij, vj_all) + if bn == 0: + oi, li, mi = oi_new, lij, mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + return out.reshape(-1, head_dim) + +def _compute_rank_attn(rank_id, block_table, context_lens, total_blocks): + q, k, v = _make_qkv(rank_id, total_blocks) + return paged_attention(q, k.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), v.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), + KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens) + +def compute_golden(tensors: dict, params: dict) -> None: + n_ranks = params.get("n_ranks", 2) + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + total_blocks = BATCH * ((CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE) + block_table = tensors["block_table"].reshape(BATCH, max_num_blocks_per_req) + context_lens_t = tensors["context_lens"] + query = tensors["query"].reshape(BATCH, NUM_HEADS, HEAD_DIM) + key_cache = tensors["key_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + value_cache = tensors["value_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + attn_result = paged_attention(query, key_cache, value_cache, KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens_t) + tensors["attn_out"][:] = attn_result.flatten() + allgather_np = tensors["allgather_out"].cpu().numpy() if hasattr(tensors["allgather_out"], 'cpu') else np.asarray(tensors["allgather_out"]) + for r in range(n_ranks): + attn_r = _compute_rank_attn(r, block_table, context_lens_t, total_blocks) + flat_r = attn_r.flatten().numpy() + allgather_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = flat_r[:GATHER_COUNT] diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aic/aic_pv_matmul.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aic/aic_pv_matmul.cpp new file mode 100644 index 00000000..45bf49eb --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aic/aic_pv_matmul.cpp @@ -0,0 +1,90 @@ +// PV Matmul Kernel: pij(M, K) @ vj(K, N) -> oi_new(M, N) +// +// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16) +// +// pij is float16 (converted from fp32 in softmax_prepare via TCVT). +// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout. +// Standard non-transposed B pattern: ND GlobalB + ColMajor/RowMajor TileMatB. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void pv_matmul_impl(__gm__ uint8_t* pij_raw, __gm__ uint8_t* vj_raw, __gm__ uint8_t* oi_raw) +{ + constexpr int M = 16, K = 16, N = 16; + + __gm__ half* pij = reinterpret_cast<__gm__ half*>(pij_raw); + __gm__ half* vj = reinterpret_cast<__gm__ half*>(vj_raw); + __gm__ float* oi = reinterpret_cast<__gm__ float*>(oi_raw); + + // pij (M, K) fp16, vj (K, N) fp16 in ND (row-major), oi_new (M, N) fp32 + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA pijGlobal(pij); + GlobalB vjGlobal(vj); + GlobalOut oiGlobal(oi); + + // L1 Mat tiles: standard ND pattern for both A and B + using TileMatA = Tile; + using TileMatB = Tile; + + // L0 tiles + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Load pij and vj to L1 + TLOAD(aMatTile, pijGlobal); + TLOAD(bMatTile, vjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // Move to L0A/L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + // Single matmul: (M,K) x (K,N) -> (M,N) + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(oiGlobal, cTile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) +{ + __gm__ uint8_t* pij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* vj = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* oi_new = reinterpret_cast<__gm__ uint8_t*>(args[2]); + + pv_matmul_impl(pij, vj, oi_new); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aic/aic_qk_matmul.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aic/aic_qk_matmul.cpp new file mode 100644 index 00000000..e1e026a2 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aic/aic_qk_matmul.cpp @@ -0,0 +1,91 @@ +// QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) +// +// Fixed tile size: (16, 16) @ (16, 16).T -> (16, 16) +// +// kj is stored as (N, K) = (block_size, head_dim) in row-major memory. +// This is equivalent to (K, N) in column-major (DN) layout. +// Using DN GlobalB + RowMajor/ColMajor TileMatB to handle the transposed B pattern. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void qk_matmul_impl(__gm__ uint8_t* qi_raw, __gm__ uint8_t* kj_raw, __gm__ uint8_t* sij_raw) +{ + constexpr int M = 16, K = 16, N = 16; + + __gm__ half* qi = reinterpret_cast<__gm__ half*>(qi_raw); + __gm__ half* kj = reinterpret_cast<__gm__ half*>(kj_raw); + __gm__ float* sij = reinterpret_cast<__gm__ float*>(sij_raw); + + // qi (M, K) fp16 in ND (row-major) layout + using GlobalA = GlobalTensor, Stride>; + // kj stored as (N, K) row-major = (K, N) column-major -> DN layout + using GlobalB = GlobalTensor, Stride, Layout::DN>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA qiGlobal(qi); + GlobalB kjGlobal(kj); + GlobalOut sijGlobal(sij); + + // L1 Mat tiles: A is standard ND, B uses transposed-B pattern (RowMajor/ColMajor) + using TileMatA = Tile; + using TileMatB = Tile; + + // L0 tiles + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Load qi and kj to L1 + TLOAD(aMatTile, qiGlobal); + TLOAD(bMatTile, kjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // Move to L0A/L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + // Single matmul: (M,K) x (K,N) -> (M,N) + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(sijGlobal, cTile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) +{ + __gm__ uint8_t* qi = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* kj = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* sij = reinterpret_cast<__gm__ uint8_t*>(args[2]); + + qk_matmul_impl(qi, kj, sij); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/aiv_online_update.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/aiv_online_update.cpp new file mode 100644 index 00000000..16e93016 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/aiv_online_update.cpp @@ -0,0 +1,230 @@ +// Online Softmax Update + Normalize Kernel (AIV) +// +// Fixed tile size: oi/oi_new are (16, 16), mij/lij/mi/li are 16-element vectors +// +// Scalar layout strategy: +// M scalar floats stored contiguously in GM can be loaded as either: +// - ND (kScalarRows, kScalarCols) RowMajor for element-wise ops (TMAX, TSUB, TEXP, TMUL, TADD) +// - DN (kAlignedRows, 1) ColMajor for row-broadcast ops (TROWEXPANDMUL, TROWEXPANDDIV) +// Conversion between layouts uses GM round-trip: ND TSTORE -> DN TLOAD. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void online_update_impl(__gm__ uint8_t* mij_raw, __gm__ uint8_t* lij_raw, + __gm__ uint8_t* oi_new_raw, __gm__ uint8_t* mi_raw, + __gm__ uint8_t* li_raw, __gm__ uint8_t* oi_raw, + int is_first, int is_last, __gm__ uint8_t* dst_raw) +{ + constexpr int M = 16, N = 16; + + __gm__ float* mij_ptr = reinterpret_cast<__gm__ float*>(mij_raw); + __gm__ float* lij_ptr = reinterpret_cast<__gm__ float*>(lij_raw); + __gm__ float* oi_new_ptr = reinterpret_cast<__gm__ float*>(oi_new_raw); + __gm__ float* mi_ptr = reinterpret_cast<__gm__ float*>(mi_raw); + __gm__ float* li_ptr = reinterpret_cast<__gm__ float*>(li_raw); + __gm__ float* oi_ptr = reinterpret_cast<__gm__ float*>(oi_raw); + __gm__ float* dst_ptr = reinterpret_cast<__gm__ float*>(dst_raw); + + // Scalar tile dimensions for RowMajor layout: + // kScalarCols = 32 bytes / 4 bytes per float = 8 floats per row (one 32-byte block) + // kScalarRows = M / 8 (M=16 -> 2 rows) + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + // Aligned rows for ColMajor DN tiles (32-byte alignment) + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + // --- GlobalTensor types --- + + // Data (M, N) RowMajor + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + + // Scalar ND: M contiguous floats as (kScalarRows, kScalarCols) RowMajor + using GlobalScalarND = GlobalTensor, + Stride<1, 1, 1, kScalarCols, 1>>; + + // Scalar DN: same M contiguous floats as (kAlignedRows, 1) ColMajor + using GlobalScalarDN = GlobalTensor, + Stride<1, 1, 1, 1, 1>, Layout::DN>; + + // --- GlobalTensor instances --- + + GlobalDataMxN oiNewGlobal(oi_new_ptr); + GlobalDataMxN oiGlobal(oi_ptr); + GlobalDataMxN dstGlobal(dst_ptr); + + // ND globals for scalar element-wise operations + GlobalScalarND mijGlobalND(mij_ptr); + GlobalScalarND lijGlobalND(lij_ptr); + GlobalScalarND miGlobalND(mi_ptr); + GlobalScalarND liGlobalND(li_ptr); + + // DN globals aliased to same GM for ColMajor reload (used after ND TSTORE) + GlobalScalarDN mijGlobalDN(mij_ptr); + GlobalScalarDN lijGlobalDN(lij_ptr); + GlobalScalarDN liGlobalDN(li_ptr); + + // --- Tile types --- + + using TileDataMxN = Tile; + using TileScalarND = Tile; + using TileScalarDN = Tile; + + // --- UB memory layout --- + + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarNDBytes = kScalarRows * kScalarCols * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + // Data tiles + TileDataMxN oiNewTile; + TileDataMxN oiTile; + + // Scalar ND tiles for element-wise arithmetic + TileScalarND mijND, lijND, miND, liND; + TileScalarND miNewND, alphaND, betaND, tmpND; + + // Scalar DN tiles for TROWEXPAND operations + TileScalarDN alphaDN, betaDN, liDN; + + TASSIGN(oiNewTile, 0); + TASSIGN(oiTile, kDataBytes); + TASSIGN(mijND, 2 * kDataBytes); + TASSIGN(lijND, 2 * kDataBytes + kScalarNDBytes); + TASSIGN(miND, 2 * kDataBytes + 2 * kScalarNDBytes); + TASSIGN(liND, 2 * kDataBytes + 3 * kScalarNDBytes); + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarNDBytes); + TASSIGN(alphaND, 2 * kDataBytes + 5 * kScalarNDBytes); + TASSIGN(betaND, 2 * kDataBytes + 6 * kScalarNDBytes); + TASSIGN(tmpND, 2 * kDataBytes + 7 * kScalarNDBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 8 * kScalarNDBytes); + TASSIGN(betaDN, 2 * kDataBytes + 8 * kScalarNDBytes + kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 8 * kScalarNDBytes + 2 * kScalarDNBytes); + + if (is_first) { + // --- First block: copy inputs to accumulators --- + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Passthrough to MTE3 (no V compute needed) + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, mijND); // mi = mij + TSTORE(liGlobalND, lijND); // li = lij + TSTORE(oiGlobal, oiNewTile); // oi = oi_new + + if (is_last) { + // Single block: normalize dst = oi_new / lij + // lij stored to li buffer in ND format; reload as DN for TROWEXPANDDIV + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(liDN, liGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDDIV(oiNewTile, oiNewTile, liDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiNewTile); + } + } else { + // --- Subsequent blocks: accumulate --- + + // Phase 1: Load all inputs + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(oiTile, oiGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + TLOAD(miND, miGlobalND); + TLOAD(liND, liGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Phase 2: Scalar arithmetic in RowMajor (kScalarRows, kScalarCols) + // pipe_barrier(PIPE_V) required between each dependent vector operation + // to resolve RAW hazards on shared UB tiles. + TMAX(miNewND, miND, mijND); // mi_new = max(mi, mij) + pipe_barrier(PIPE_V); + TSUB(alphaND, miND, miNewND); // alpha = mi - mi_new + pipe_barrier(PIPE_V); + TEXP(alphaND, alphaND); // alpha = exp(mi - mi_new) + pipe_barrier(PIPE_V); + TSUB(betaND, mijND, miNewND); // beta = mij - mi_new + pipe_barrier(PIPE_V); + TEXP(betaND, betaND); // beta = exp(mij - mi_new) + pipe_barrier(PIPE_V); + TMUL(liND, alphaND, liND); // li = alpha * li + pipe_barrier(PIPE_V); + TMUL(tmpND, betaND, lijND); // tmp = beta * lij + pipe_barrier(PIPE_V); + TADD(liND, liND, tmpND); // li = alpha * li + beta * lij (= li_new) + + // Phase 3: Store scalar results to GM (ND format) + // mi_new -> mi accumulator, li_new -> li accumulator + // alpha -> mij buffer (reuse), beta -> lij buffer (reuse) + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liND); // persist li_new + TSTORE(mijGlobalND, alphaND); // temp: alpha to mij buffer + TSTORE(lijGlobalND, betaND); // temp: beta to lij buffer + + // Phase 4: Reload alpha, beta (and li if last) as ColMajor DN + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(alphaDN, mijGlobalDN); // alpha from mij buffer as DN + TLOAD(betaDN, lijGlobalDN); // beta from lij buffer as DN + if (is_last) { + TLOAD(liDN, liGlobalDN); // li_new from li buffer as DN + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + + // Phase 5: Scale data tiles using row-broadcast multiply + TROWEXPANDMUL(oiTile, oiTile, alphaDN); // oi *= alpha + TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); // oi_new *= beta + pipe_barrier(PIPE_V); + TADD(oiTile, oiTile, oiNewTile); // oi = alpha*oi + beta*oi_new + + if (is_last) { + // Phase 6: Normalize and output + pipe_barrier(PIPE_V); + TROWEXPANDDIV(oiTile, oiTile, liDN); // dst = oi / li_new + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiTile); + } else { + // Phase 6: Store updated accumulators + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(oiGlobal, oiTile); + } + } +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ uint8_t* mij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* lij = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* oi_new = reinterpret_cast<__gm__ uint8_t*>(args[2]); + __gm__ uint8_t* mi = reinterpret_cast<__gm__ uint8_t*>(args[3]); + __gm__ uint8_t* li = reinterpret_cast<__gm__ uint8_t*>(args[4]); + __gm__ uint8_t* oi = reinterpret_cast<__gm__ uint8_t*>(args[5]); + int is_first = static_cast(args[6]); + int is_last = static_cast(args[7]); + __gm__ uint8_t* dst = reinterpret_cast<__gm__ uint8_t*>(args[8]); + + online_update_impl(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/aiv_softmax_prepare.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/aiv_softmax_prepare.cpp new file mode 100644 index 00000000..6715cf07 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/aiv_softmax_prepare.cpp @@ -0,0 +1,94 @@ +// Softmax Preparation Kernel (AIV) +// +// Fixed tile size: sij is (16, 16) +// +// Computes: +// sij_scale = sij * scale +// mij = row_max(sij_scale) -> (M, 1) +// pij = exp(sij_scale - mij) -> (M, N) +// lij = row_sum(pij) -> (M, 1) + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void softmax_prepare_impl(__gm__ uint8_t* sij_raw, float scale_value, + __gm__ uint8_t* pij_raw, __gm__ uint8_t* mij_raw, + __gm__ uint8_t* lij_raw) +{ + constexpr int M = 16, N = 16; + + __gm__ float* sij = reinterpret_cast<__gm__ float*>(sij_raw); + __gm__ half* pij = reinterpret_cast<__gm__ half*>(pij_raw); + __gm__ float* mij = reinterpret_cast<__gm__ float*>(mij_raw); + __gm__ float* lij = reinterpret_cast<__gm__ float*>(lij_raw); + + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalDataMxN_f16 = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + GlobalDataMxN sijGlobal(sij); + GlobalDataMxN_f16 pijGlobal(pij); + GlobalScalarDN mijGlobal(mij); + GlobalScalarDN lijGlobal(lij); + + using TileVecMxN = Tile; + using TileVecMxN_f16 = Tile; + using TileScalarDN = Tile; + + TileVecMxN sijTile; + TileVecMxN pijTile; + TileVecMxN tmpTile; + TileScalarDN maxTile; + TileScalarDN sumTile; + TileVecMxN_f16 pijF16Tile; + + TASSIGN(sijTile, 0x0); + TASSIGN(pijTile, M * N * sizeof(float)); + TASSIGN(tmpTile, 2 * M * N * sizeof(float)); + TASSIGN(maxTile, 3 * M * N * sizeof(float)); + TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float)); + TASSIGN(pijF16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float)); + + TLOAD(sijTile, sijGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TMULS(sijTile, sijTile, scale_value); + TROWMAX(maxTile, sijTile, tmpTile); + TROWEXPANDSUB(pijTile, sijTile, maxTile); + TEXP(pijTile, pijTile); + // Truncate pij to fp16 first, then compute lij from truncated values (matches golden) + TCVT(pijF16Tile, pijTile, RoundMode::CAST_ROUND); + TCVT(pijTile, pijF16Tile, RoundMode::CAST_ROUND); + TROWSUM(sumTile, pijTile, tmpTile); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobal, maxTile); + TSTORE(lijGlobal, sumTile); + TSTORE(pijGlobal, pijF16Tile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ uint8_t* sij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + union { uint64_t u; float f; } scale_conv; + scale_conv.u = static_cast(args[1]); + float scale_value = scale_conv.f; + __gm__ uint8_t* pij = reinterpret_cast<__gm__ uint8_t*>(args[2]); + __gm__ uint8_t* mij = reinterpret_cast<__gm__ uint8_t*>(args[3]); + __gm__ uint8_t* lij = reinterpret_cast<__gm__ uint8_t*>(args[4]); + + softmax_prepare_impl(sij, scale_value, pij, mij, lij); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp new file mode 100644 index 00000000..6a93a85c --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp @@ -0,0 +1,64 @@ +/** + * Manual AllGather kernel - direct RDMA reads, no TGATHER. + * + * Each rank independently reads from all ranks' win_src via HcclRemotePtr + * and writes to local dst. No collective TGATHER call, so no deadlock. + * All ranks can run in parallel (single kernel, single barrier). + * + * Args: dst, src, ctx, nranks, rank_id (rank_id unused, for API compatibility) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* src = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[2]); + int nranks = static_cast(args[3]); + (void)args[4]; /* rank_id unused */ + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + ShapeDyn sliceShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn sliceStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int r = 0; r < actual_nranks; ++r) { + __gm__ float* remote_src = HcclRemotePtr(hcclCtx, src, r); + __gm__ float* local_dst = dst + static_cast(r) * GATHER_COUNT; + + Global srcG(remote_src, sliceShape, sliceStride); + Global dstG(local_dst, sliceShape, sliceStride); + + TLOAD(ubTile, srcG); + set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + TSTORE(dstG, ubTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp new file mode 100644 index 00000000..b1711665 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp @@ -0,0 +1,52 @@ +/** + * All-to-all barrier (多对多): every rank waits for every other rank. + * + * Used by AllGather where every rank reads from all ranks' windows. + * Unlike comm_barrier_kernel (many-to-one), ALL ranks do TWAIT here. + * + * Flow: + * 1. Each rank TNOTIFY to root's barrier slot[my_rank] + * 2. Each rank TWAIT on root's barrier until all n_ranks slots >= 1 + * + * Args: + * args[0] = barrier_base (local barrier buffer; root's is used for sync) + * args[1] = device_ctx_ptr (HcclDeviceContext*) + * args[2] = n_ranks + * args[3] = root (whose barrier buffer is the sync point) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(args[0]); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + // Step 1: Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Step 2: ALL ranks wait until every rank's flag is >= 1 (multi-to-multi). + __gm__ int32_t* root_barrier = HcclRemotePtr(ctx, local_barrier, root); + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(root_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/window_memcopy_in.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..38408baa --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before AllGather so remote ranks can read. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/window_memcopy_out.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..99e83e76 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * After AllGather, every rank copies gathered result to device. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/kernel_config.py b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/kernel_config.py new file mode 100644 index 00000000..a530a52d --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/kernel_config.py @@ -0,0 +1,39 @@ +""" +Paged Attention + AllGather (Manual): Paged Attention → AllGather. + +Flow per rank: + QK → Softmax → PV → OnlineUpdate (paged attention, possibly multi-block) + → WindowMemCopyIn → CommBarrier → AllGatherManual (direct RDMA reads) + → WindowMemCopyOut → CommBarrier(post) + +All ranks get the full allgather output (concatenation of all ranks' attn_out prefix). +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "paged_attention_allgather_orch.cpp"), + "function_name": "build_paged_attention_allgather_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "QK", "source": str(_KERNELS_ROOT / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 1, "name": "SF", "source": str(_KERNELS_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "PV", "source": str(_KERNELS_ROOT / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 3, "name": "UP", "source": str(_KERNELS_ROOT / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 4, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 5, "name": "AllGatherManual", "source": str(_KERNELS_ROOT / "aiv" / "allgather_manual_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 6, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 7, "name": "CommBarrierAll", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, + "aicpu_thread_num": 3, + "block_dim": 3, +} diff --git a/examples/host_build_graph/paged_attention_allgather_Manual/kernels/orchestration/paged_attention_allgather_orch.cpp b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/orchestration/paged_attention_allgather_orch.cpp new file mode 100644 index 00000000..5484c070 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Manual/kernels/orchestration/paged_attention_allgather_orch.cpp @@ -0,0 +1,290 @@ +/** + * Paged Attention + AllGather (Manual): Paged Attention → AllGather (direct RDMA). + * + * Phase 1: QK → Softmax → PV → OnlineUpdate (paged attention) + * Phase 2: WindowMemCopyIn → CommBarrier(pre) → AllGatherManual → WindowMemCopyOut → CommBarrier(post) + * All ranks get the full allgather output. + */ + +#include "runtime.h" +#include +#include +#include +#include + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 +#define FUNC_WIN_MEMCOPY_IN 4 +#define FUNC_ALLGATHER 5 +#define FUNC_WIN_MEMCOPY_OUT 6 +#define FUNC_COMM_BARRIER 7 + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +extern "C" { + +int build_paged_attention_allgather_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 22) { + std::cerr << "build_paged_attention_allgather_graph: Expected at least 22 args, got " << arg_count << '\n'; + return -1; + } + + void* host_query = reinterpret_cast(args[0]); + void* host_key_cache = reinterpret_cast(args[1]); + void* host_value_cache = reinterpret_cast(args[2]); + int* host_block_table = reinterpret_cast(args[3]); + int* host_context_lens = reinterpret_cast(args[4]); + void* host_attn_out = reinterpret_cast(args[5]); + void* host_allgather_out = reinterpret_cast(args[6]); + int64_t* host_config = reinterpret_cast(args[7]); + + size_t query_size = static_cast(args[8]); + size_t key_cache_size = static_cast(args[9]); + size_t value_cache_size = static_cast(args[10]); + size_t block_table_size = static_cast(args[11]); + size_t context_lens_size = static_cast(args[12]); + size_t attn_out_size = static_cast(args[13]); + size_t allgather_out_size = static_cast(args[14]); + size_t config_size = static_cast(args[15]); + uint64_t device_ctx_ptr = args[16]; + uint64_t win_in_base = args[17]; + uint64_t win_out_base = args[18]; + int n_ranks = static_cast(args[19]); + int root = static_cast(args[20]); + int rank_id = static_cast(args[21]); + + int batch = static_cast(host_config[0]); + int num_heads = static_cast(host_config[1]); + int kv_head_num = static_cast(host_config[2]); + int head_dim = static_cast(host_config[3]); + int block_size = static_cast(host_config[4]); + int max_num_blocks = static_cast(host_config[5]); + uint64_t scale_value_bits = static_cast(host_config[6]); + + int q_tile_size = std::min(num_heads, 128); + int num_head_tiles = (num_heads + q_tile_size - 1) / q_tile_size; + + std::cout << "\n=== build_paged_attention_allgather_graph (Manual) ===" << '\n'; + std::cout << " n_ranks=" << n_ranks << " rank_id=" << rank_id << '\n'; + + void* dev_query = runtime->host_api.device_malloc(query_size); + void* dev_key_cache = runtime->host_api.device_malloc(key_cache_size); + void* dev_value_cache = runtime->host_api.device_malloc(value_cache_size); + void* dev_attn_out = runtime->host_api.device_malloc(attn_out_size); + void* dev_allgather_out = runtime->host_api.device_malloc(allgather_out_size); + + if (!dev_query || !dev_key_cache || !dev_value_cache || !dev_attn_out || !dev_allgather_out) { + std::cerr << "Error: Failed to allocate device memory\n"; + return -1; + } + + runtime->host_api.copy_to_device(dev_query, host_query, query_size); + runtime->host_api.copy_to_device(dev_key_cache, host_key_cache, key_cache_size); + runtime->host_api.copy_to_device(dev_value_cache, host_value_cache, value_cache_size); + runtime->record_tensor_pair(host_attn_out, dev_attn_out, attn_out_size); + runtime->record_tensor_pair(host_allgather_out, dev_allgather_out, allgather_out_size); + + size_t sij_size = static_cast(q_tile_size) * block_size * sizeof(float); + size_t pij_size = static_cast(q_tile_size) * block_size * sizeof(uint16_t); + size_t mij_size = static_cast(q_tile_size) * sizeof(float); + size_t lij_size = mij_size; + size_t oi_new_size = static_cast(q_tile_size) * head_dim * sizeof(float); + + int total_buffers = batch * max_num_blocks; + void** dev_sij_arr = new void*[total_buffers]; + void** dev_pij_arr = new void*[total_buffers]; + void** dev_mij_arr = new void*[total_buffers]; + void** dev_lij_arr = new void*[total_buffers]; + void** dev_oi_new_arr = new void*[total_buffers]; + + for (int i = 0; i < total_buffers; i++) { + dev_sij_arr[i] = runtime->host_api.device_malloc(sij_size); + dev_pij_arr[i] = runtime->host_api.device_malloc(pij_size); + dev_mij_arr[i] = runtime->host_api.device_malloc(mij_size); + dev_lij_arr[i] = runtime->host_api.device_malloc(lij_size); + dev_oi_new_arr[i] = runtime->host_api.device_malloc(oi_new_size); + } + + int total_accums = batch * num_head_tiles; + size_t mi_size = static_cast(q_tile_size) * sizeof(float); + size_t li_size = mi_size; + size_t oi_size = static_cast(q_tile_size) * head_dim * sizeof(float); + + void** dev_mi_arr = new void*[total_accums]; + void** dev_li_arr = new void*[total_accums]; + void** dev_oi_arr = new void*[total_accums]; + + for (int i = 0; i < total_accums; i++) { + dev_mi_arr[i] = runtime->host_api.device_malloc(mi_size); + dev_li_arr[i] = runtime->host_api.device_malloc(li_size); + dev_oi_arr[i] = runtime->host_api.device_malloc(oi_size); + } + + std::vector last_pa_tasks; + + for (int b_idx = 0; b_idx < batch; b_idx++) { + int cur_seq = host_context_lens[b_idx]; + int bn_this_batch = (cur_seq + block_size - 1) / block_size; + + for (int ht = 0; ht < num_head_tiles; ht++) { + int cur_offset = ht * q_tile_size; + uint8_t* qi_ptr = reinterpret_cast(dev_query) + + static_cast(b_idx * num_heads + cur_offset) * head_dim * sizeof(uint16_t); + uint8_t* out_ptr = reinterpret_cast(dev_attn_out) + + static_cast(b_idx * num_heads + cur_offset) * head_dim * sizeof(float); + int kv_head_idx = cur_offset / (num_heads / kv_head_num); + int accum_idx = b_idx * num_head_tiles + ht; + void* dev_mi = dev_mi_arr[accum_idx]; + void* dev_li = dev_li_arr[accum_idx]; + void* dev_oi = dev_oi_arr[accum_idx]; + + int t_up_prev = -1; + + for (int bn = 0; bn < bn_this_batch; bn++) { + int cur_block_idx = host_block_table[b_idx * max_num_blocks + bn]; + uint8_t* kj_ptr = reinterpret_cast(dev_key_cache) + + (static_cast(cur_block_idx) * block_size * kv_head_num + kv_head_idx) + * head_dim * sizeof(uint16_t); + uint8_t* vj_ptr = reinterpret_cast(dev_value_cache) + + (static_cast(cur_block_idx) * block_size * kv_head_num + kv_head_idx) + * head_dim * sizeof(uint16_t); + + int buf_idx = b_idx * max_num_blocks + bn; + void* dev_sij = dev_sij_arr[buf_idx]; + void* dev_pij = dev_pij_arr[buf_idx]; + void* dev_mij = dev_mij_arr[buf_idx]; + void* dev_lij = dev_lij_arr[buf_idx]; + void* dev_oi_new = dev_oi_new_arr[buf_idx]; + + uint64_t qk_args[6] = { + reinterpret_cast(qi_ptr), + reinterpret_cast(kj_ptr), + reinterpret_cast(dev_sij), + static_cast(q_tile_size), + static_cast(head_dim), + static_cast(block_size) + }; + int t_qk = runtime->add_task(qk_args, 6, FUNC_QK_MATMUL, CoreType::AIC); + + uint64_t sf_args[7] = { + reinterpret_cast(dev_sij), + scale_value_bits, + reinterpret_cast(dev_pij), + reinterpret_cast(dev_mij), + reinterpret_cast(dev_lij), + static_cast(q_tile_size), + static_cast(block_size) + }; + int t_sf = runtime->add_task(sf_args, 7, FUNC_SOFTMAX_PREPARE, CoreType::AIV); + + uint64_t pv_args[6] = { + reinterpret_cast(dev_pij), + reinterpret_cast(vj_ptr), + reinterpret_cast(dev_oi_new), + static_cast(q_tile_size), + static_cast(block_size), + static_cast(head_dim) + }; + int t_pv = runtime->add_task(pv_args, 6, FUNC_PV_MATMUL, CoreType::AIC); + + runtime->add_successor(t_qk, t_sf); + runtime->add_successor(t_sf, t_pv); + + int is_first = (bn == 0) ? 1 : 0; + int is_last = (bn == bn_this_batch - 1) ? 1 : 0; + + uint64_t up_args[11] = { + reinterpret_cast(dev_mij), + reinterpret_cast(dev_lij), + reinterpret_cast(dev_oi_new), + reinterpret_cast(dev_mi), + reinterpret_cast(dev_li), + reinterpret_cast(dev_oi), + static_cast(is_first), + static_cast(is_last), + reinterpret_cast(out_ptr), + static_cast(q_tile_size), + static_cast(head_dim) + }; + int t_up = runtime->add_task(up_args, 11, FUNC_ONLINE_UPDATE, CoreType::AIV); + + runtime->add_successor(t_pv, t_up); + if (t_up_prev >= 0) { + runtime->add_successor(t_up_prev, t_up); + } + t_up_prev = t_up; + } + last_pa_tasks.push_back(t_up_prev); + } + } + + /* Phase 2: AllGather (Manual) */ + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base_pre = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t barrier_base_post = barrier_base_pre + barrier_size; + uint64_t win_src = barrier_base_post + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + int32_t zeros[64] = {}; + std::memset(zeros, 0, sizeof(zeros)); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base_pre), zeros, barrier_size); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base_post), zeros, barrier_size); + + uint64_t args_wmin[3] = { + win_src, + reinterpret_cast(dev_attn_out), + static_cast(GATHER_COUNT) + }; + int t_wmin = runtime->add_task(args_wmin, 3, FUNC_WIN_MEMCOPY_IN, CoreType::AIV); + for (int t : last_pa_tasks) { + runtime->add_successor(t, t_wmin); + } + + uint64_t args_barrier_pre[4] = { + barrier_base_pre, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t_barrier_pre = runtime->add_task(args_barrier_pre, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t_wmin, t_barrier_pre); + + uint64_t args_allgather[5] = { + win_dst, win_src, device_ctx_ptr, + static_cast(n_ranks), static_cast(rank_id) + }; + int t_allgather = runtime->add_task(args_allgather, 5, FUNC_ALLGATHER, CoreType::AIV); + runtime->add_successor(t_barrier_pre, t_allgather); + + uint64_t args_wmout[3] = { + reinterpret_cast(dev_allgather_out), + win_dst, + static_cast(n_ranks * GATHER_COUNT) + }; + int t_wmout = runtime->add_task(args_wmout, 3, FUNC_WIN_MEMCOPY_OUT, CoreType::AIV); + runtime->add_successor(t_allgather, t_wmout); + + uint64_t args_barrier_post[4] = { + barrier_base_post, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t_barrier_post = runtime->add_task(args_barrier_post, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t_wmout, t_barrier_post); + + delete[] dev_sij_arr; + delete[] dev_pij_arr; + delete[] dev_mij_arr; + delete[] dev_lij_arr; + delete[] dev_oi_new_arr; + delete[] dev_mi_arr; + delete[] dev_li_arr; + delete[] dev_oi_arr; + + std::cout << "Created paged_attention_allgather (Manual) graph\n"; + runtime->print_runtime(); + + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/README.md b/examples/host_build_graph/paged_attention_allgather_Tgather/README.md new file mode 100644 index 00000000..4fc1eb54 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/README.md @@ -0,0 +1,12 @@ +# Paged Attention + AllGather (TGATHER) - host_build_graph + +Paged Attention 计算后 AllGather(N 轮 TGATHER)。 + +流程:QK → Softmax → PV → OnlineUpdate → WindowMemCopyIn +→ for r in [0,n_ranks): Barrier → Gather(root=r) → [rank r: WindowMemCopyOut] → Barrier(post) + +## 运行 + +```bash +./run_hostbuild.sh paged_attention_allgather_Tgather 2 0 +``` diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/golden.py b/examples/host_build_graph/paged_attention_allgather_Tgather/golden.py new file mode 100644 index 00000000..6eb31afd --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/golden.py @@ -0,0 +1,144 @@ +""" +Paged Attention + AllGather (TGATHER): Paged Attention → N sequential Gathers. + +Same golden logic as paged_attention_allgather_Manual (output is identical). +""" + +import ctypes +import struct +import torch +import numpy as np + +GATHER_COUNT = 64 +BATCH = 1 +NUM_HEADS = 16 +KV_HEAD_NUM = 1 +HEAD_DIM = 16 +BLOCK_SIZE = 16 +CONTEXT_LEN = 16 +MAX_MODEL_LEN = 256 + +__outputs__ = ["attn_out", "allgather_out"] +RTOL = 1e-2 +ATOL = 1e-2 +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" + +def _make_block_table_and_context(): + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + cur_valid_blocks = (CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE + total_blocks = BATCH * cur_valid_blocks + torch.manual_seed(100) + block_table = torch.randint(0, max(total_blocks, 1), size=(BATCH, max_num_blocks_per_req), dtype=torch.int32) + context_lens = torch.full((BATCH,), CONTEXT_LEN, dtype=torch.int32) + return block_table, context_lens, total_blocks, max_num_blocks_per_req + +def _make_qkv(rank_id, total_blocks): + torch.manual_seed(42 + rank_id) + q = (torch.rand(BATCH, 1, NUM_HEADS * HEAD_DIM) - 0.5).to(torch.float16) + q = q.reshape(BATCH, NUM_HEADS, HEAD_DIM) + k = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) - 0.5).to(torch.float16) + v = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) * 2 - 1).to(torch.float16) + return q, k, v + +def generate_inputs(params: dict) -> list: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + block_table, context_lens, total_blocks, max_num_blocks_per_req = _make_block_table_and_context() + query_fp16, key_fp16, value_fp16 = _make_qkv(rank_id, total_blocks) + scale_value = 1.0 + scale_bits = struct.unpack('I', struct.pack('f', scale_value))[0] + config = torch.tensor([BATCH, NUM_HEADS, KV_HEAD_NUM, HEAD_DIM, BLOCK_SIZE, max_num_blocks_per_req, scale_bits], dtype=torch.int64) + query = query_fp16.flatten() + key_cache = key_fp16.flatten() + value_cache = value_fp16.flatten() + block_table_flat = block_table.flatten() + attn_out = torch.zeros(BATCH * NUM_HEADS * HEAD_DIM, dtype=torch.float32) + allgather_out = torch.zeros(n_ranks * GATHER_COUNT, dtype=torch.float32) + result = [ + ("query", query), ("key_cache", key_cache), ("value_cache", value_cache), + ("block_table", block_table_flat), ("context_lens", context_lens), + ("attn_out", attn_out), ("allgather_out", allgather_out), ("config", config), + ("size_query", ctypes.c_int64(query.nbytes)), ("size_key_cache", ctypes.c_int64(key_cache.nbytes)), + ("size_value_cache", ctypes.c_int64(value_cache.nbytes)), ("size_block_table", ctypes.c_int64(block_table_flat.nbytes)), + ("size_context_lens", ctypes.c_int64(context_lens.nbytes)), ("size_attn_out", ctypes.c_int64(attn_out.nbytes)), + ("size_allgather_out", ctypes.c_int64(allgather_out.nbytes)), ("size_config", ctypes.c_int64(config.nbytes)), + ] + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), ("root", ctypes.c_int32(root)), ("rank_id", ctypes.c_int32(rank_id)), + ]) + return result + +def paged_attention(query, key_cache, value_cache, num_kv_heads, num_heads, scale_value, block_table, context_lens): + assert num_kv_heads == 1 + batch, num_heads_dim, head_dim = query.shape + _, block_size, _, _ = key_cache.shape + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + q_tile = min(num_heads_dim, 128) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + oi, li, mi = None, None, None + for bn in range(max_bn): + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 + if not active_mask.any(): break + block_indices = block_table[:, bn] + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) + valid_mask = pos < valid_lens.unsqueeze(1) + valid_mask = valid_mask.unsqueeze(1) + sij = sij.masked_fill(~valid_mask, float('-inf')) + batch_mask = active_mask.view(-1, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + mij = sij.max(dim=-1, keepdim=True)[0] + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) + oi_new = torch.bmm(pij, vj_all) + if bn == 0: + oi, li, mi = oi_new, lij, mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + return out.reshape(-1, head_dim) + +def _compute_rank_attn(rank_id, block_table, context_lens, total_blocks): + q, k, v = _make_qkv(rank_id, total_blocks) + return paged_attention(q, k.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), v.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), + KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens) + +def compute_golden(tensors: dict, params: dict) -> None: + n_ranks = params.get("n_ranks", 2) + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + total_blocks = BATCH * ((CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE) + block_table = tensors["block_table"].reshape(BATCH, max_num_blocks_per_req) + context_lens_t = tensors["context_lens"] + query = tensors["query"].reshape(BATCH, NUM_HEADS, HEAD_DIM) + key_cache = tensors["key_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + value_cache = tensors["value_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + attn_result = paged_attention(query, key_cache, value_cache, KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens_t) + tensors["attn_out"][:] = attn_result.flatten() + allgather_np = tensors["allgather_out"].cpu().numpy() if hasattr(tensors["allgather_out"], 'cpu') else np.asarray(tensors["allgather_out"]) + for r in range(n_ranks): + attn_r = _compute_rank_attn(r, block_table, context_lens_t, total_blocks) + flat_r = attn_r.flatten().numpy() + allgather_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = flat_r[:GATHER_COUNT] diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aic/aic_pv_matmul.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aic/aic_pv_matmul.cpp new file mode 100644 index 00000000..45bf49eb --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aic/aic_pv_matmul.cpp @@ -0,0 +1,90 @@ +// PV Matmul Kernel: pij(M, K) @ vj(K, N) -> oi_new(M, N) +// +// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16) +// +// pij is float16 (converted from fp32 in softmax_prepare via TCVT). +// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout. +// Standard non-transposed B pattern: ND GlobalB + ColMajor/RowMajor TileMatB. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void pv_matmul_impl(__gm__ uint8_t* pij_raw, __gm__ uint8_t* vj_raw, __gm__ uint8_t* oi_raw) +{ + constexpr int M = 16, K = 16, N = 16; + + __gm__ half* pij = reinterpret_cast<__gm__ half*>(pij_raw); + __gm__ half* vj = reinterpret_cast<__gm__ half*>(vj_raw); + __gm__ float* oi = reinterpret_cast<__gm__ float*>(oi_raw); + + // pij (M, K) fp16, vj (K, N) fp16 in ND (row-major), oi_new (M, N) fp32 + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA pijGlobal(pij); + GlobalB vjGlobal(vj); + GlobalOut oiGlobal(oi); + + // L1 Mat tiles: standard ND pattern for both A and B + using TileMatA = Tile; + using TileMatB = Tile; + + // L0 tiles + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Load pij and vj to L1 + TLOAD(aMatTile, pijGlobal); + TLOAD(bMatTile, vjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // Move to L0A/L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + // Single matmul: (M,K) x (K,N) -> (M,N) + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(oiGlobal, cTile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) +{ + __gm__ uint8_t* pij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* vj = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* oi_new = reinterpret_cast<__gm__ uint8_t*>(args[2]); + + pv_matmul_impl(pij, vj, oi_new); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aic/aic_qk_matmul.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aic/aic_qk_matmul.cpp new file mode 100644 index 00000000..e1e026a2 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aic/aic_qk_matmul.cpp @@ -0,0 +1,91 @@ +// QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) +// +// Fixed tile size: (16, 16) @ (16, 16).T -> (16, 16) +// +// kj is stored as (N, K) = (block_size, head_dim) in row-major memory. +// This is equivalent to (K, N) in column-major (DN) layout. +// Using DN GlobalB + RowMajor/ColMajor TileMatB to handle the transposed B pattern. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void qk_matmul_impl(__gm__ uint8_t* qi_raw, __gm__ uint8_t* kj_raw, __gm__ uint8_t* sij_raw) +{ + constexpr int M = 16, K = 16, N = 16; + + __gm__ half* qi = reinterpret_cast<__gm__ half*>(qi_raw); + __gm__ half* kj = reinterpret_cast<__gm__ half*>(kj_raw); + __gm__ float* sij = reinterpret_cast<__gm__ float*>(sij_raw); + + // qi (M, K) fp16 in ND (row-major) layout + using GlobalA = GlobalTensor, Stride>; + // kj stored as (N, K) row-major = (K, N) column-major -> DN layout + using GlobalB = GlobalTensor, Stride, Layout::DN>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA qiGlobal(qi); + GlobalB kjGlobal(kj); + GlobalOut sijGlobal(sij); + + // L1 Mat tiles: A is standard ND, B uses transposed-B pattern (RowMajor/ColMajor) + using TileMatA = Tile; + using TileMatB = Tile; + + // L0 tiles + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Load qi and kj to L1 + TLOAD(aMatTile, qiGlobal); + TLOAD(bMatTile, kjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // Move to L0A/L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + // Single matmul: (M,K) x (K,N) -> (M,N) + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(sijGlobal, cTile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) +{ + __gm__ uint8_t* qi = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* kj = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* sij = reinterpret_cast<__gm__ uint8_t*>(args[2]); + + qk_matmul_impl(qi, kj, sij); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/aiv_online_update.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/aiv_online_update.cpp new file mode 100644 index 00000000..16e93016 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/aiv_online_update.cpp @@ -0,0 +1,230 @@ +// Online Softmax Update + Normalize Kernel (AIV) +// +// Fixed tile size: oi/oi_new are (16, 16), mij/lij/mi/li are 16-element vectors +// +// Scalar layout strategy: +// M scalar floats stored contiguously in GM can be loaded as either: +// - ND (kScalarRows, kScalarCols) RowMajor for element-wise ops (TMAX, TSUB, TEXP, TMUL, TADD) +// - DN (kAlignedRows, 1) ColMajor for row-broadcast ops (TROWEXPANDMUL, TROWEXPANDDIV) +// Conversion between layouts uses GM round-trip: ND TSTORE -> DN TLOAD. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void online_update_impl(__gm__ uint8_t* mij_raw, __gm__ uint8_t* lij_raw, + __gm__ uint8_t* oi_new_raw, __gm__ uint8_t* mi_raw, + __gm__ uint8_t* li_raw, __gm__ uint8_t* oi_raw, + int is_first, int is_last, __gm__ uint8_t* dst_raw) +{ + constexpr int M = 16, N = 16; + + __gm__ float* mij_ptr = reinterpret_cast<__gm__ float*>(mij_raw); + __gm__ float* lij_ptr = reinterpret_cast<__gm__ float*>(lij_raw); + __gm__ float* oi_new_ptr = reinterpret_cast<__gm__ float*>(oi_new_raw); + __gm__ float* mi_ptr = reinterpret_cast<__gm__ float*>(mi_raw); + __gm__ float* li_ptr = reinterpret_cast<__gm__ float*>(li_raw); + __gm__ float* oi_ptr = reinterpret_cast<__gm__ float*>(oi_raw); + __gm__ float* dst_ptr = reinterpret_cast<__gm__ float*>(dst_raw); + + // Scalar tile dimensions for RowMajor layout: + // kScalarCols = 32 bytes / 4 bytes per float = 8 floats per row (one 32-byte block) + // kScalarRows = M / 8 (M=16 -> 2 rows) + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + // Aligned rows for ColMajor DN tiles (32-byte alignment) + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + // --- GlobalTensor types --- + + // Data (M, N) RowMajor + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + + // Scalar ND: M contiguous floats as (kScalarRows, kScalarCols) RowMajor + using GlobalScalarND = GlobalTensor, + Stride<1, 1, 1, kScalarCols, 1>>; + + // Scalar DN: same M contiguous floats as (kAlignedRows, 1) ColMajor + using GlobalScalarDN = GlobalTensor, + Stride<1, 1, 1, 1, 1>, Layout::DN>; + + // --- GlobalTensor instances --- + + GlobalDataMxN oiNewGlobal(oi_new_ptr); + GlobalDataMxN oiGlobal(oi_ptr); + GlobalDataMxN dstGlobal(dst_ptr); + + // ND globals for scalar element-wise operations + GlobalScalarND mijGlobalND(mij_ptr); + GlobalScalarND lijGlobalND(lij_ptr); + GlobalScalarND miGlobalND(mi_ptr); + GlobalScalarND liGlobalND(li_ptr); + + // DN globals aliased to same GM for ColMajor reload (used after ND TSTORE) + GlobalScalarDN mijGlobalDN(mij_ptr); + GlobalScalarDN lijGlobalDN(lij_ptr); + GlobalScalarDN liGlobalDN(li_ptr); + + // --- Tile types --- + + using TileDataMxN = Tile; + using TileScalarND = Tile; + using TileScalarDN = Tile; + + // --- UB memory layout --- + + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarNDBytes = kScalarRows * kScalarCols * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + // Data tiles + TileDataMxN oiNewTile; + TileDataMxN oiTile; + + // Scalar ND tiles for element-wise arithmetic + TileScalarND mijND, lijND, miND, liND; + TileScalarND miNewND, alphaND, betaND, tmpND; + + // Scalar DN tiles for TROWEXPAND operations + TileScalarDN alphaDN, betaDN, liDN; + + TASSIGN(oiNewTile, 0); + TASSIGN(oiTile, kDataBytes); + TASSIGN(mijND, 2 * kDataBytes); + TASSIGN(lijND, 2 * kDataBytes + kScalarNDBytes); + TASSIGN(miND, 2 * kDataBytes + 2 * kScalarNDBytes); + TASSIGN(liND, 2 * kDataBytes + 3 * kScalarNDBytes); + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarNDBytes); + TASSIGN(alphaND, 2 * kDataBytes + 5 * kScalarNDBytes); + TASSIGN(betaND, 2 * kDataBytes + 6 * kScalarNDBytes); + TASSIGN(tmpND, 2 * kDataBytes + 7 * kScalarNDBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 8 * kScalarNDBytes); + TASSIGN(betaDN, 2 * kDataBytes + 8 * kScalarNDBytes + kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 8 * kScalarNDBytes + 2 * kScalarDNBytes); + + if (is_first) { + // --- First block: copy inputs to accumulators --- + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Passthrough to MTE3 (no V compute needed) + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, mijND); // mi = mij + TSTORE(liGlobalND, lijND); // li = lij + TSTORE(oiGlobal, oiNewTile); // oi = oi_new + + if (is_last) { + // Single block: normalize dst = oi_new / lij + // lij stored to li buffer in ND format; reload as DN for TROWEXPANDDIV + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(liDN, liGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDDIV(oiNewTile, oiNewTile, liDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiNewTile); + } + } else { + // --- Subsequent blocks: accumulate --- + + // Phase 1: Load all inputs + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(oiTile, oiGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + TLOAD(miND, miGlobalND); + TLOAD(liND, liGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Phase 2: Scalar arithmetic in RowMajor (kScalarRows, kScalarCols) + // pipe_barrier(PIPE_V) required between each dependent vector operation + // to resolve RAW hazards on shared UB tiles. + TMAX(miNewND, miND, mijND); // mi_new = max(mi, mij) + pipe_barrier(PIPE_V); + TSUB(alphaND, miND, miNewND); // alpha = mi - mi_new + pipe_barrier(PIPE_V); + TEXP(alphaND, alphaND); // alpha = exp(mi - mi_new) + pipe_barrier(PIPE_V); + TSUB(betaND, mijND, miNewND); // beta = mij - mi_new + pipe_barrier(PIPE_V); + TEXP(betaND, betaND); // beta = exp(mij - mi_new) + pipe_barrier(PIPE_V); + TMUL(liND, alphaND, liND); // li = alpha * li + pipe_barrier(PIPE_V); + TMUL(tmpND, betaND, lijND); // tmp = beta * lij + pipe_barrier(PIPE_V); + TADD(liND, liND, tmpND); // li = alpha * li + beta * lij (= li_new) + + // Phase 3: Store scalar results to GM (ND format) + // mi_new -> mi accumulator, li_new -> li accumulator + // alpha -> mij buffer (reuse), beta -> lij buffer (reuse) + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liND); // persist li_new + TSTORE(mijGlobalND, alphaND); // temp: alpha to mij buffer + TSTORE(lijGlobalND, betaND); // temp: beta to lij buffer + + // Phase 4: Reload alpha, beta (and li if last) as ColMajor DN + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(alphaDN, mijGlobalDN); // alpha from mij buffer as DN + TLOAD(betaDN, lijGlobalDN); // beta from lij buffer as DN + if (is_last) { + TLOAD(liDN, liGlobalDN); // li_new from li buffer as DN + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + + // Phase 5: Scale data tiles using row-broadcast multiply + TROWEXPANDMUL(oiTile, oiTile, alphaDN); // oi *= alpha + TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); // oi_new *= beta + pipe_barrier(PIPE_V); + TADD(oiTile, oiTile, oiNewTile); // oi = alpha*oi + beta*oi_new + + if (is_last) { + // Phase 6: Normalize and output + pipe_barrier(PIPE_V); + TROWEXPANDDIV(oiTile, oiTile, liDN); // dst = oi / li_new + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiTile); + } else { + // Phase 6: Store updated accumulators + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(oiGlobal, oiTile); + } + } +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ uint8_t* mij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* lij = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* oi_new = reinterpret_cast<__gm__ uint8_t*>(args[2]); + __gm__ uint8_t* mi = reinterpret_cast<__gm__ uint8_t*>(args[3]); + __gm__ uint8_t* li = reinterpret_cast<__gm__ uint8_t*>(args[4]); + __gm__ uint8_t* oi = reinterpret_cast<__gm__ uint8_t*>(args[5]); + int is_first = static_cast(args[6]); + int is_last = static_cast(args[7]); + __gm__ uint8_t* dst = reinterpret_cast<__gm__ uint8_t*>(args[8]); + + online_update_impl(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/aiv_softmax_prepare.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/aiv_softmax_prepare.cpp new file mode 100644 index 00000000..6715cf07 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/aiv_softmax_prepare.cpp @@ -0,0 +1,94 @@ +// Softmax Preparation Kernel (AIV) +// +// Fixed tile size: sij is (16, 16) +// +// Computes: +// sij_scale = sij * scale +// mij = row_max(sij_scale) -> (M, 1) +// pij = exp(sij_scale - mij) -> (M, N) +// lij = row_sum(pij) -> (M, 1) + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void softmax_prepare_impl(__gm__ uint8_t* sij_raw, float scale_value, + __gm__ uint8_t* pij_raw, __gm__ uint8_t* mij_raw, + __gm__ uint8_t* lij_raw) +{ + constexpr int M = 16, N = 16; + + __gm__ float* sij = reinterpret_cast<__gm__ float*>(sij_raw); + __gm__ half* pij = reinterpret_cast<__gm__ half*>(pij_raw); + __gm__ float* mij = reinterpret_cast<__gm__ float*>(mij_raw); + __gm__ float* lij = reinterpret_cast<__gm__ float*>(lij_raw); + + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalDataMxN_f16 = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + GlobalDataMxN sijGlobal(sij); + GlobalDataMxN_f16 pijGlobal(pij); + GlobalScalarDN mijGlobal(mij); + GlobalScalarDN lijGlobal(lij); + + using TileVecMxN = Tile; + using TileVecMxN_f16 = Tile; + using TileScalarDN = Tile; + + TileVecMxN sijTile; + TileVecMxN pijTile; + TileVecMxN tmpTile; + TileScalarDN maxTile; + TileScalarDN sumTile; + TileVecMxN_f16 pijF16Tile; + + TASSIGN(sijTile, 0x0); + TASSIGN(pijTile, M * N * sizeof(float)); + TASSIGN(tmpTile, 2 * M * N * sizeof(float)); + TASSIGN(maxTile, 3 * M * N * sizeof(float)); + TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float)); + TASSIGN(pijF16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float)); + + TLOAD(sijTile, sijGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TMULS(sijTile, sijTile, scale_value); + TROWMAX(maxTile, sijTile, tmpTile); + TROWEXPANDSUB(pijTile, sijTile, maxTile); + TEXP(pijTile, pijTile); + // Truncate pij to fp16 first, then compute lij from truncated values (matches golden) + TCVT(pijF16Tile, pijTile, RoundMode::CAST_ROUND); + TCVT(pijTile, pijF16Tile, RoundMode::CAST_ROUND); + TROWSUM(sumTile, pijTile, tmpTile); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobal, maxTile); + TSTORE(lijGlobal, sumTile); + TSTORE(pijGlobal, pijF16Tile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ uint8_t* sij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + union { uint64_t u; float f; } scale_conv; + scale_conv.u = static_cast(args[1]); + float scale_value = scale_conv.f; + __gm__ uint8_t* pij = reinterpret_cast<__gm__ uint8_t*>(args[2]); + __gm__ uint8_t* mij = reinterpret_cast<__gm__ uint8_t*>(args[3]); + __gm__ uint8_t* lij = reinterpret_cast<__gm__ uint8_t*>(args[4]); + + softmax_prepare_impl(sij, scale_value, pij, mij, lij); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp new file mode 100644 index 00000000..b1711665 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp @@ -0,0 +1,52 @@ +/** + * All-to-all barrier (多对多): every rank waits for every other rank. + * + * Used by AllGather where every rank reads from all ranks' windows. + * Unlike comm_barrier_kernel (many-to-one), ALL ranks do TWAIT here. + * + * Flow: + * 1. Each rank TNOTIFY to root's barrier slot[my_rank] + * 2. Each rank TWAIT on root's barrier until all n_ranks slots >= 1 + * + * Args: + * args[0] = barrier_base (local barrier buffer; root's is used for sync) + * args[1] = device_ctx_ptr (HcclDeviceContext*) + * args[2] = n_ranks + * args[3] = root (whose barrier buffer is the sync point) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(args[0]); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + // Step 1: Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Step 2: ALL ranks wait until every rank's flag is >= 1 (multi-to-multi). + __gm__ int32_t* root_barrier = HcclRemotePtr(ctx, local_barrier, root); + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(root_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/gather_kernel.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/gather_kernel.cpp new file mode 100644 index 00000000..2d972cfa --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/gather_kernel.cpp @@ -0,0 +1,62 @@ +/** + * TGATHER collective kernel - root gathers from all ranks. + * Requires pto-comm-isa (PTO_ISA_ROOT or PTO_COMM_ISA_ROOT). + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* src = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[2]); + int nranks = static_cast(args[3]); + int root = static_cast(args[4]); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + + using TileData = pto::Tile; + + int my_rank = static_cast(hcclCtx->rankId); + + ShapeDyn srcShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn srcStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + ShapeDyn dstShape(1, 1, 1, nranks, GATHER_COUNT); + StrideDyn dstStride(nranks * GATHER_COUNT, nranks * GATHER_COUNT, nranks * GATHER_COUNT, GATHER_COUNT, 1); + Global dstG(dst, dstShape, dstStride); + + Global tensors[16]; + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int i = 0; i < actual_nranks; ++i) { + __gm__ float* remoteSrc = HcclRemotePtr(hcclCtx, src, i); + tensors[i] = Global(remoteSrc, srcShape, srcStride); + } + + pto::comm::ParallelGroup pg(tensors, actual_nranks, root); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + if (my_rank == root) { + pto::comm::TGATHER(pg, dstG, ubTile); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/window_memcopy_in.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..38408baa --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before AllGather so remote ranks can read. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/window_memcopy_out.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..99e83e76 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * After AllGather, every rank copies gathered result to device. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/kernel_config.py b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/kernel_config.py new file mode 100644 index 00000000..de624018 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/kernel_config.py @@ -0,0 +1,37 @@ +""" +Paged Attention + AllGather (TGATHER): Paged Attention → N sequential Gathers. + +Flow per rank: + QK → Softmax → PV → OnlineUpdate (paged attention) + → WindowMemCopyIn -> for r in [0,n_ranks): Barrier -> Gather(root=r) + -> [if rank==r: WindowMemCopyOut] -> Barrier(post) +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "paged_attention_allgather_orch.cpp"), + "function_name": "build_paged_attention_allgather_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "QK", "source": str(_KERNELS_ROOT / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 1, "name": "SF", "source": str(_KERNELS_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "PV", "source": str(_KERNELS_ROOT / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 3, "name": "UP", "source": str(_KERNELS_ROOT / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 4, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 5, "name": "Gather", "source": str(_KERNELS_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 6, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 7, "name": "CommBarrierAll", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, + "aicpu_thread_num": 3, + "block_dim": 3, +} diff --git a/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/orchestration/paged_attention_allgather_orch.cpp b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/orchestration/paged_attention_allgather_orch.cpp new file mode 100644 index 00000000..c34fc8a0 --- /dev/null +++ b/examples/host_build_graph/paged_attention_allgather_Tgather/kernels/orchestration/paged_attention_allgather_orch.cpp @@ -0,0 +1,299 @@ +/** + * Paged Attention + AllGather (TGATHER): Paged Attention → N sequential Gathers. + * + * Phase 1: QK → Softmax → PV → OnlineUpdate (paged attention) + * Phase 2: WindowMemCopyIn -> for r in [0,n_ranks): Barrier -> Gather(root=r) + * -> [if rank==r: WindowMemCopyOut] -> Barrier(post) + */ + +#include "runtime.h" +#include +#include +#include +#include + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 +#define FUNC_WIN_MEMCOPY_IN 4 +#define FUNC_GATHER 5 +#define FUNC_WIN_MEMCOPY_OUT 6 +#define FUNC_COMM_BARRIER 7 + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +extern "C" { + +int build_paged_attention_allgather_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 22) { + std::cerr << "build_paged_attention_allgather_graph: Expected at least 22 args, got " << arg_count << '\n'; + return -1; + } + + void* host_query = reinterpret_cast(args[0]); + void* host_key_cache = reinterpret_cast(args[1]); + void* host_value_cache = reinterpret_cast(args[2]); + int* host_block_table = reinterpret_cast(args[3]); + int* host_context_lens = reinterpret_cast(args[4]); + void* host_attn_out = reinterpret_cast(args[5]); + void* host_allgather_out = reinterpret_cast(args[6]); + int64_t* host_config = reinterpret_cast(args[7]); + + size_t query_size = static_cast(args[8]); + size_t key_cache_size = static_cast(args[9]); + size_t value_cache_size = static_cast(args[10]); + size_t block_table_size = static_cast(args[11]); + size_t context_lens_size = static_cast(args[12]); + size_t attn_out_size = static_cast(args[13]); + size_t allgather_out_size = static_cast(args[14]); + size_t config_size = static_cast(args[15]); + uint64_t device_ctx_ptr = args[16]; + uint64_t win_in_base = args[17]; + uint64_t win_out_base = args[18]; + int n_ranks = static_cast(args[19]); + int root = static_cast(args[20]); + int rank_id = static_cast(args[21]); + + int batch = static_cast(host_config[0]); + int num_heads = static_cast(host_config[1]); + int kv_head_num = static_cast(host_config[2]); + int head_dim = static_cast(host_config[3]); + int block_size = static_cast(host_config[4]); + int max_num_blocks = static_cast(host_config[5]); + uint64_t scale_value_bits = static_cast(host_config[6]); + + int q_tile_size = std::min(num_heads, 128); + int num_head_tiles = (num_heads + q_tile_size - 1) / q_tile_size; + + std::cout << "\n=== build_paged_attention_allgather_graph (TGATHER) ===" << '\n'; + std::cout << " n_ranks=" << n_ranks << " rank_id=" << rank_id << '\n'; + + void* dev_query = runtime->host_api.device_malloc(query_size); + void* dev_key_cache = runtime->host_api.device_malloc(key_cache_size); + void* dev_value_cache = runtime->host_api.device_malloc(value_cache_size); + void* dev_attn_out = runtime->host_api.device_malloc(attn_out_size); + void* dev_allgather_out = runtime->host_api.device_malloc(allgather_out_size); + + if (!dev_query || !dev_key_cache || !dev_value_cache || !dev_attn_out || !dev_allgather_out) { + std::cerr << "Error: Failed to allocate device memory\n"; + return -1; + } + + runtime->host_api.copy_to_device(dev_query, host_query, query_size); + runtime->host_api.copy_to_device(dev_key_cache, host_key_cache, key_cache_size); + runtime->host_api.copy_to_device(dev_value_cache, host_value_cache, value_cache_size); + runtime->record_tensor_pair(host_attn_out, dev_attn_out, attn_out_size); + runtime->record_tensor_pair(host_allgather_out, dev_allgather_out, allgather_out_size); + + size_t sij_size = static_cast(q_tile_size) * block_size * sizeof(float); + size_t pij_size = static_cast(q_tile_size) * block_size * sizeof(uint16_t); + size_t mij_size = static_cast(q_tile_size) * sizeof(float); + size_t lij_size = mij_size; + size_t oi_new_size = static_cast(q_tile_size) * head_dim * sizeof(float); + + int total_buffers = batch * max_num_blocks; + void** dev_sij_arr = new void*[total_buffers]; + void** dev_pij_arr = new void*[total_buffers]; + void** dev_mij_arr = new void*[total_buffers]; + void** dev_lij_arr = new void*[total_buffers]; + void** dev_oi_new_arr = new void*[total_buffers]; + + for (int i = 0; i < total_buffers; i++) { + dev_sij_arr[i] = runtime->host_api.device_malloc(sij_size); + dev_pij_arr[i] = runtime->host_api.device_malloc(pij_size); + dev_mij_arr[i] = runtime->host_api.device_malloc(mij_size); + dev_lij_arr[i] = runtime->host_api.device_malloc(lij_size); + dev_oi_new_arr[i] = runtime->host_api.device_malloc(oi_new_size); + } + + int total_accums = batch * num_head_tiles; + size_t mi_size = static_cast(q_tile_size) * sizeof(float); + size_t li_size = mi_size; + size_t oi_size = static_cast(q_tile_size) * head_dim * sizeof(float); + + void** dev_mi_arr = new void*[total_accums]; + void** dev_li_arr = new void*[total_accums]; + void** dev_oi_arr = new void*[total_accums]; + + for (int i = 0; i < total_accums; i++) { + dev_mi_arr[i] = runtime->host_api.device_malloc(mi_size); + dev_li_arr[i] = runtime->host_api.device_malloc(li_size); + dev_oi_arr[i] = runtime->host_api.device_malloc(oi_size); + } + + std::vector last_pa_tasks; + + for (int b_idx = 0; b_idx < batch; b_idx++) { + int cur_seq = host_context_lens[b_idx]; + int bn_this_batch = (cur_seq + block_size - 1) / block_size; + + for (int ht = 0; ht < num_head_tiles; ht++) { + int cur_offset = ht * q_tile_size; + uint8_t* qi_ptr = reinterpret_cast(dev_query) + + static_cast(b_idx * num_heads + cur_offset) * head_dim * sizeof(uint16_t); + uint8_t* out_ptr = reinterpret_cast(dev_attn_out) + + static_cast(b_idx * num_heads + cur_offset) * head_dim * sizeof(float); + int kv_head_idx = cur_offset / (num_heads / kv_head_num); + int accum_idx = b_idx * num_head_tiles + ht; + void* dev_mi = dev_mi_arr[accum_idx]; + void* dev_li = dev_li_arr[accum_idx]; + void* dev_oi = dev_oi_arr[accum_idx]; + + int t_up_prev = -1; + + for (int bn = 0; bn < bn_this_batch; bn++) { + int cur_block_idx = host_block_table[b_idx * max_num_blocks + bn]; + uint8_t* kj_ptr = reinterpret_cast(dev_key_cache) + + (static_cast(cur_block_idx) * block_size * kv_head_num + kv_head_idx) + * head_dim * sizeof(uint16_t); + uint8_t* vj_ptr = reinterpret_cast(dev_value_cache) + + (static_cast(cur_block_idx) * block_size * kv_head_num + kv_head_idx) + * head_dim * sizeof(uint16_t); + + int buf_idx = b_idx * max_num_blocks + bn; + void* dev_sij = dev_sij_arr[buf_idx]; + void* dev_pij = dev_pij_arr[buf_idx]; + void* dev_mij = dev_mij_arr[buf_idx]; + void* dev_lij = dev_lij_arr[buf_idx]; + void* dev_oi_new = dev_oi_new_arr[buf_idx]; + + uint64_t qk_args[6] = { + reinterpret_cast(qi_ptr), + reinterpret_cast(kj_ptr), + reinterpret_cast(dev_sij), + static_cast(q_tile_size), + static_cast(head_dim), + static_cast(block_size) + }; + int t_qk = runtime->add_task(qk_args, 6, FUNC_QK_MATMUL, CoreType::AIC); + + uint64_t sf_args[7] = { + reinterpret_cast(dev_sij), + scale_value_bits, + reinterpret_cast(dev_pij), + reinterpret_cast(dev_mij), + reinterpret_cast(dev_lij), + static_cast(q_tile_size), + static_cast(block_size) + }; + int t_sf = runtime->add_task(sf_args, 7, FUNC_SOFTMAX_PREPARE, CoreType::AIV); + + uint64_t pv_args[6] = { + reinterpret_cast(dev_pij), + reinterpret_cast(vj_ptr), + reinterpret_cast(dev_oi_new), + static_cast(q_tile_size), + static_cast(block_size), + static_cast(head_dim) + }; + int t_pv = runtime->add_task(pv_args, 6, FUNC_PV_MATMUL, CoreType::AIC); + + runtime->add_successor(t_qk, t_sf); + runtime->add_successor(t_sf, t_pv); + + int is_first = (bn == 0) ? 1 : 0; + int is_last = (bn == bn_this_batch - 1) ? 1 : 0; + + uint64_t up_args[11] = { + reinterpret_cast(dev_mij), + reinterpret_cast(dev_lij), + reinterpret_cast(dev_oi_new), + reinterpret_cast(dev_mi), + reinterpret_cast(dev_li), + reinterpret_cast(dev_oi), + static_cast(is_first), + static_cast(is_last), + reinterpret_cast(out_ptr), + static_cast(q_tile_size), + static_cast(head_dim) + }; + int t_up = runtime->add_task(up_args, 11, FUNC_ONLINE_UPDATE, CoreType::AIV); + + runtime->add_successor(t_pv, t_up); + if (t_up_prev >= 0) { + runtime->add_successor(t_up_prev, t_up); + } + t_up_prev = t_up; + } + last_pa_tasks.push_back(t_up_prev); + } + } + + /* Phase 2: AllGather (TGATHER) - N rounds */ + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + size_t total_barrier_bytes = barrier_size * (static_cast(n_ranks) + 1); + uint64_t barrier_base_0 = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base_0 + total_barrier_bytes; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + int32_t zeros[64] = {}; + std::memset(zeros, 0, sizeof(zeros)); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base_0), zeros, total_barrier_bytes); + + uint64_t args_wmin[3] = { + win_src, + reinterpret_cast(dev_attn_out), + static_cast(GATHER_COUNT) + }; + int t_wmin = runtime->add_task(args_wmin, 3, FUNC_WIN_MEMCOPY_IN, CoreType::AIV); + for (int t : last_pa_tasks) { + runtime->add_successor(t, t_wmin); + } + + int t_prev = t_wmin; + for (int r = 0; r < n_ranks; r++) { + uint64_t barrier_base_r = barrier_base_0 + static_cast(r) * barrier_size; + uint64_t args_barrier[4] = { + barrier_base_r, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t_barrier = runtime->add_task(args_barrier, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t_prev, t_barrier); + + uint64_t args_gather[5] = { + win_dst, win_src, device_ctx_ptr, + static_cast(n_ranks), static_cast(r) + }; + int t_gather = runtime->add_task(args_gather, 5, FUNC_GATHER, CoreType::AIV); + runtime->add_successor(t_barrier, t_gather); + + if (rank_id == r) { + uint64_t args_wmout[3] = { + reinterpret_cast(dev_allgather_out), + win_dst, + static_cast(n_ranks * GATHER_COUNT) + }; + int t_wmout = runtime->add_task(args_wmout, 3, FUNC_WIN_MEMCOPY_OUT, CoreType::AIV); + runtime->add_successor(t_gather, t_wmout); + t_prev = t_wmout; + } else { + t_prev = t_gather; + } + } + + uint64_t barrier_base_post = barrier_base_0 + static_cast(n_ranks) * barrier_size; + uint64_t args_barrier_post[4] = { + barrier_base_post, device_ctx_ptr, + static_cast(n_ranks), static_cast(0) + }; + int t_post = runtime->add_task(args_barrier_post, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t_prev, t_post); + + delete[] dev_sij_arr; + delete[] dev_pij_arr; + delete[] dev_mij_arr; + delete[] dev_lij_arr; + delete[] dev_oi_new_arr; + delete[] dev_mi_arr; + delete[] dev_li_arr; + delete[] dev_oi_arr; + + std::cout << "Created paged_attention_allgather (TGATHER) graph\n"; + runtime->print_runtime(); + + return 0; +} + +} // extern "C" diff --git a/examples/host_build_graph/paged_attention_gather/README.md b/examples/host_build_graph/paged_attention_gather/README.md new file mode 100644 index 00000000..9a399390 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/README.md @@ -0,0 +1,11 @@ +# Paged Attention + Gather (host_build_graph) + +Paged Attention 计算后 TGATHER。Root 收集各 rank 的 attn_out 前 GATHER_COUNT 元素。 + +流程:QK → Softmax → PV → OnlineUpdate → WindowMemCopyIn → CommBarrier → TGATHER → WindowMemCopyOut (root only) + +## 运行 + +```bash +./run_hostbuild.sh paged_attention_gather 2 0 +``` diff --git a/examples/host_build_graph/paged_attention_gather/golden.py b/examples/host_build_graph/paged_attention_gather/golden.py new file mode 100644 index 00000000..7bb490f3 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/golden.py @@ -0,0 +1,148 @@ +""" +Paged Attention + Gather: Paged Attention → TGATHER. + +Each rank computes paged attention on its own Q/K/V, then root gathers +first GATHER_COUNT elements from each rank's attn_out into gather_out. +""" + +import ctypes +import struct +import torch +import numpy as np + +GATHER_COUNT = 64 +BATCH = 1 +NUM_HEADS = 16 +KV_HEAD_NUM = 1 +HEAD_DIM = 16 +BLOCK_SIZE = 16 +CONTEXT_LEN = 16 +MAX_MODEL_LEN = 256 + +__outputs__ = ["attn_out", "gather_out"] +RTOL = 1e-2 +ATOL = 1e-2 +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" + +def _make_block_table_and_context(): + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + cur_valid_blocks = (CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE + total_blocks = BATCH * cur_valid_blocks + torch.manual_seed(100) + block_table = torch.randint(0, max(total_blocks, 1), size=(BATCH, max_num_blocks_per_req), dtype=torch.int32) + context_lens = torch.full((BATCH,), CONTEXT_LEN, dtype=torch.int32) + return block_table, context_lens, total_blocks, max_num_blocks_per_req + +def _make_qkv(rank_id, total_blocks): + torch.manual_seed(42 + rank_id) + q = (torch.rand(BATCH, 1, NUM_HEADS * HEAD_DIM) - 0.5).to(torch.float16) + q = q.reshape(BATCH, NUM_HEADS, HEAD_DIM) + k = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) - 0.5).to(torch.float16) + v = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) * 2 - 1).to(torch.float16) + return q, k, v + +def generate_inputs(params: dict) -> list: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + block_table, context_lens, total_blocks, max_num_blocks_per_req = _make_block_table_and_context() + query_fp16, key_fp16, value_fp16 = _make_qkv(rank_id, total_blocks) + scale_value = 1.0 + scale_bits = struct.unpack('I', struct.pack('f', scale_value))[0] + config = torch.tensor([BATCH, NUM_HEADS, KV_HEAD_NUM, HEAD_DIM, BLOCK_SIZE, max_num_blocks_per_req, scale_bits], dtype=torch.int64) + query = query_fp16.flatten() + key_cache = key_fp16.flatten() + value_cache = value_fp16.flatten() + block_table_flat = block_table.flatten() + attn_out = torch.zeros(BATCH * NUM_HEADS * HEAD_DIM, dtype=torch.float32) + gather_out = torch.zeros(n_ranks * GATHER_COUNT, dtype=torch.float32) # root only, but all ranks need buffer + result = [ + ("query", query), ("key_cache", key_cache), ("value_cache", value_cache), + ("block_table", block_table_flat), ("context_lens", context_lens), + ("attn_out", attn_out), ("gather_out", gather_out), ("config", config), + ("size_query", ctypes.c_int64(query.nbytes)), ("size_key_cache", ctypes.c_int64(key_cache.nbytes)), + ("size_value_cache", ctypes.c_int64(value_cache.nbytes)), ("size_block_table", ctypes.c_int64(block_table_flat.nbytes)), + ("size_context_lens", ctypes.c_int64(context_lens.nbytes)), ("size_attn_out", ctypes.c_int64(attn_out.nbytes)), + ("size_gather_out", ctypes.c_int64(gather_out.nbytes)), ("size_config", ctypes.c_int64(config.nbytes)), + ] + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), ("root", ctypes.c_int32(root)), ("rank_id", ctypes.c_int32(rank_id)), + ]) + return result + +def paged_attention(query, key_cache, value_cache, num_kv_heads, num_heads, scale_value, block_table, context_lens): + assert num_kv_heads == 1 + batch, num_heads_dim, head_dim = query.shape + _, block_size, _, _ = key_cache.shape + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + q_tile = min(num_heads_dim, 128) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + oi, li, mi = None, None, None + for bn in range(max_bn): + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 + if not active_mask.any(): break + block_indices = block_table[:, bn] + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) + valid_mask = pos < valid_lens.unsqueeze(1) + valid_mask = valid_mask.unsqueeze(1) + sij = sij.masked_fill(~valid_mask, float('-inf')) + batch_mask = active_mask.view(-1, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + mij = sij.max(dim=-1, keepdim=True)[0] + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) + oi_new = torch.bmm(pij, vj_all) + if bn == 0: + oi, li, mi = oi_new, lij, mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + return out.reshape(-1, head_dim) + +def _compute_rank_attn(rank_id, block_table, context_lens, total_blocks): + q, k, v = _make_qkv(rank_id, total_blocks) + return paged_attention(q, k.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), v.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), + KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens) + +def compute_golden(tensors: dict, params: dict) -> None: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + total_blocks = BATCH * ((CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE) + block_table = tensors["block_table"].reshape(BATCH, max_num_blocks_per_req) + context_lens_t = tensors["context_lens"] + query = tensors["query"].reshape(BATCH, NUM_HEADS, HEAD_DIM) + key_cache = tensors["key_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + value_cache = tensors["value_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + attn_result = paged_attention(query, key_cache, value_cache, KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens_t) + tensors["attn_out"][:] = attn_result.flatten() + if rank_id == root: + gather_np = tensors["gather_out"].cpu().numpy() if hasattr(tensors["gather_out"], 'cpu') else np.asarray(tensors["gather_out"]) + for r in range(n_ranks): + attn_r = _compute_rank_attn(r, block_table, context_lens_t, total_blocks) + flat_r = attn_r.flatten().numpy() + gather_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = flat_r[:GATHER_COUNT] diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aic/aic_pv_matmul.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aic/aic_pv_matmul.cpp new file mode 100644 index 00000000..45bf49eb --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aic/aic_pv_matmul.cpp @@ -0,0 +1,90 @@ +// PV Matmul Kernel: pij(M, K) @ vj(K, N) -> oi_new(M, N) +// +// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16) +// +// pij is float16 (converted from fp32 in softmax_prepare via TCVT). +// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout. +// Standard non-transposed B pattern: ND GlobalB + ColMajor/RowMajor TileMatB. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void pv_matmul_impl(__gm__ uint8_t* pij_raw, __gm__ uint8_t* vj_raw, __gm__ uint8_t* oi_raw) +{ + constexpr int M = 16, K = 16, N = 16; + + __gm__ half* pij = reinterpret_cast<__gm__ half*>(pij_raw); + __gm__ half* vj = reinterpret_cast<__gm__ half*>(vj_raw); + __gm__ float* oi = reinterpret_cast<__gm__ float*>(oi_raw); + + // pij (M, K) fp16, vj (K, N) fp16 in ND (row-major), oi_new (M, N) fp32 + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA pijGlobal(pij); + GlobalB vjGlobal(vj); + GlobalOut oiGlobal(oi); + + // L1 Mat tiles: standard ND pattern for both A and B + using TileMatA = Tile; + using TileMatB = Tile; + + // L0 tiles + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Load pij and vj to L1 + TLOAD(aMatTile, pijGlobal); + TLOAD(bMatTile, vjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // Move to L0A/L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + // Single matmul: (M,K) x (K,N) -> (M,N) + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(oiGlobal, cTile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) +{ + __gm__ uint8_t* pij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* vj = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* oi_new = reinterpret_cast<__gm__ uint8_t*>(args[2]); + + pv_matmul_impl(pij, vj, oi_new); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aic/aic_qk_matmul.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aic/aic_qk_matmul.cpp new file mode 100644 index 00000000..e1e026a2 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aic/aic_qk_matmul.cpp @@ -0,0 +1,91 @@ +// QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) +// +// Fixed tile size: (16, 16) @ (16, 16).T -> (16, 16) +// +// kj is stored as (N, K) = (block_size, head_dim) in row-major memory. +// This is equivalent to (K, N) in column-major (DN) layout. +// Using DN GlobalB + RowMajor/ColMajor TileMatB to handle the transposed B pattern. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void qk_matmul_impl(__gm__ uint8_t* qi_raw, __gm__ uint8_t* kj_raw, __gm__ uint8_t* sij_raw) +{ + constexpr int M = 16, K = 16, N = 16; + + __gm__ half* qi = reinterpret_cast<__gm__ half*>(qi_raw); + __gm__ half* kj = reinterpret_cast<__gm__ half*>(kj_raw); + __gm__ float* sij = reinterpret_cast<__gm__ float*>(sij_raw); + + // qi (M, K) fp16 in ND (row-major) layout + using GlobalA = GlobalTensor, Stride>; + // kj stored as (N, K) row-major = (K, N) column-major -> DN layout + using GlobalB = GlobalTensor, Stride, Layout::DN>; + using GlobalOut = GlobalTensor, Stride>; + + GlobalA qiGlobal(qi); + GlobalB kjGlobal(kj); + GlobalOut sijGlobal(sij); + + // L1 Mat tiles: A is standard ND, B uses transposed-B pattern (RowMajor/ColMajor) + using TileMatA = Tile; + using TileMatB = Tile; + + // L0 tiles + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Load qi and kj to L1 + TLOAD(aMatTile, qiGlobal); + TLOAD(bMatTile, kjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // Move to L0A/L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + // Single matmul: (M,K) x (K,N) -> (M,N) + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(sijGlobal, cTile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) +{ + __gm__ uint8_t* qi = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* kj = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* sij = reinterpret_cast<__gm__ uint8_t*>(args[2]); + + qk_matmul_impl(qi, kj, sij); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aiv/aiv_online_update.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aiv/aiv_online_update.cpp new file mode 100644 index 00000000..16e93016 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aiv/aiv_online_update.cpp @@ -0,0 +1,230 @@ +// Online Softmax Update + Normalize Kernel (AIV) +// +// Fixed tile size: oi/oi_new are (16, 16), mij/lij/mi/li are 16-element vectors +// +// Scalar layout strategy: +// M scalar floats stored contiguously in GM can be loaded as either: +// - ND (kScalarRows, kScalarCols) RowMajor for element-wise ops (TMAX, TSUB, TEXP, TMUL, TADD) +// - DN (kAlignedRows, 1) ColMajor for row-broadcast ops (TROWEXPANDMUL, TROWEXPANDDIV) +// Conversion between layouts uses GM round-trip: ND TSTORE -> DN TLOAD. + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void online_update_impl(__gm__ uint8_t* mij_raw, __gm__ uint8_t* lij_raw, + __gm__ uint8_t* oi_new_raw, __gm__ uint8_t* mi_raw, + __gm__ uint8_t* li_raw, __gm__ uint8_t* oi_raw, + int is_first, int is_last, __gm__ uint8_t* dst_raw) +{ + constexpr int M = 16, N = 16; + + __gm__ float* mij_ptr = reinterpret_cast<__gm__ float*>(mij_raw); + __gm__ float* lij_ptr = reinterpret_cast<__gm__ float*>(lij_raw); + __gm__ float* oi_new_ptr = reinterpret_cast<__gm__ float*>(oi_new_raw); + __gm__ float* mi_ptr = reinterpret_cast<__gm__ float*>(mi_raw); + __gm__ float* li_ptr = reinterpret_cast<__gm__ float*>(li_raw); + __gm__ float* oi_ptr = reinterpret_cast<__gm__ float*>(oi_raw); + __gm__ float* dst_ptr = reinterpret_cast<__gm__ float*>(dst_raw); + + // Scalar tile dimensions for RowMajor layout: + // kScalarCols = 32 bytes / 4 bytes per float = 8 floats per row (one 32-byte block) + // kScalarRows = M / 8 (M=16 -> 2 rows) + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + // Aligned rows for ColMajor DN tiles (32-byte alignment) + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + // --- GlobalTensor types --- + + // Data (M, N) RowMajor + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + + // Scalar ND: M contiguous floats as (kScalarRows, kScalarCols) RowMajor + using GlobalScalarND = GlobalTensor, + Stride<1, 1, 1, kScalarCols, 1>>; + + // Scalar DN: same M contiguous floats as (kAlignedRows, 1) ColMajor + using GlobalScalarDN = GlobalTensor, + Stride<1, 1, 1, 1, 1>, Layout::DN>; + + // --- GlobalTensor instances --- + + GlobalDataMxN oiNewGlobal(oi_new_ptr); + GlobalDataMxN oiGlobal(oi_ptr); + GlobalDataMxN dstGlobal(dst_ptr); + + // ND globals for scalar element-wise operations + GlobalScalarND mijGlobalND(mij_ptr); + GlobalScalarND lijGlobalND(lij_ptr); + GlobalScalarND miGlobalND(mi_ptr); + GlobalScalarND liGlobalND(li_ptr); + + // DN globals aliased to same GM for ColMajor reload (used after ND TSTORE) + GlobalScalarDN mijGlobalDN(mij_ptr); + GlobalScalarDN lijGlobalDN(lij_ptr); + GlobalScalarDN liGlobalDN(li_ptr); + + // --- Tile types --- + + using TileDataMxN = Tile; + using TileScalarND = Tile; + using TileScalarDN = Tile; + + // --- UB memory layout --- + + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarNDBytes = kScalarRows * kScalarCols * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + // Data tiles + TileDataMxN oiNewTile; + TileDataMxN oiTile; + + // Scalar ND tiles for element-wise arithmetic + TileScalarND mijND, lijND, miND, liND; + TileScalarND miNewND, alphaND, betaND, tmpND; + + // Scalar DN tiles for TROWEXPAND operations + TileScalarDN alphaDN, betaDN, liDN; + + TASSIGN(oiNewTile, 0); + TASSIGN(oiTile, kDataBytes); + TASSIGN(mijND, 2 * kDataBytes); + TASSIGN(lijND, 2 * kDataBytes + kScalarNDBytes); + TASSIGN(miND, 2 * kDataBytes + 2 * kScalarNDBytes); + TASSIGN(liND, 2 * kDataBytes + 3 * kScalarNDBytes); + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarNDBytes); + TASSIGN(alphaND, 2 * kDataBytes + 5 * kScalarNDBytes); + TASSIGN(betaND, 2 * kDataBytes + 6 * kScalarNDBytes); + TASSIGN(tmpND, 2 * kDataBytes + 7 * kScalarNDBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 8 * kScalarNDBytes); + TASSIGN(betaDN, 2 * kDataBytes + 8 * kScalarNDBytes + kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 8 * kScalarNDBytes + 2 * kScalarDNBytes); + + if (is_first) { + // --- First block: copy inputs to accumulators --- + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Passthrough to MTE3 (no V compute needed) + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, mijND); // mi = mij + TSTORE(liGlobalND, lijND); // li = lij + TSTORE(oiGlobal, oiNewTile); // oi = oi_new + + if (is_last) { + // Single block: normalize dst = oi_new / lij + // lij stored to li buffer in ND format; reload as DN for TROWEXPANDDIV + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(liDN, liGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDDIV(oiNewTile, oiNewTile, liDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiNewTile); + } + } else { + // --- Subsequent blocks: accumulate --- + + // Phase 1: Load all inputs + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(oiTile, oiGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + TLOAD(miND, miGlobalND); + TLOAD(liND, liGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Phase 2: Scalar arithmetic in RowMajor (kScalarRows, kScalarCols) + // pipe_barrier(PIPE_V) required between each dependent vector operation + // to resolve RAW hazards on shared UB tiles. + TMAX(miNewND, miND, mijND); // mi_new = max(mi, mij) + pipe_barrier(PIPE_V); + TSUB(alphaND, miND, miNewND); // alpha = mi - mi_new + pipe_barrier(PIPE_V); + TEXP(alphaND, alphaND); // alpha = exp(mi - mi_new) + pipe_barrier(PIPE_V); + TSUB(betaND, mijND, miNewND); // beta = mij - mi_new + pipe_barrier(PIPE_V); + TEXP(betaND, betaND); // beta = exp(mij - mi_new) + pipe_barrier(PIPE_V); + TMUL(liND, alphaND, liND); // li = alpha * li + pipe_barrier(PIPE_V); + TMUL(tmpND, betaND, lijND); // tmp = beta * lij + pipe_barrier(PIPE_V); + TADD(liND, liND, tmpND); // li = alpha * li + beta * lij (= li_new) + + // Phase 3: Store scalar results to GM (ND format) + // mi_new -> mi accumulator, li_new -> li accumulator + // alpha -> mij buffer (reuse), beta -> lij buffer (reuse) + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liND); // persist li_new + TSTORE(mijGlobalND, alphaND); // temp: alpha to mij buffer + TSTORE(lijGlobalND, betaND); // temp: beta to lij buffer + + // Phase 4: Reload alpha, beta (and li if last) as ColMajor DN + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(alphaDN, mijGlobalDN); // alpha from mij buffer as DN + TLOAD(betaDN, lijGlobalDN); // beta from lij buffer as DN + if (is_last) { + TLOAD(liDN, liGlobalDN); // li_new from li buffer as DN + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + + // Phase 5: Scale data tiles using row-broadcast multiply + TROWEXPANDMUL(oiTile, oiTile, alphaDN); // oi *= alpha + TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); // oi_new *= beta + pipe_barrier(PIPE_V); + TADD(oiTile, oiTile, oiNewTile); // oi = alpha*oi + beta*oi_new + + if (is_last) { + // Phase 6: Normalize and output + pipe_barrier(PIPE_V); + TROWEXPANDDIV(oiTile, oiTile, liDN); // dst = oi / li_new + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiTile); + } else { + // Phase 6: Store updated accumulators + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(oiGlobal, oiTile); + } + } +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ uint8_t* mij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + __gm__ uint8_t* lij = reinterpret_cast<__gm__ uint8_t*>(args[1]); + __gm__ uint8_t* oi_new = reinterpret_cast<__gm__ uint8_t*>(args[2]); + __gm__ uint8_t* mi = reinterpret_cast<__gm__ uint8_t*>(args[3]); + __gm__ uint8_t* li = reinterpret_cast<__gm__ uint8_t*>(args[4]); + __gm__ uint8_t* oi = reinterpret_cast<__gm__ uint8_t*>(args[5]); + int is_first = static_cast(args[6]); + int is_last = static_cast(args[7]); + __gm__ uint8_t* dst = reinterpret_cast<__gm__ uint8_t*>(args[8]); + + online_update_impl(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aiv/aiv_softmax_prepare.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aiv/aiv_softmax_prepare.cpp new file mode 100644 index 00000000..6715cf07 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aiv/aiv_softmax_prepare.cpp @@ -0,0 +1,94 @@ +// Softmax Preparation Kernel (AIV) +// +// Fixed tile size: sij is (16, 16) +// +// Computes: +// sij_scale = sij * scale +// mij = row_max(sij_scale) -> (M, 1) +// pij = exp(sij_scale - mij) -> (M, N) +// lij = row_sum(pij) -> (M, 1) + +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static __aicore__ void softmax_prepare_impl(__gm__ uint8_t* sij_raw, float scale_value, + __gm__ uint8_t* pij_raw, __gm__ uint8_t* mij_raw, + __gm__ uint8_t* lij_raw) +{ + constexpr int M = 16, N = 16; + + __gm__ float* sij = reinterpret_cast<__gm__ float*>(sij_raw); + __gm__ half* pij = reinterpret_cast<__gm__ half*>(pij_raw); + __gm__ float* mij = reinterpret_cast<__gm__ float*>(mij_raw); + __gm__ float* lij = reinterpret_cast<__gm__ float*>(lij_raw); + + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalDataMxN_f16 = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + GlobalDataMxN sijGlobal(sij); + GlobalDataMxN_f16 pijGlobal(pij); + GlobalScalarDN mijGlobal(mij); + GlobalScalarDN lijGlobal(lij); + + using TileVecMxN = Tile; + using TileVecMxN_f16 = Tile; + using TileScalarDN = Tile; + + TileVecMxN sijTile; + TileVecMxN pijTile; + TileVecMxN tmpTile; + TileScalarDN maxTile; + TileScalarDN sumTile; + TileVecMxN_f16 pijF16Tile; + + TASSIGN(sijTile, 0x0); + TASSIGN(pijTile, M * N * sizeof(float)); + TASSIGN(tmpTile, 2 * M * N * sizeof(float)); + TASSIGN(maxTile, 3 * M * N * sizeof(float)); + TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float)); + TASSIGN(pijF16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float)); + + TLOAD(sijTile, sijGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TMULS(sijTile, sijTile, scale_value); + TROWMAX(maxTile, sijTile, tmpTile); + TROWEXPANDSUB(pijTile, sijTile, maxTile); + TEXP(pijTile, pijTile); + // Truncate pij to fp16 first, then compute lij from truncated values (matches golden) + TCVT(pijF16Tile, pijTile, RoundMode::CAST_ROUND); + TCVT(pijTile, pijF16Tile, RoundMode::CAST_ROUND); + TROWSUM(sumTile, pijTile, tmpTile); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobal, maxTile); + TSTORE(lijGlobal, sumTile); + TSTORE(pijGlobal, pijF16Tile); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ uint8_t* sij = reinterpret_cast<__gm__ uint8_t*>(args[0]); + union { uint64_t u; float f; } scale_conv; + scale_conv.u = static_cast(args[1]); + float scale_value = scale_conv.f; + __gm__ uint8_t* pij = reinterpret_cast<__gm__ uint8_t*>(args[2]); + __gm__ uint8_t* mij = reinterpret_cast<__gm__ uint8_t*>(args[3]); + __gm__ uint8_t* lij = reinterpret_cast<__gm__ uint8_t*>(args[4]); + + softmax_prepare_impl(sij, scale_value, pij, mij, lij); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aiv/comm_barrier_kernel.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aiv/comm_barrier_kernel.cpp new file mode 100644 index 00000000..7e210a16 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aiv/comm_barrier_kernel.cpp @@ -0,0 +1,51 @@ +/** + * Device-side cross-rank barrier using TNOTIFY/TWAIT from pto-comm-isa. + * + * Each rank notifies root that it has finished the compute phase by writing + * a flag to root's barrier slot. Root then spins until all ranks have + * reported, guaranteeing that every rank's window data is visible before + * TGATHER reads it. + * + * Args: + * args[0] = barrier_base (local barrier signal buffer in own windowsIn) + * args[1] = device_ctx_ptr (HcclDeviceContext*) + * args[2] = n_ranks + * args[3] = root + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(args[0]); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + // Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Root waits until every rank's flag is >= 1. + if (my_rank == root) { + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(local_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aiv/gather_kernel.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aiv/gather_kernel.cpp new file mode 100644 index 00000000..2d972cfa --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aiv/gather_kernel.cpp @@ -0,0 +1,62 @@ +/** + * TGATHER collective kernel - root gathers from all ranks. + * Requires pto-comm-isa (PTO_ISA_ROOT or PTO_COMM_ISA_ROOT). + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* src = reinterpret_cast<__gm__ float*>(args[1]); + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[2]); + int nranks = static_cast(args[3]); + int root = static_cast(args[4]); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + + using TileData = pto::Tile; + + int my_rank = static_cast(hcclCtx->rankId); + + ShapeDyn srcShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn srcStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + ShapeDyn dstShape(1, 1, 1, nranks, GATHER_COUNT); + StrideDyn dstStride(nranks * GATHER_COUNT, nranks * GATHER_COUNT, nranks * GATHER_COUNT, GATHER_COUNT, 1); + Global dstG(dst, dstShape, dstStride); + + Global tensors[16]; + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int i = 0; i < actual_nranks; ++i) { + __gm__ float* remoteSrc = HcclRemotePtr(hcclCtx, src, i); + tensors[i] = Global(remoteSrc, srcShape, srcStride); + } + + pto::comm::ParallelGroup pg(tensors, actual_nranks, root); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + if (my_rank == root) { + pto::comm::TGATHER(pg, dstG, ubTile); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aiv/window_memcopy_in.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..73504fa1 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before TGATHER so remote ranks can read. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/aiv/window_memcopy_out.cpp b/examples/host_build_graph/paged_attention_gather/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..3f2ef586 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,26 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * Root only - after TGATHER, copy gathered result to device. + */ + +#include +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(args[1]); + int count = static_cast(args[2]); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/kernel_config.py b/examples/host_build_graph/paged_attention_gather/kernels/kernel_config.py new file mode 100644 index 00000000..c81042e1 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/kernel_config.py @@ -0,0 +1,36 @@ +""" +Paged Attention + Gather: Paged Attention → TGATHER. + +Flow per rank: + QK → Softmax → PV → OnlineUpdate (paged attention, possibly multi-block) + → WindowMemCopyIn → CommBarrier → TGATHER → WindowMemCopyOut (root only) +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "paged_attention_gather_orch.cpp"), + "function_name": "build_paged_attention_gather_graph", +} + +KERNELS = [ + {"func_id": 0, "name": "QK", "source": str(_KERNELS_ROOT / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 1, "name": "SF", "source": str(_KERNELS_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "PV", "source": str(_KERNELS_ROOT / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 3, "name": "UP", "source": str(_KERNELS_ROOT / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 4, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 5, "name": "Gather", "source": str(_KERNELS_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 6, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 7, "name": "CommBarrier", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, + "aicpu_thread_num": 3, + "block_dim": 3, +} diff --git a/examples/host_build_graph/paged_attention_gather/kernels/orchestration/paged_attention_gather_orch.cpp b/examples/host_build_graph/paged_attention_gather/kernels/orchestration/paged_attention_gather_orch.cpp new file mode 100644 index 00000000..03dfa233 --- /dev/null +++ b/examples/host_build_graph/paged_attention_gather/kernels/orchestration/paged_attention_gather_orch.cpp @@ -0,0 +1,294 @@ +/** + * Paged Attention + Gather: Paged Attention → TGATHER. + * + * Phase 1: QK → Softmax → PV → OnlineUpdate (paged attention) + * Phase 2: WindowMemCopyIn → CommBarrier → TGATHER → WindowMemCopyOut (root only) + * + * Args layout: same as paged_attention but host_out is attn_out, plus: + * host_gather_out (root only), device_ctx_ptr, win_in_base, win_out_base, + * n_ranks, root, rank_id. Expect arg_count >= 16 for comm args. + */ + +#include "runtime.h" +#include +#include +#include +#include + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 +#define FUNC_WIN_MEMCOPY_IN 4 +#define FUNC_GATHER 5 +#define FUNC_WIN_MEMCOPY_OUT 6 +#define FUNC_COMM_BARRIER 7 + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +extern "C" { + +int build_paged_attention_gather_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 16) { + std::cerr << "build_paged_attention_gather_graph: Expected at least 16 args, got " << arg_count << '\n'; + return -1; + } + + void* host_query = reinterpret_cast(args[0]); + void* host_key_cache = reinterpret_cast(args[1]); + void* host_value_cache = reinterpret_cast(args[2]); + int* host_block_table = reinterpret_cast(args[3]); + int* host_context_lens = reinterpret_cast(args[4]); + void* host_attn_out = reinterpret_cast(args[5]); + void* host_gather_out = reinterpret_cast(args[6]); + int64_t* host_config = reinterpret_cast(args[7]); + + size_t query_size = static_cast(args[8]); + size_t key_cache_size = static_cast(args[9]); + size_t value_cache_size = static_cast(args[10]); + size_t block_table_size = static_cast(args[11]); + size_t context_lens_size = static_cast(args[12]); + size_t attn_out_size = static_cast(args[13]); + size_t gather_out_size = static_cast(args[14]); + size_t config_size = static_cast(args[15]); + uint64_t device_ctx_ptr = args[16]; + uint64_t win_in_base = args[17]; + uint64_t win_out_base = args[18]; + int n_ranks = static_cast(args[19]); + int root = static_cast(args[20]); + int rank_id = static_cast(args[21]); + + int batch = static_cast(host_config[0]); + int num_heads = static_cast(host_config[1]); + int kv_head_num = static_cast(host_config[2]); + int head_dim = static_cast(host_config[3]); + int block_size = static_cast(host_config[4]); + int max_num_blocks = static_cast(host_config[5]); + uint64_t scale_value_bits = static_cast(host_config[6]); + + int q_tile_size = std::min(num_heads, 128); + int num_head_tiles = (num_heads + q_tile_size - 1) / q_tile_size; + + std::cout << "\n=== build_paged_attention_gather_graph ===" << '\n'; + std::cout << " n_ranks=" << n_ranks << " root=" << root << " rank_id=" << rank_id << '\n'; + + void* dev_query = runtime->host_api.device_malloc(query_size); + void* dev_key_cache = runtime->host_api.device_malloc(key_cache_size); + void* dev_value_cache = runtime->host_api.device_malloc(value_cache_size); + void* dev_attn_out = runtime->host_api.device_malloc(attn_out_size); + + if (!dev_query || !dev_key_cache || !dev_value_cache || !dev_attn_out) { + std::cerr << "Error: Failed to allocate device memory\n"; + return -1; + } + + runtime->host_api.copy_to_device(dev_query, host_query, query_size); + runtime->host_api.copy_to_device(dev_key_cache, host_key_cache, key_cache_size); + runtime->host_api.copy_to_device(dev_value_cache, host_value_cache, value_cache_size); + runtime->record_tensor_pair(host_attn_out, dev_attn_out, attn_out_size); + + void* dev_gather_out = nullptr; + if (rank_id == root) { + dev_gather_out = runtime->host_api.device_malloc(gather_out_size); + if (!dev_gather_out) { + runtime->host_api.device_free(dev_attn_out); + return -1; + } + runtime->record_tensor_pair(host_gather_out, dev_gather_out, gather_out_size); + } + + size_t sij_size = static_cast(q_tile_size) * block_size * sizeof(float); + size_t pij_size = static_cast(q_tile_size) * block_size * sizeof(uint16_t); + size_t mij_size = static_cast(q_tile_size) * sizeof(float); + size_t lij_size = mij_size; + size_t oi_new_size = static_cast(q_tile_size) * head_dim * sizeof(float); + + int total_buffers = batch * max_num_blocks; + void** dev_sij_arr = new void*[total_buffers]; + void** dev_pij_arr = new void*[total_buffers]; + void** dev_mij_arr = new void*[total_buffers]; + void** dev_lij_arr = new void*[total_buffers]; + void** dev_oi_new_arr = new void*[total_buffers]; + + for (int i = 0; i < total_buffers; i++) { + dev_sij_arr[i] = runtime->host_api.device_malloc(sij_size); + dev_pij_arr[i] = runtime->host_api.device_malloc(pij_size); + dev_mij_arr[i] = runtime->host_api.device_malloc(mij_size); + dev_lij_arr[i] = runtime->host_api.device_malloc(lij_size); + dev_oi_new_arr[i] = runtime->host_api.device_malloc(oi_new_size); + } + + int total_accums = batch * num_head_tiles; + size_t mi_size = static_cast(q_tile_size) * sizeof(float); + size_t li_size = mi_size; + size_t oi_size = static_cast(q_tile_size) * head_dim * sizeof(float); + + void** dev_mi_arr = new void*[total_accums]; + void** dev_li_arr = new void*[total_accums]; + void** dev_oi_arr = new void*[total_accums]; + + for (int i = 0; i < total_accums; i++) { + dev_mi_arr[i] = runtime->host_api.device_malloc(mi_size); + dev_li_arr[i] = runtime->host_api.device_malloc(li_size); + dev_oi_arr[i] = runtime->host_api.device_malloc(oi_size); + } + + std::vector last_pa_tasks; + + for (int b_idx = 0; b_idx < batch; b_idx++) { + int cur_seq = host_context_lens[b_idx]; + int bn_this_batch = (cur_seq + block_size - 1) / block_size; + + for (int ht = 0; ht < num_head_tiles; ht++) { + int cur_offset = ht * q_tile_size; + uint8_t* qi_ptr = reinterpret_cast(dev_query) + + static_cast(b_idx * num_heads + cur_offset) * head_dim * sizeof(uint16_t); + uint8_t* out_ptr = reinterpret_cast(dev_attn_out) + + static_cast(b_idx * num_heads + cur_offset) * head_dim * sizeof(float); + int kv_head_idx = cur_offset / (num_heads / kv_head_num); + int accum_idx = b_idx * num_head_tiles + ht; + void* dev_mi = dev_mi_arr[accum_idx]; + void* dev_li = dev_li_arr[accum_idx]; + void* dev_oi = dev_oi_arr[accum_idx]; + + int t_up_prev = -1; + + for (int bn = 0; bn < bn_this_batch; bn++) { + int cur_block_idx = host_block_table[b_idx * max_num_blocks + bn]; + uint8_t* kj_ptr = reinterpret_cast(dev_key_cache) + + (static_cast(cur_block_idx) * block_size * kv_head_num + kv_head_idx) + * head_dim * sizeof(uint16_t); + uint8_t* vj_ptr = reinterpret_cast(dev_value_cache) + + (static_cast(cur_block_idx) * block_size * kv_head_num + kv_head_idx) + * head_dim * sizeof(uint16_t); + + int buf_idx = b_idx * max_num_blocks + bn; + void* dev_sij = dev_sij_arr[buf_idx]; + void* dev_pij = dev_pij_arr[buf_idx]; + void* dev_mij = dev_mij_arr[buf_idx]; + void* dev_lij = dev_lij_arr[buf_idx]; + void* dev_oi_new = dev_oi_new_arr[buf_idx]; + + uint64_t qk_args[6] = { + reinterpret_cast(qi_ptr), + reinterpret_cast(kj_ptr), + reinterpret_cast(dev_sij), + static_cast(q_tile_size), + static_cast(head_dim), + static_cast(block_size) + }; + int t_qk = runtime->add_task(qk_args, 6, FUNC_QK_MATMUL, CoreType::AIC); + + uint64_t sf_args[7] = { + reinterpret_cast(dev_sij), + scale_value_bits, + reinterpret_cast(dev_pij), + reinterpret_cast(dev_mij), + reinterpret_cast(dev_lij), + static_cast(q_tile_size), + static_cast(block_size) + }; + int t_sf = runtime->add_task(sf_args, 7, FUNC_SOFTMAX_PREPARE, CoreType::AIV); + + uint64_t pv_args[6] = { + reinterpret_cast(dev_pij), + reinterpret_cast(vj_ptr), + reinterpret_cast(dev_oi_new), + static_cast(q_tile_size), + static_cast(block_size), + static_cast(head_dim) + }; + int t_pv = runtime->add_task(pv_args, 6, FUNC_PV_MATMUL, CoreType::AIC); + + runtime->add_successor(t_qk, t_sf); + runtime->add_successor(t_sf, t_pv); + + int is_first = (bn == 0) ? 1 : 0; + int is_last = (bn == bn_this_batch - 1) ? 1 : 0; + + uint64_t up_args[11] = { + reinterpret_cast(dev_mij), + reinterpret_cast(dev_lij), + reinterpret_cast(dev_oi_new), + reinterpret_cast(dev_mi), + reinterpret_cast(dev_li), + reinterpret_cast(dev_oi), + static_cast(is_first), + static_cast(is_last), + reinterpret_cast(out_ptr), + static_cast(q_tile_size), + static_cast(head_dim) + }; + int t_up = runtime->add_task(up_args, 11, FUNC_ONLINE_UPDATE, CoreType::AIV); + + runtime->add_successor(t_pv, t_up); + if (t_up_prev >= 0) { + runtime->add_successor(t_up_prev, t_up); + } + t_up_prev = t_up; + } + last_pa_tasks.push_back(t_up_prev); + } + } + + /* Phase 2: Gather */ + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + int32_t zeros[64] = {}; + std::memset(zeros, 0, sizeof(zeros)); + runtime->host_api.copy_to_device(reinterpret_cast(barrier_base), zeros, barrier_size); + + uint64_t args_wmin[3] = { + win_src, + reinterpret_cast(dev_attn_out), + static_cast(GATHER_COUNT) + }; + int t_wmin = runtime->add_task(args_wmin, 3, FUNC_WIN_MEMCOPY_IN, CoreType::AIV); + for (int t : last_pa_tasks) { + runtime->add_successor(t, t_wmin); + } + + uint64_t args_barrier[4] = { + barrier_base, device_ctx_ptr, + static_cast(n_ranks), static_cast(root) + }; + int t_barrier = runtime->add_task(args_barrier, 4, FUNC_COMM_BARRIER, CoreType::AIV); + runtime->add_successor(t_wmin, t_barrier); + + uint64_t args_gather[5] = { + win_dst, win_src, device_ctx_ptr, + static_cast(n_ranks), static_cast(root) + }; + int t_gather = runtime->add_task(args_gather, 5, FUNC_GATHER, CoreType::AIV); + runtime->add_successor(t_barrier, t_gather); + + if (dev_gather_out != nullptr) { + uint64_t args_wmout[3] = { + reinterpret_cast(dev_gather_out), + win_dst, + static_cast(n_ranks * GATHER_COUNT) + }; + int t_wmout = runtime->add_task(args_wmout, 3, FUNC_WIN_MEMCOPY_OUT, CoreType::AIV); + runtime->add_successor(t_gather, t_wmout); + } + + delete[] dev_sij_arr; + delete[] dev_pij_arr; + delete[] dev_mij_arr; + delete[] dev_lij_arr; + delete[] dev_oi_new_arr; + delete[] dev_mi_arr; + delete[] dev_li_arr; + delete[] dev_oi_arr; + + std::cout << "Created paged_attention_gather graph with gather phase\n"; + runtime->print_runtime(); + + return 0; +} + +} // extern "C" diff --git a/examples/scripts/code_runner.py b/examples/scripts/code_runner.py index 7d048e9a..6678d1cc 100644 --- a/examples/scripts/code_runner.py +++ b/examples/scripts/code_runner.py @@ -338,6 +338,10 @@ def __init__( enable_profiling: bool = False, run_all_cases: bool = False, case_name: Optional[str] = None, + n_devices: Optional[int] = None, + first_device_id: Optional[int] = None, + compiled_artifacts: Optional[dict] = None, + prebuilt_dir: Optional[str] = None, ): # Setup logging if not already configured (e.g., when used directly, not via run_example.py) _setup_logging_if_needed() @@ -346,7 +350,12 @@ def __init__( self.golden_path = Path(golden_path).resolve() self.platform = platform self.enable_profiling = enable_profiling + self.run_all_cases = run_all_cases + self.case_name = case_name self.project_root = _get_project_root() + self.compiled_artifacts = compiled_artifacts + self.prebuilt_dir = Path(prebuilt_dir) if prebuilt_dir else None + self._skip_build = compiled_artifacts is not None or (self.prebuilt_dir is not None and self.prebuilt_dir.exists()) # Resolve device ID self.device_id = device_id if device_id is not None else 0 @@ -383,6 +392,9 @@ def __init__( self.aicpu_thread_num = runtime_config.get('aicpu_thread_num', 3) self.block_dim = runtime_config.get('block_dim', 24) self.runtime_name = runtime_config.get('runtime', 'host_build_graph') + # Multi-device: CLI overrides config + self.n_devices = n_devices if n_devices is not None else runtime_config.get('n_devices', 1) + self.first_device_id = first_device_id if first_device_id is not None else runtime_config.get('first_device_id', 0) def _load_kernel_config(self): """Load kernel_config.py from kernels directory.""" @@ -602,86 +614,88 @@ def _build_func_args(self, tensors: Dict[str, torch.Tensor]) -> Tuple[List[int], def run(self) -> None: """ Execute the full test flow: - 1. Check environment - 2. Build runtime - 3. Load runtime and set device - 4. Compile orchestration - 5. Compile and register kernels - 6. For each params in params_list: - - Generate inputs using golden.py - - Initialize and launch runtime - - Finalize and compare with golden + - If compiled_artifacts or prebuilt_dir: skip build, load and run (set_device → init → launch → finalize) + - Else: build first, then run """ - # Import runtime modules (deferred import to avoid top-level dependency) - from runtime_builder import RuntimeBuilder from bindings import bind_host_binary, set_device, launch_runtime - from elf_parser import extract_text_section - - # Auto-setup PTO_ISA_ROOT if needed (for all platforms, since kernels may use PTO ISA headers) - pto_isa_root = _ensure_pto_isa_root(verbose=True) - if pto_isa_root is None: - raise EnvironmentError( - "PTO_ISA_ROOT could not be resolved.\n" - "Please set it to the PTO-ISA root directory, e.g.:\n" - " export PTO_ISA_ROOT=$(pwd)/examples/scripts/_deps/pto-isa" - ) - - # Step 1: Build runtime, orchestration, and kernels in parallel - # (they are independent — all only need kernel_compiler which is ready) - logger.info(f"=== Building Runtime: {self.runtime_name} (platform: {self.platform}) ===") - builder = RuntimeBuilder(platform=self.platform) - kernel_compiler = builder.get_kernel_compiler() - - from concurrent.futures import ThreadPoolExecutor, Future - - runtime_include_dirs = [ - os.path.join(self.project_root, "src", "runtime", self.runtime_name, "runtime") - ] - def _build_runtime(): - return builder.build(self.runtime_name) - - def _compile_orchestration(): - return kernel_compiler.compile_orchestration( - self.runtime_name, - self.orchestration["source"], - ) - - def _compile_one_kernel(kernel): - logger.info(f"Compiling kernel: {kernel['source']} (func_id={kernel['func_id']})") - incore_o = kernel_compiler.compile_incore( - kernel["source"], - core_type=kernel["core_type"], - pto_isa_root=pto_isa_root, - extra_include_dirs=runtime_include_dirs, - ) - if self.platform == "a2a3sim": - kernel_bin = incore_o + if self._skip_build: + if self.compiled_artifacts: + artifacts = self.compiled_artifacts else: - kernel_bin = extract_text_section(incore_o) - return (kernel["func_id"], kernel_bin) + artifacts = _load_artifacts_from_dir(self.prebuilt_dir) + host_binary = artifacts["host_binary"] + orch_so_binary = artifacts["orch_so_binary"] + aicpu_binary = artifacts["aicpu_binary"] + aicore_binary = artifacts["aicore_binary"] + kernel_binaries = artifacts["kernel_binaries"] + orch_func_name = artifacts["orch_func_name"] + logger.info(f"=== Using pre-built artifacts ({len(kernel_binaries)} kernels) ===") + else: + # Build path + from runtime_builder import RuntimeBuilder + from elf_parser import extract_text_section + + pto_isa_root = _ensure_pto_isa_root(verbose=True) + if pto_isa_root is None: + raise EnvironmentError( + "PTO_ISA_ROOT could not be resolved.\n" + "Please set it to the PTO-ISA root directory, e.g.:\n" + " export PTO_ISA_ROOT=$(pwd)/examples/scripts/_deps/pto-isa" + ) - # Launch all compilations concurrently - max_workers = 2 + len(self.kernels) # runtime + orchestration + kernels - with ThreadPoolExecutor(max_workers=max_workers) as executor: - fut_runtime = executor.submit(_build_runtime) - fut_orch = executor.submit(_compile_orchestration) - fut_kernels = [executor.submit(_compile_one_kernel, k) for k in self.kernels] + logger.info(f"=== Building Runtime: {self.runtime_name} (platform: {self.platform}) ===") + builder = RuntimeBuilder(platform=self.platform) + kernel_compiler = builder.get_kernel_compiler() - try: - host_binary, aicpu_binary, aicore_binary = fut_runtime.result() - except Exception as e: - raise RuntimeError( - f"Failed to build runtime '{self.runtime_name}' for platform '{self.platform}'.\n" - f"Error: {e}" - ) from e + from concurrent.futures import ThreadPoolExecutor - orch_so_binary = fut_orch.result() - kernel_binaries = [f.result() for f in fut_kernels] + runtime_include_dirs = [ + os.path.join(self.project_root, "src", "runtime", self.runtime_name, "runtime") + ] + + def _build_runtime(): + return builder.build(self.runtime_name) - logger.info(f"Compiled {len(kernel_binaries)} kernel(s)") + def _compile_orchestration(): + return kernel_compiler.compile_orchestration( + self.runtime_name, + self.orchestration["source"], + ) - # Step 2: Load runtime and set device + def _compile_one_kernel(kernel): + logger.info(f"Compiling kernel: {kernel['source']} (func_id={kernel['func_id']})") + incore_o = kernel_compiler.compile_incore( + kernel["source"], + core_type=kernel["core_type"], + pto_isa_root=pto_isa_root, + extra_include_dirs=runtime_include_dirs, + ) + if self.platform == "a2a3sim": + return (kernel["func_id"], incore_o) + return (kernel["func_id"], extract_text_section(incore_o)) + + max_workers = 2 + len(self.kernels) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + fut_runtime = executor.submit(_build_runtime) + fut_orch = executor.submit(_compile_orchestration) + fut_kernels = [executor.submit(_compile_one_kernel, k) for k in self.kernels] + + try: + host_binary, aicpu_binary, aicore_binary = fut_runtime.result() + except Exception as e: + raise RuntimeError( + f"Failed to build runtime '{self.runtime_name}' for platform '{self.platform}'.\n" + f"Error: {e}" + ) from e + + orch_so_binary = fut_orch.result() + kernel_binaries = [f.result() for f in fut_kernels] + + logger.info(f"Compiled {len(kernel_binaries)} kernel(s)") + orch_func_name = self.orchestration["function_name"] + + # Load runtime and set device logger.info(f"=== Loading Runtime ({len(host_binary)} bytes) ===") Runtime = bind_host_binary(host_binary) @@ -735,7 +749,7 @@ def _compile_one_kernel(kernel): with _temporary_env(run_env): runtime.initialize( orch_so_binary, - self.orchestration["function_name"], + orch_func_name, func_args, arg_types=arg_types, arg_sizes=arg_sizes, @@ -829,9 +843,176 @@ def _compare_with_golden( def create_code_runner(kernels_dir, golden_path, device_id=None, platform="a2a3", - enable_profiling=False, run_all_cases=False, case_name=None): + enable_profiling=False, run_all_cases=False, case_name=None, + n_devices=None, first_device_id=None, + compiled_artifacts=None, prebuilt_dir=None): """Factory: creates a CodeRunner based on kernel_config.""" return CodeRunner(kernels_dir=kernels_dir, golden_path=golden_path, device_id=device_id, platform=platform, enable_profiling=enable_profiling, - run_all_cases=run_all_cases, case_name=case_name) + run_all_cases=run_all_cases, case_name=case_name, + n_devices=n_devices, first_device_id=first_device_id, + compiled_artifacts=compiled_artifacts, prebuilt_dir=prebuilt_dir) + + +def run_on_device( + device_id: int, + artifacts: dict, + kernels_dir: str, + golden_path: str, + platform: str = "a2a3", + enable_profiling: bool = False, + run_all_cases: bool = False, + case_name: Optional[str] = None, +) -> None: + """ + Worker entry for multiprocessing: create CodeRunner and run. + Must be at module level so child process can import without running main(). + """ + runner = create_code_runner( + kernels_dir=kernels_dir, + golden_path=golden_path, + device_id=device_id, + platform=platform, + enable_profiling=enable_profiling, + run_all_cases=run_all_cases, + case_name=case_name, + n_devices=1, + first_device_id=device_id, + compiled_artifacts=artifacts, + ) + runner.run() + + +# ============================================================================= +# PTOCompiler - compile once, run many +# ============================================================================= + +def _write_artifacts_to_dir(artifacts: dict, out_dir: Path) -> None: + """Write compiled artifacts to a directory for subprocess loading.""" + import json + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "host.bin").write_bytes(artifacts["host_binary"]) + (out_dir / "orch.so").write_bytes(artifacts["orch_so_binary"]) + (out_dir / "aicpu.so").write_bytes(artifacts["aicpu_binary"]) + (out_dir / "aicore.bin").write_bytes(artifacts["aicore_binary"]) + for func_id, bin_data in artifacts["kernel_binaries"]: + (out_dir / f"kernel_{func_id}.bin").write_bytes(bin_data) + manifest = { + "orch_func_name": artifacts["orch_func_name"], + "kernel_func_ids": [k[0] for k in artifacts["kernel_binaries"]], + } + (out_dir / "manifest.json").write_text(json.dumps(manifest), encoding="utf-8") + + +def _load_artifacts_from_dir(prebuilt_dir: Path) -> dict: + """Load compiled artifacts from a prebuilt directory.""" + import json + manifest = json.loads((prebuilt_dir / "manifest.json").read_text(encoding="utf-8")) + kernel_binaries = [] + for func_id in manifest["kernel_func_ids"]: + bin_data = (prebuilt_dir / f"kernel_{func_id}.bin").read_bytes() + kernel_binaries.append((func_id, bin_data)) + return { + "host_binary": (prebuilt_dir / "host.bin").read_bytes(), + "orch_so_binary": (prebuilt_dir / "orch.so").read_bytes(), + "aicpu_binary": (prebuilt_dir / "aicpu.so").read_bytes(), + "aicore_binary": (prebuilt_dir / "aicore.bin").read_bytes(), + "kernel_binaries": kernel_binaries, + "orch_func_name": manifest["orch_func_name"], + } + + +class PTOCompiler: + """Compiles PTO runtime, orchestration, and kernels. Returns artifacts for Runner.""" + + def __init__( + self, + kernels_dir: str, + platform: str = "a2a3", + ): + self.kernels_dir = Path(kernels_dir).resolve() + self.platform = platform + self.project_root = _get_project_root() + self._kernel_config = _load_module_from_path( + self.kernels_dir / "kernel_config.py", f"kernel_config_compiler_{id(self)}" + ) + self.kernels = self._kernel_config.KERNELS + self.orchestration = self._kernel_config.ORCHESTRATION + runtime_config = getattr(self._kernel_config, "RUNTIME_CONFIG", {}) + self.runtime_name = runtime_config.get("runtime", "host_build_graph") + + def compile(self) -> dict: + """Build runtime, orchestration, and kernels. Return artifacts dict.""" + from runtime_builder import RuntimeBuilder + from elf_parser import extract_text_section + + pto_isa_root = _ensure_pto_isa_root(verbose=True) + if pto_isa_root is None: + raise EnvironmentError( + "PTO_ISA_ROOT could not be resolved.\n" + "Please set it to the PTO-ISA root directory." + ) + + logger.info(f"=== PTOCompiler: Building {self.runtime_name} (platform: {self.platform}) ===") + builder = RuntimeBuilder(platform=self.platform) + kernel_compiler = builder.get_kernel_compiler() + + from concurrent.futures import ThreadPoolExecutor + + runtime_include_dirs = [ + os.path.join(self.project_root, "src", "runtime", self.runtime_name, "runtime") + ] + + def _build_runtime(): + return builder.build(self.runtime_name) + + def _compile_orchestration(): + return kernel_compiler.compile_orchestration( + self.runtime_name, + self.orchestration["source"], + ) + + def _compile_one_kernel(kernel): + logger.info(f"Compiling kernel: {kernel['source']} (func_id={kernel['func_id']})") + incore_o = kernel_compiler.compile_incore( + kernel["source"], + core_type=kernel["core_type"], + pto_isa_root=pto_isa_root, + extra_include_dirs=runtime_include_dirs, + ) + if self.platform == "a2a3sim": + return (kernel["func_id"], incore_o) + return (kernel["func_id"], extract_text_section(incore_o)) + + max_workers = 2 + len(self.kernels) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + fut_runtime = executor.submit(_build_runtime) + fut_orch = executor.submit(_compile_orchestration) + fut_kernels = [executor.submit(_compile_one_kernel, k) for k in self.kernels] + + try: + host_binary, aicpu_binary, aicore_binary = fut_runtime.result() + except Exception as e: + raise RuntimeError( + f"Failed to build runtime '{self.runtime_name}': {e}" + ) from e + + orch_so_binary = fut_orch.result() + kernel_binaries = [f.result() for f in fut_kernels] + + logger.info(f"PTOCompiler: Compiled {len(kernel_binaries)} kernel(s)") + + return { + "host_binary": host_binary, + "orch_so_binary": orch_so_binary, + "aicpu_binary": aicpu_binary, + "aicore_binary": aicore_binary, + "kernel_binaries": kernel_binaries, + "orch_func_name": self.orchestration["function_name"], + } + + +def create_compiler(kernels_dir, platform="a2a3"): + """Factory: creates a PTOCompiler.""" + return PTOCompiler(kernels_dir=kernels_dir, platform=platform) diff --git a/examples/scripts/comm_include/hccl_context.h b/examples/scripts/comm_include/hccl_context.h new file mode 100644 index 00000000..5f93290f --- /dev/null +++ b/examples/scripts/comm_include/hccl_context.h @@ -0,0 +1,21 @@ +/** + * HcclDeviceContext - device-side context for HCCL collective ops. + * Extracted from pto-comm-isa for use in simpler-PTO cpt_and_comm. + */ + +#pragma once + +#include + +static constexpr uint32_t HCCL_MAX_RANK_NUM = 64; + +struct HcclDeviceContext { + uint64_t workSpace; + uint64_t workSpaceSize; + + uint32_t rankId; + uint32_t rankNum; + uint64_t winSize; + uint64_t windowsIn[HCCL_MAX_RANK_NUM]; + uint64_t windowsOut[HCCL_MAX_RANK_NUM]; +}; diff --git a/examples/scripts/comm_include/hccl_helpers.h b/examples/scripts/comm_include/hccl_helpers.h new file mode 100644 index 00000000..3f4f9012 --- /dev/null +++ b/examples/scripts/comm_include/hccl_helpers.h @@ -0,0 +1,36 @@ +/** + * HCCL device-side helpers: HcclRemotePtr, WindowAlloc. + * Extracted from pto-comm-isa common.hpp for use in simpler-PTO cpt_and_comm. + */ + +#pragma once + +#include +#include + +#include "hccl_context.h" + +#ifndef AICORE +#define AICORE +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +// Convert local window pointer to remote rank's equivalent address +template +AICORE inline __gm__ T* HcclRemotePtr(__gm__ HcclDeviceContext* ctx, __gm__ T* localPtr, int pe) { + // TGATHER source tensors are laid out in windowsIn. + uint64_t localBase = ctx->windowsIn[ctx->rankId]; + uint64_t peerBase = ctx->windowsIn[pe]; + uint64_t offset = (uint64_t)localPtr - localBase; + return (__gm__ T*)(peerBase + offset); +} + +// Allocate from window at (windowBase + offset), advance offset +inline void* WindowAlloc(uint64_t windowBase, size_t& offset, size_t bytes) { + void* ptr = reinterpret_cast(windowBase + offset); + offset += bytes; + return ptr; +} diff --git a/examples/scripts/hccl_bindings.py b/examples/scripts/hccl_bindings.py new file mode 100644 index 00000000..0d35c65f --- /dev/null +++ b/examples/scripts/hccl_bindings.py @@ -0,0 +1,167 @@ +""" +HCCL Python bindings for multi-card communication setup. + +Prefer C++ helper lib (libhccl_helper.so) — same link as pto-comm-isa (ascendcl, hcomm, runtime). +Build: cd examples/scripts/hccl_helper && mkdir build && cd build && cmake .. && make +Then run with CANN env: source .../set_env.sh + +Usage: + from hccl_bindings import hccl_get_root_info, hccl_init_comm, hccl_barrier, HCCL_ROOT_INFO_BYTES +""" + +import ctypes +import os +import sys +from ctypes import ( + POINTER, + c_void_p, + c_uint32, + c_int, + c_uint64, + c_char_p, + Structure, + create_string_buffer, +) +from pathlib import Path +from typing import Optional, Tuple + +# Set after loading libhccl_helper +HCCL_ROOT_INFO_BYTES = 1024 + +_lib_helper = None + + +def _find_helper_so() -> Optional[Path]: + """Locate libhccl_helper.so next to this script or in hccl_helper/build.""" + script_dir = Path(__file__).resolve().parent + candidates = [ + script_dir / "hccl_helper" / "build" / "libhccl_helper.so", + script_dir / "build" / "libhccl_helper.so", + script_dir / "libhccl_helper.so", + ] + for p in candidates: + if p.exists(): + return p + return None + + +def _load_helper(): + """Load libhccl_helper.so (C++ helper, same link as pto-comm-isa).""" + global _lib_helper, HCCL_ROOT_INFO_BYTES + if _lib_helper is not None: + return + + path = _find_helper_so() + if path is None: + raise RuntimeError( + "libhccl_helper.so not found. Build it with CANN env set:\n" + " source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash\n" + " cd examples/scripts/hccl_helper && mkdir build && cd build\n" + " cmake .. && make\n" + "Then run your script with the same CANN env (source setenv.bash)." + ) + try: + _lib_helper = ctypes.CDLL(str(path)) + except OSError as e: + raise RuntimeError( + f"Failed to load {path}: {e}\n" + "Ensure CANN env is set (source .../setenv.bash) so dependencies (ascendcl, hcomm, runtime) can be found." + ) from e + + # C API + _lib_helper.hccl_helper_root_info_bytes.restype = c_uint32 + HCCL_ROOT_INFO_BYTES = _lib_helper.hccl_helper_root_info_bytes() + + _lib_helper.hccl_helper_get_root_info.argtypes = [c_int, c_void_p, c_uint32] + _lib_helper.hccl_helper_get_root_info.restype = c_int + + _lib_helper.hccl_helper_init_comm.argtypes = [ + c_int, c_int, c_int, c_int, # rank_id, n_ranks, n_devices, first_device_id + c_void_p, c_uint32, # root_info, root_info_len + POINTER(c_void_p), POINTER(c_void_p), POINTER(c_uint64), POINTER(c_uint64), POINTER(c_void_p), POINTER(c_int), + ] + _lib_helper.hccl_helper_init_comm.restype = c_int + + _lib_helper.hccl_helper_barrier.argtypes = [c_void_p, c_void_p] + _lib_helper.hccl_helper_barrier.restype = c_int + + +def hccl_get_root_info(device_id: int) -> bytes: + """ + Rank 0: get HcclRootInfo (C++ helper sets device and calls HcclGetRootInfo). + + Returns: + bytes of length HCCL_ROOT_INFO_BYTES + """ + _load_helper() + buf = create_string_buffer(HCCL_ROOT_INFO_BYTES) + ret = _lib_helper.hccl_helper_get_root_info(device_id, buf, HCCL_ROOT_INFO_BYTES) + if ret != 0: + raise RuntimeError(f"hccl_helper_get_root_info failed: {ret}") + return buf.raw[:HCCL_ROOT_INFO_BYTES] + + +def hccl_init_comm( + rank_id: int, + n_ranks: int, + n_devices: int, + first_device_id: int, + root_info: bytes, +) -> Tuple[int, int, int, int, int, int]: + """ + All ranks: init HCCL comm (same link as pto-comm-isa). + + Returns: + (comm, device_ctx_ptr, win_in_base, win_out_base, stream, actual_rank_id) as integers. + """ + _load_helper() + if len(root_info) < HCCL_ROOT_INFO_BYTES: + raise ValueError(f"root_info must be at least {HCCL_ROOT_INFO_BYTES} bytes") + + comm = c_void_p() + ctx_ptr = c_void_p() + win_in_base = c_uint64() + win_out_base = c_uint64() + stream = c_void_p() + actual_rank_id = c_int(-1) + + buf = create_string_buffer(len(root_info)) + # copy bytes into mutable buffer + ctypes.memmove(buf, root_info, len(root_info)) + + ret = _lib_helper.hccl_helper_init_comm( + rank_id, + n_ranks, + n_devices, + first_device_id, + buf, + len(root_info), + ctypes.byref(comm), + ctypes.byref(ctx_ptr), + ctypes.byref(win_in_base), + ctypes.byref(win_out_base), + ctypes.byref(stream), + ctypes.byref(actual_rank_id), + ) + if ret != 0: + raise RuntimeError(f"hccl_helper_init_comm failed: {ret}") + + return ( + comm.value or 0, + ctx_ptr.value or 0, + win_in_base.value, + win_out_base.value, + stream.value or 0, + actual_rank_id.value, + ) + + +def hccl_barrier(comm_handle: int, stream_handle: int) -> None: + """HcclBarrier + stream sync (C++ helper).""" + _load_helper() + ret = _lib_helper.hccl_helper_barrier( + ctypes.c_void_p(comm_handle), + ctypes.c_void_p(stream_handle), + ) + if ret != 0: + raise RuntimeError(f"hccl_helper_barrier failed: {ret}") diff --git a/examples/scripts/hccl_helper/CMakeLists.txt b/examples/scripts/hccl_helper/CMakeLists.txt new file mode 100644 index 00000000..f6662d66 --- /dev/null +++ b/examples/scripts/hccl_helper/CMakeLists.txt @@ -0,0 +1,24 @@ +# Build libhccl_helper.so — same link as pto-comm-isa (ascendcl, hcomm, runtime). +# Run with CANN env: source .../setenv.bash, then: +# mkdir build && cd build && cmake .. && make + +cmake_minimum_required(VERSION 3.16) +project(hccl_helper CXX) + +set(CMAKE_CXX_STANDARD 17) + +if(DEFINED ENV{ASCEND_HOME_PATH}) + set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) +elseif(DEFINED ENV{ASCEND_HOME}) + set(ASCEND_HOME_PATH $ENV{ASCEND_HOME}) +else() + message(FATAL_ERROR "ASCEND_HOME_PATH or ASCEND_HOME not set. Run: source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash") +endif() + +add_library(hccl_helper SHARED hccl_helper.cpp) +target_include_directories(hccl_helper PRIVATE + ${ASCEND_HOME_PATH}/include + ${ASCEND_HOME_PATH}/include/hccl +) +target_link_directories(hccl_helper PRIVATE ${ASCEND_HOME_PATH}/lib64) +target_link_libraries(hccl_helper PRIVATE ascendcl hcomm runtime) diff --git a/examples/scripts/hccl_helper/README.md b/examples/scripts/hccl_helper/README.md new file mode 100644 index 00000000..e43f2437 --- /dev/null +++ b/examples/scripts/hccl_helper/README.md @@ -0,0 +1,32 @@ +# libhccl_helper + +C++ 辅助库,与 pto-comm-isa 相同方式链接(ascendcl、hcomm、runtime),供 Python 通过 ctypes 调用。不依赖 Python 侧直接加载 libacl.so/libhccl.so。 + +## 编译 + +在已 source CANN 环境(与跑 pto-comm-isa comm case 相同)下: + +```bash +source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash +cd examples/scripts/hccl_helper +mkdir build && cd build +cmake .. +make +``` + +生成 `libhccl_helper.so`。运行多卡脚本前同样需要 `source .../setenv.bash`,以便运行时找到 ascendcl、hcomm、runtime 等依赖。 + +## 依赖 + +- CANN:需设置 `ASCEND_HOME_PATH`(一般由 `setenv.bash` 设置) +- 与 pto-comm-isa 的 a2a3 comm 用例相同 + +## Python 使用 + +`hccl_bindings.py` 会优先在以下位置查找 `libhccl_helper.so`: + +- `examples/scripts/hccl_helper/build/libhccl_helper.so` +- `examples/scripts/build/libhccl_helper.so` +- `examples/scripts/libhccl_helper.so` + +找到则使用 C++ 实现;未找到则报错并提示先编译本库。 diff --git a/examples/scripts/hccl_helper/hccl_helper.cpp b/examples/scripts/hccl_helper/hccl_helper.cpp new file mode 100644 index 00000000..9788f1a4 --- /dev/null +++ b/examples/scripts/hccl_helper/hccl_helper.cpp @@ -0,0 +1,489 @@ +/** + * HCCL helper shared library — same link as pto-comm-isa (ascendcl, hcomm, runtime). + * Python loads this .so and calls the C API; no direct libacl/libhccl in Python. + * + * Build: from examples/scripts/hccl_helper with CANN env set (source set_env.sh): + * mkdir build && cd build && cmake .. && make + * Output: libhccl_helper.so + */ + +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "hccl/hccl_comm.h" +#include "hccl/hccl_types.h" + +#define HCCL_HELPER_ROOT_INFO_BYTES 1024 + +using CommTopo = uint32_t; +static constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0; +static constexpr uint32_t COMM_TOPO_MESH = 0b1u; +static constexpr uint32_t GROUP_NAME_SIZE = 128U; +static constexpr uint32_t ALG_CONFIG_SIZE = 128U; +static constexpr uint32_t MAX_CC_TILING_NUM = 8U; + +extern "C" int rtSetDevice(int32_t device); +extern "C" int rtStreamCreate(void** stream, int32_t priority); +extern "C" int rtStreamDestroy(void* stream); +extern "C" int HcclAllocComResourceByTiling(void* comm, void* stream, void* mc2Tiling, void** commContext); +extern "C" int HcomGetCommHandleByGroup(const char* group, void** commHandle); +extern "C" int HcomGetL0TopoTypeEx(const char* group, CommTopo* topoType, uint32_t isSetDevice); + +struct Mc2InitTilingInner { + uint32_t version; + uint32_t mc2HcommCnt; + uint32_t offset[MAX_CC_TILING_NUM]; + uint8_t debugMode; + uint8_t preparePosition; + uint16_t queueNum; + uint16_t commBlockNum; + uint8_t devType; + char reserved[17]; +}; + +struct Mc2cCTilingInner { + uint8_t skipLocalRankCopy; + uint8_t skipBufferWindowCopy; + uint8_t stepSize; + uint8_t version; + char reserved[9]; + uint8_t commEngine; + uint8_t srcDataType; + uint8_t dstDataType; + char groupName[GROUP_NAME_SIZE]; + char algConfig[ALG_CONFIG_SIZE]; + uint32_t opType; + uint32_t reduceType; +}; + +struct Mc2CommConfigV2 { + Mc2InitTilingInner init; + Mc2cCTilingInner inner; +}; + +static constexpr uint32_t HCCL_MAX_RANK_NUM = 64; +struct HcclDeviceContext { + uint64_t workSpace; + uint64_t workSpaceSize; + uint32_t rankId; + uint32_t rankNum; + uint64_t winSize; + uint64_t windowsIn[HCCL_MAX_RANK_NUM]; + uint64_t windowsOut[HCCL_MAX_RANK_NUM]; +}; + +static constexpr int RT_STREAM_PRIORITY_DEFAULT = 0; + +// --------------------------------------------------------------------------- +// HcclOpResParam compat structs — binary-compatible copies of HCCL internal +// types (from pto-comm-isa common.hpp). Used only on host side to compute +// windowsIn[...] for RING topology. +// --------------------------------------------------------------------------- + +namespace hccl_compat { + +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +static constexpr uint32_t COMPAT_LOCAL_NOTIFY_MAX_NUM = 64; +static constexpr uint32_t COMPAT_LOCAL_STREAM_MAX_NUM = 19; +static constexpr uint32_t COMPAT_AICPU_OP_NOTIFY_MAX_NUM = 2; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[COMPAT_LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[COMPAT_LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[COMPAT_AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; +}; + +struct AlgoTopoInfo { + uint32_t userRank; + uint32_t userRankSize; + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interXLinkDisable; + uint32_t floatOverflowMode; + uint32_t multiQpThreshold; +}; + +struct HDCommunicateParams { + uint64_t hostAddr; + uint64_t deviceAddr; + uint64_t readCacheAddr; + uint32_t devMemSize; + uint32_t buffLen; + uint32_t flag; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HcclMC2WorkSpace { + uint64_t workspace; + uint64_t workspaceSize; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParamHead { + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; +}; + +// Full struct layout for offsetof(remoteRes) computation. +// Array size of remoteRes does not affect the offset calculation. +struct HcclOpResParam { + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[1]; +}; + +} // namespace hccl_compat + +// --------------------------------------------------------------------------- +// C API for Python ctypes +// --------------------------------------------------------------------------- + +extern "C" { + +unsigned int hccl_helper_root_info_bytes(void) { + return HCCL_HELPER_ROOT_INFO_BYTES; +} + +// Rank 0: set device, get root info into out_buf. Returns 0 on success. +int hccl_helper_get_root_info(int device_id, void* out_buf, unsigned buf_size) { + if (out_buf == nullptr || buf_size < HCCL_HELPER_ROOT_INFO_BYTES) + return -EINVAL; + aclError e = aclrtSetDevice(device_id); + if (e != ACL_SUCCESS) + return -static_cast(e); + // HcclGetRootInfo expects HcclRootInfo* (opaque struct, typically 1024 bytes) + auto* root = reinterpret_cast(out_buf); + int ret = HcclGetRootInfo(root); + return (ret == HCCL_SUCCESS) ? 0 : -ret; +} + +// All ranks: init comm. Fills out_comm, out_ctx_ptr, out_win_in_base, out_win_out_base, out_stream. +// Returns 0 on success. +// root_info from rank 0 (hccl_helper_get_root_info). No MPI — Python uses Barrier. +int hccl_helper_init_comm( + int rank_id, + int n_ranks, + int n_devices, + int first_device_id, + const void* root_info, + unsigned root_info_len, + void** out_comm, + void** out_ctx_ptr, + uint64_t* out_win_in_base, + uint64_t* out_win_out_base, + void** out_stream, + int* out_actual_rank_id +) { + if (out_comm == nullptr || out_ctx_ptr == nullptr || + out_win_in_base == nullptr || out_win_out_base == nullptr || + out_stream == nullptr || + out_actual_rank_id == nullptr || + root_info == nullptr || root_info_len < HCCL_HELPER_ROOT_INFO_BYTES) + return -EINVAL; + + int device_id = rank_id % n_devices + first_device_id; + + int rtRet = rtSetDevice(device_id); + if (rtRet != 0) + return -rtRet; + + void* stream = nullptr; + rtRet = rtStreamCreate(&stream, RT_STREAM_PRIORITY_DEFAULT); + if (rtRet != 0 || stream == nullptr) + return rtRet != 0 ? -rtRet : -1; + + HcclComm comm = nullptr; + auto* root = reinterpret_cast(root_info); + int hret = HcclCommInitRootInfo( + static_cast(n_ranks), + root, + static_cast(rank_id), + &comm); + if (hret != HCCL_SUCCESS || comm == nullptr) { + rtStreamDestroy(stream); + return hret != HCCL_SUCCESS ? -hret : -1; + } + + char group[128] = {}; + hret = HcclGetCommName(comm, group); + if (hret != HCCL_SUCCESS) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -hret; + } + + CommTopo topo = 0; + hret = HcomGetL0TopoTypeEx(group, &topo, COMM_IS_NOT_SET_DEVICE); + if (hret != HCCL_SUCCESS) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -hret; + } + + void* commHandle = nullptr; + hret = HcomGetCommHandleByGroup(group, &commHandle); + if (hret != HCCL_SUCCESS) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -hret; + } + + Mc2CommConfigV2 tiling{}; + memset(&tiling, 0, sizeof(tiling)); + tiling.init.version = 100U; + tiling.init.mc2HcommCnt = 1U; + tiling.init.commBlockNum = 48U; + tiling.init.devType = 4U; + tiling.init.offset[0] = static_cast( + reinterpret_cast(&tiling.inner) - reinterpret_cast(&tiling.init)); + tiling.inner.opType = 18U; + tiling.inner.commEngine = 3U; + tiling.inner.version = 1U; + strncpy(tiling.inner.groupName, group, GROUP_NAME_SIZE - 1); + strncpy(tiling.inner.algConfig, "BatchWrite=level0:fullmesh", ALG_CONFIG_SIZE - 1); + + void* ctxPtr = nullptr; + hret = HcclAllocComResourceByTiling(commHandle, stream, &tiling, &ctxPtr); + if (hret != HCCL_SUCCESS || ctxPtr == nullptr) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return hret != HCCL_SUCCESS ? -hret : -1; + } + + // Build host-side HcclDeviceContext for both MESH and RING topo. + HcclDeviceContext hostCtx{}; + void* deviceCtxPtr = nullptr; + + if (topo == COMM_TOPO_MESH) { + // MESH: ctxPtr is HcclCombinOpParamA5 whose first fields match HcclDeviceContext. + aclError aRet = aclrtMemcpy(&hostCtx, sizeof(hostCtx), ctxPtr, sizeof(hostCtx), ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -static_cast(aRet); + } + deviceCtxPtr = ctxPtr; + } else { + // RING: ctxPtr is HcclOpResParam. Extract remote windows and build our own HcclDeviceContext. + using namespace hccl_compat; + auto* rawCtx = reinterpret_cast(ctxPtr); + + // 1. Read HcclOpResParam head (from localUsrRankId through localWindowsExp). + HcclOpResParamHead head{}; + const size_t headOff = offsetof(HcclOpResParam, localUsrRankId); + aclError aRet = aclrtMemcpy(&head, sizeof(head), rawCtx + headOff, sizeof(head), ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -static_cast(aRet); + } + + if (head.rankSize == 0 || head.rankSize > HCCL_MAX_RANK_NUM) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -EINVAL; + } + + // 2. Read remoteRes[0..rankSize-1] (array of device-pointer pairs). + const size_t remoteResOff = offsetof(HcclOpResParam, remoteRes); + const size_t remoteResBytes = head.rankSize * sizeof(RemoteResPtr); + std::vector remoteResArr(head.rankSize); + + aRet = aclrtMemcpy(remoteResArr.data(), remoteResBytes, rawCtx + remoteResOff, remoteResBytes, + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -static_cast(aRet); + } + + // 3. Build hostCtx with correct per-rank RDMA window addresses. + std::memset(&hostCtx, 0, sizeof(hostCtx)); + + // Read mc2WorkSpace (first 16 bytes of HcclOpResParam). + uint64_t wsFields[2] = {0, 0}; + aRet = aclrtMemcpy(wsFields, sizeof(wsFields), rawCtx, sizeof(wsFields), ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet == ACL_SUCCESS) { + hostCtx.workSpace = wsFields[0]; + hostCtx.workSpaceSize = wsFields[1]; + } + + hostCtx.rankId = head.localUsrRankId; + hostCtx.rankNum = head.rankSize; + hostCtx.winSize = head.winSize; + + for (uint32_t i = 0; i < head.rankSize; ++i) { + if (i == head.localUsrRankId) { + hostCtx.windowsIn[i] = head.localWindowsIn; + hostCtx.windowsOut[i] = head.localWindowsOut; + continue; + } + + uint64_t devPtr = remoteResArr[i].nextDevicePtr; + if (devPtr == 0) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -EINVAL; + } + + HcclRankRelationResV2 remoteInfo{}; + aRet = aclrtMemcpy(&remoteInfo, sizeof(remoteInfo), reinterpret_cast(devPtr), sizeof(remoteInfo), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -static_cast(aRet); + } + + hostCtx.windowsIn[i] = remoteInfo.windowsIn; + hostCtx.windowsOut[i] = remoteInfo.windowsOut; + } + + // 4. Allocate new device memory and copy our correctly-built HcclDeviceContext. + void* newDevMem = nullptr; + aRet = aclrtMalloc(&newDevMem, sizeof(HcclDeviceContext), ACL_MEM_MALLOC_HUGE_FIRST); + if (aRet != ACL_SUCCESS || newDevMem == nullptr) { + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -static_cast(aRet); + } + + aRet = aclrtMemcpy(newDevMem, sizeof(HcclDeviceContext), &hostCtx, sizeof(HcclDeviceContext), + ACL_MEMCPY_HOST_TO_DEVICE); + if (aRet != ACL_SUCCESS) { + aclrtFree(newDevMem); + HcclCommDestroy(comm); + rtStreamDestroy(stream); + return -static_cast(aRet); + } + + deviceCtxPtr = newDevMem; + } + + *out_comm = comm; + *out_ctx_ptr = deviceCtxPtr; + *out_win_in_base = hostCtx.windowsIn[hostCtx.rankId]; + *out_win_out_base = (hostCtx.windowsOut[hostCtx.rankId] != 0) + ? hostCtx.windowsOut[hostCtx.rankId] + : hostCtx.windowsIn[hostCtx.rankId]; + *out_stream = stream; + *out_actual_rank_id = static_cast(hostCtx.rankId); + return 0; +} + +// Barrier + stream sync. +int hccl_helper_barrier(void* comm, void* stream) { + if (comm == nullptr || stream == nullptr) + return -EINVAL; + int hret = HcclBarrier(comm, stream); + if (hret != HCCL_SUCCESS) + return -hret; + aclError e = aclrtSynchronizeStream(stream); + return (e == ACL_SUCCESS) ? 0 : -static_cast(e); +} + +} // extern "C" diff --git a/examples/scripts/kill_python_processes.ps1 b/examples/scripts/kill_python_processes.ps1 new file mode 100644 index 00000000..8fa398d4 --- /dev/null +++ b/examples/scripts/kill_python_processes.ps1 @@ -0,0 +1,13 @@ +# Batch kill Python processes (e.g. runaway multi_bgemm subprocesses). +# Run in PowerShell: .\kill_python_processes.ps1 +# Or: powershell -ExecutionPolicy Bypass -File kill_python_processes.ps1 + +$procs = Get-Process -Name python*, py -ErrorAction SilentlyContinue +if ($procs) { + $count = ($procs | Measure-Object).Count + Write-Host "Found $count Python process(es). Killing..." + $procs | Stop-Process -Force + Write-Host "Done." +} else { + Write-Host "No Python processes found." +} diff --git a/examples/scripts/multi_card_code_runner.py b/examples/scripts/multi_card_code_runner.py new file mode 100644 index 00000000..aee333bf --- /dev/null +++ b/examples/scripts/multi_card_code_runner.py @@ -0,0 +1,1131 @@ +""" +CodeRunner - Simplified test framework for PTO runtime tests (multi-card version). + +This module provides a simplified interface for writing runtime tests. +Users only need to provide: +1. A kernels directory with kernel_config.py +2. A golden.py script with generate_inputs() and compute_golden() + +Usage: + # Command line + python examples/scripts/multi_card_run_example.py --kernels ./my_test/kernels --golden ./my_test/golden.py + + # In Python + from multi_card_code_runner import CodeRunner + runner = CodeRunner("./kernels", "./golden.py") + runner.run() + +Golden.py interface: + # Required functions + def generate_inputs(params: dict) -> list: + '''Return flat argument list — tensors as (name, tensor) tuples, scalars as ctypes typed values''' + a = torch.tensor(...) + b = torch.tensor(...) + out_f = torch.zeros(...) + return [ + ("a", a), + ("b", b), + ("out_f", out_f), + ("size_a", ctypes.c_int64(a.nbytes)), + ("size_b", ctypes.c_int64(b.nbytes)), + ("size_f", ctypes.c_int64(out_f.nbytes)), + ("SIZE", ctypes.c_int64(a.numel())), + ] + + def compute_golden(tensors: dict, params: dict) -> None: + '''Compute expected outputs in-place''' + tensors["out_f"][:] = tensors["a"] + tensors["b"] + + # Optional configuration + ALL_CASES = {"Case1": {"size": 1024}, "Case2": {"size": 2048}} # Multiple test cases + DEFAULT_CASE = "Case1" # Default case to run + RTOL = 1e-5 # Relative tolerance + ATOL = 1e-5 # Absolute tolerance + __outputs__ = ["out_f"] # Explicit output names (or use 'out_' prefix) +""" + +import importlib.util +import logging +import os +import sys +import time +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +import numpy as np + +logger = logging.getLogger(__name__) + + +def _setup_logging_if_needed() -> None: + """ + Setup logging if not already configured (for direct CodeRunner usage). + Uses PTO_LOG_LEVEL environment variable or defaults to 'info'. + """ + # Only setup if logging hasn't been configured yet + if not logging.getLogger().hasHandlers(): + level_str = os.environ.get('PTO_LOG_LEVEL', 'info') + level_map = { + 'error': logging.ERROR, + 'warn': logging.WARNING, + 'info': logging.INFO, + 'debug': logging.DEBUG, + } + log_level = level_map.get(level_str.lower(), logging.INFO) + logging.basicConfig( + level=log_level, + format='[%(levelname)s] %(message)s', + force=True + ) + + +def _to_torch(tensor) -> torch.Tensor: + """Convert tensor to torch.Tensor, handling bfloat16 and other tensor types.""" + if isinstance(tensor, torch.Tensor): + # Already a torch tensor, ensure it's on CPU and contiguous + return tensor.cpu().contiguous() + + # For any non-torch tensor, try direct torch conversion first + # This handles most array-like objects including numpy arrays + try: + return torch.as_tensor(tensor) + except (TypeError, RuntimeError): + # If direct conversion fails, fall back to numpy path + import numpy as np + arr = np.asarray(tensor) + return torch.from_numpy(arr) + + +def _load_module_from_path(module_path: Path, module_name: str): + """Dynamically load a Python module from file path.""" + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load module from {module_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _get_project_root() -> Path: + """Get the project root directory.""" + return Path(__file__).parent.parent.parent # examples/scripts/ -> examples/ -> simpler/ + + +def _get_pto_isa_clone_path() -> Path: + """Get the expected path to pto-isa clone.""" + return _get_project_root() / "examples" / "scripts" / "_deps" / "pto-isa" + + +def _is_pto_isa_cloned() -> bool: + """ + Check if pto-isa is cloned. + + A clone is considered valid if: + 1. The directory exists + 2. It contains the include directory (essential content) + """ + clone_path = _get_pto_isa_clone_path() + if not clone_path.exists(): + return False + + # Check for essential content + include_dir = clone_path / "include" + return include_dir.exists() and include_dir.is_dir() + + +def _is_git_available() -> bool: + """Check if git command is available.""" + try: + import subprocess + result = subprocess.run( + ["git", "--version"], + capture_output=True, + text=True, + timeout=5 + ) + return result.returncode == 0 + except (FileNotFoundError, subprocess.TimeoutExpired): + return False + + +_PTO_ISA_REPO = "https://gitcode.com/cann/pto-isa.git" + + +def _clone_pto_isa(verbose: bool = False) -> bool: + """ + Clone pto-isa repository. + + Args: + verbose: Print detailed progress information + + Returns: + True if successful, False otherwise + """ + import subprocess + + if not _is_git_available(): + if verbose: + logger.warning("git command not available, cannot clone pto-isa") + return False + + clone_path = _get_pto_isa_clone_path() + + # Create parent deps directory if it doesn't exist + deps_dir = clone_path.parent + try: + deps_dir.mkdir(parents=True, exist_ok=True) + except Exception as e: + if verbose: + logger.warning(f"Failed to create deps directory: {e}") + return False + + try: + if verbose: + logger.info(f"Cloning pto-isa to {clone_path}...") + logger.info("This may take a few moments on first run...") + + result = subprocess.run( + [ + "git", "clone", + _PTO_ISA_REPO, + str(clone_path) + ], + capture_output=True, + text=True, + timeout=300 # 5 minutes timeout + ) + + if result.returncode != 0: + if verbose: + logger.warning(f"Failed to clone pto-isa:\n{result.stderr}") + return False + + if verbose: + logger.info(f"pto-isa cloned successfully to: {clone_path}") + + return True + + except subprocess.TimeoutExpired: + if verbose: + logger.warning("Clone operation timed out") + return False + except Exception as e: + if verbose: + logger.warning(f"Failed to clone pto-isa: {e}") + return False + + +def _ensure_pto_isa_root(verbose: bool = False) -> Optional[str]: + """ + Ensure PTO_ISA_ROOT is available, either from environment or cloned repo. + + This function: + 1. Checks if PTO_ISA_ROOT is already set + 2. If not, tries to clone pto-isa repository + 3. Returns the resolved path + + Args: + verbose: Print detailed progress information + + Returns: + PTO_ISA_ROOT path if successful, None otherwise + """ + # Check if already set in environment + existing_root = os.environ.get("PTO_ISA_ROOT") + if existing_root: + if verbose: + logger.info(f"Using existing PTO_ISA_ROOT: {existing_root}") + return existing_root + + # Try to use cloned repository + clone_path = _get_pto_isa_clone_path() + + # Clone if needed + if not _is_pto_isa_cloned(): + if verbose: + logger.info("PTO_ISA_ROOT not set, cloning pto-isa repository...") + if not _clone_pto_isa(verbose=verbose): + if verbose: + logger.warning("Failed to automatically clone pto-isa.") + logger.warning("You can manually clone it with:") + logger.warning(f" mkdir -p {clone_path.parent}") + logger.warning(f" git clone {_PTO_ISA_REPO} {clone_path}") + logger.warning("Or set PTO_ISA_ROOT to an existing pto-isa installation:") + logger.warning(" export PTO_ISA_ROOT=/path/to/pto-isa") + return None + + # Verify clone has expected content + include_dir = clone_path / "include" + if not include_dir.exists(): + if verbose: + logger.warning(f"pto-isa cloned but missing include directory: {include_dir}") + return None + + return str(clone_path.resolve()) + + +def _kernel_config_runtime_env(kernel_config_module, kernels_dir: Path) -> Dict[str, str]: + """ + Optional per-example environment variables for runtime compilation. + + `kernel_config.py` may define: + RUNTIME_ENV = {"ENV_KEY": "value", ...} + + If a value looks like a path (ENV key ends with _DIR/_PATH) + and is not absolute, it is resolved relative to + `kernels_dir`. + """ + runtime_env = getattr(kernel_config_module, "RUNTIME_ENV", None) + if not isinstance(runtime_env, dict): + return {} + + out: Dict[str, str] = {} + for k, v in runtime_env.items(): + if not isinstance(k, str): + continue + s = str(v) + is_path_like = k.endswith("_DIR") or k.endswith("_PATH") + if is_path_like and s: + p = Path(s) + if not p.is_absolute(): + s = str((kernels_dir / p).resolve()) + out[k] = s + return out + + +@contextmanager +def _temporary_env(env_updates: Dict[str, str]): + """Temporarily apply env vars for the duration of the context.""" + old = {k: os.environ.get(k) for k in env_updates.keys()} + for k, v in env_updates.items(): + os.environ[k] = v + try: + yield + finally: + for k, prev in old.items(): + if prev is None: + os.environ.pop(k, None) + else: + os.environ[k] = prev + + +class CodeRunner: + """ + Simplified test runner that loads kernel config and golden script. + + This class automates: + - Loading kernel_config.py and golden.py dynamically + - Building func_args automatically from torch tensors + - Converting numpy arrays to torch tensors + - Separating inputs and outputs based on naming convention + - Running the full test flow + + Args: + kernels_dir: Path to kernels directory containing kernel_config.py + golden_path: Path to golden.py script + device_id: Device ID (defaults to 0) + platform: Platform name ("a2a3" for hardware, "a2a3sim" for simulation, default: "a2a3") + """ + + def __init__( + self, + kernels_dir: str, + golden_path: str, + device_id: Optional[int] = None, + platform: str = "a2a3", + enable_profiling: bool = False, + run_all_cases: bool = False, + case_name: Optional[str] = None, + n_devices: Optional[int] = None, + first_device_id: Optional[int] = None, + compiled_artifacts: Optional[dict] = None, + prebuilt_dir: Optional[str] = None, + rank_id: Optional[int] = None, + ): + # Setup logging if not already configured (e.g., when used directly, not via multi_card_run_example.py) + _setup_logging_if_needed() + + self.kernels_dir = Path(kernels_dir).resolve() + self.golden_path = Path(golden_path).resolve() + self.platform = platform + self.enable_profiling = enable_profiling + self.run_all_cases = run_all_cases + self.case_name = case_name + self.project_root = _get_project_root() + self.compiled_artifacts = compiled_artifacts + self.prebuilt_dir = Path(prebuilt_dir) if prebuilt_dir else None + self._skip_build = compiled_artifacts is not None or (self.prebuilt_dir is not None and self.prebuilt_dir.exists()) + + # Resolve device ID + self.device_id = device_id if device_id is not None else 0 + + # Load configurations + self._kernel_config = self._load_kernel_config() + self._golden_module = self._load_golden_module() + + # Extract kernel configuration + self.kernels = self._kernel_config.KERNELS + self.orchestration = self._kernel_config.ORCHESTRATION + + # Extract golden configuration — determine which cases to run + all_cases = getattr(self._golden_module, 'ALL_CASES', {"Default": {}}) + default_case = getattr(self._golden_module, 'DEFAULT_CASE', "Default") + + if run_all_cases: + self.params_list = [{"name": name, **params} for name, params in all_cases.items()] + logger.info(f"Running all {len(self.params_list)} cases: {list(all_cases.keys())}") + elif case_name is not None: + if case_name not in all_cases: + raise ValueError(f"Case '{case_name}' not found. Available: {list(all_cases.keys())}") + self.params_list = [{"name": case_name, **all_cases[case_name]}] + else: + self.params_list = [{"name": default_case, **all_cases[default_case]}] + + self.rtol = getattr(self._golden_module, 'RTOL', 1e-5) + self.atol = getattr(self._golden_module, 'ATOL', 1e-5) + self.output_names = getattr(self._golden_module, '__outputs__', None) + self.tensor_order = getattr(self._golden_module, 'TENSOR_ORDER', None) + + # Runtime configuration - read from kernel_config or use defaults + runtime_config = getattr(self._kernel_config, 'RUNTIME_CONFIG', {}) + self.aicpu_thread_num = runtime_config.get('aicpu_thread_num', 3) + self.block_dim = runtime_config.get('block_dim', 24) + self.runtime_name = runtime_config.get('runtime', 'host_build_graph') + # Multi-device: CLI overrides config + self.n_devices = n_devices if n_devices is not None else runtime_config.get('n_devices', 1) + self.first_device_id = first_device_id if first_device_id is not None else runtime_config.get('first_device_id', 0) + self.requires_comm = runtime_config.get('requires_comm', False) + self.rank_id = rank_id if rank_id is not None else (device_id if device_id is not None else 0) + + def _load_kernel_config(self): + """Load kernel_config.py from kernels directory.""" + config_path = self.kernels_dir / "kernel_config.py" + if not config_path.exists(): + raise FileNotFoundError( + f"kernel_config.py not found in {self.kernels_dir}\n" + f"Expected: {config_path}" + ) + return _load_module_from_path(config_path, f"kernel_config_{id(self)}") + + def _load_golden_module(self): + """Load golden.py script.""" + if not self.golden_path.exists(): + raise FileNotFoundError(f"Golden script not found: {self.golden_path}") + + module = _load_module_from_path(self.golden_path, f"golden_{id(self)}") + + # Validate required functions + if not hasattr(module, 'generate_inputs'): + raise AttributeError( + f"golden.py must define generate_inputs(params) function\n" + f"File: {self.golden_path}" + ) + if not hasattr(module, 'compute_golden'): + raise AttributeError( + f"golden.py must define compute_golden(tensors, params) function\n" + f"File: {self.golden_path}" + ) + + return module + + def _identify_outputs(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Dict, Dict]: + """ + Separate inputs and outputs from tensor dict. + + Uses either explicit __outputs__ list or 'out_' prefix convention. + + Returns: + Tuple of (inputs_dict, outputs_dict) + """ + if self.output_names: + # Use explicit output names + outputs = {k: v for k, v in tensors.items() if k in self.output_names} + inputs = {k: v for k, v in tensors.items() if k not in self.output_names} + else: + # Use 'out_' prefix convention + outputs = {k: v for k, v in tensors.items() if k.startswith('out_')} + inputs = {k: v for k, v in tensors.items() if not k.startswith('out_')} + + if not outputs: + raise ValueError( + "No output tensors identified. Either:\n" + "1. Define __outputs__ = ['tensor_name'] in golden.py, or\n" + "2. Use 'out_' prefix for output tensor names (e.g., 'out_result')" + ) + + return inputs, outputs + + def _build_func_args_from_list( + self, args_list: list + ) -> Tuple[List[int], List[int], List[int], Dict[str, Any], Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """ + Build func_args from an explicit argument list returned by generate_inputs. + + Every element must be a (name, value) pair where value is either: + - torch.Tensor / numpy array: a tensor argument + - ctypes scalar (ctypes.c_int64, ctypes.c_float, etc.): a scalar argument + + All named items (tensors and scalars) are collected into the args dict + passed to compute_golden, so compute_golden can reference any arg by name. + + Returns: + Tuple of (func_args, arg_types, arg_sizes, args, inputs, outputs) + where args contains all named items, inputs/outputs contain tensor-only subsets. + """ + import ctypes + import numpy as np + from bindings import ARG_SCALAR, ARG_INPUT_PTR, ARG_OUTPUT_PTR + + # Identify outputs + if self.output_names: + output_set = set(self.output_names) + else: + output_set = set() + + func_args = [] + arg_types = [] + arg_sizes = [] + args = {} # all named items: tensors + scalars → passed to compute_golden + inputs = {} # tensor inputs only → for logging + outputs = {} # tensor outputs only → for comparison + + for item in args_list: + if not (isinstance(item, tuple) and len(item) == 2): + raise TypeError( + f"Each element in generate_inputs() list must be a (name, value) pair, " + f"got: {type(item)}\n" + f"Tensors: ('name', tensor) Scalars: ('name', ctypes.c_int64(...))" + ) + + name, value = item + + if isinstance(value, (torch.Tensor, np.ndarray)): + tensor = _to_torch(value) + tensor = tensor.cpu().contiguous() + args[name] = tensor + + func_args.append(tensor.data_ptr()) + nbytes = tensor.element_size() * tensor.numel() + arg_sizes.append(nbytes) + + if name in output_set or (not output_set and name.startswith('out_')): + arg_types.append(ARG_OUTPUT_PTR) + outputs[name] = tensor + else: + arg_types.append(ARG_INPUT_PTR) + inputs[name] = tensor + + elif isinstance(value, ctypes._SimpleCData): + if isinstance(value, (ctypes.c_float, ctypes.c_double)): + uint_type = ctypes.c_uint32 if isinstance(value, ctypes.c_float) else ctypes.c_uint64 + bits = uint_type.from_buffer_copy(value).value + func_args.append(bits) + else: + func_args.append(int(value.value)) + args[name] = value.value + arg_types.append(ARG_SCALAR) + arg_sizes.append(0) + + else: + raise TypeError( + f"Unsupported value type for arg '{name}': {type(value)}\n" + f"Expected torch.Tensor, numpy array, or ctypes scalar (ctypes.c_int64, ctypes.c_float, etc.)" + ) + + if not outputs: + raise ValueError( + "No output tensors identified. Either:\n" + "1. Define __outputs__ = ['tensor_name'] in golden.py, or\n" + "2. Use 'out_' prefix for output tensor names (e.g., 'out_result')" + ) + + return func_args, arg_types, arg_sizes, args, inputs, outputs + + def _build_func_args(self, tensors: Dict[str, torch.Tensor]) -> Tuple[List[int], List[int], List[int]]: + """ + Build func_args, arg_types, and arg_sizes from tensors dict (legacy path). + + Convention for orchestration function signature: + int BuildGraph(Runtime* runtime, uint64_t* args, int arg_count) + + Where args layout is: + [ptr_0, ptr_1, ..., ptr_n, nbytes_0, nbytes_1, ..., nbytes_n, count] + + Args: + tensors: Dict of torch tensors (will be modified to ensure contiguous) + + Returns: + Tuple of (func_args, arg_types, arg_sizes) + """ + from bindings import ARG_SCALAR, ARG_INPUT_PTR, ARG_OUTPUT_PTR + + # Determine tensor order + if self.tensor_order: + order = self.tensor_order + else: + order = list(tensors.keys()) + + # Identify outputs + if self.output_names: + output_set = set(self.output_names) + else: + output_set = {k for k in tensors.keys() if k.startswith('out_')} + + # First pass: ensure all tensors are CPU and contiguous (update dict in place) + for name in order: + if name not in tensors: + raise KeyError( + f"Tensor '{name}' from TENSOR_ORDER not found in generate_inputs() result.\n" + f"Available tensors: {list(tensors.keys())}" + ) + tensors[name] = tensors[name].cpu().contiguous() + + func_args = [] + arg_types = [] + arg_sizes = [] + + # Add pointers + for name in order: + tensor = tensors[name] + func_args.append(tensor.data_ptr()) + + # Determine arg type based on whether it's an output + if name in output_set: + arg_types.append(ARG_OUTPUT_PTR) + else: + arg_types.append(ARG_INPUT_PTR) + + arg_sizes.append(tensor.element_size() * tensor.numel()) + + # Add sizes (as scalars) + for name in order: + tensor = tensors[name] + func_args.append(tensor.element_size() * tensor.numel()) + arg_types.append(ARG_SCALAR) + arg_sizes.append(0) + + # Add element count (as scalar) + count = tensors[order[0]].numel() + func_args.append(count) + arg_types.append(ARG_SCALAR) + arg_sizes.append(0) + + return func_args, arg_types, arg_sizes + + def run(self, comm_context: Optional[Dict[str, Any]] = None) -> None: + """ + Execute the full test flow: + - If compiled_artifacts or prebuilt_dir: skip build, load and run (set_device → init → launch → finalize) + - Else: build first, then run + + When requires_comm, pass comm_context with device_ctx_ptr, win_in_base, win_out_base, n_ranks, root, rank_id + (and optionally comm, stream for HcclBarrier). Merged into params before generate_inputs. + """ + from bindings import bind_host_binary, set_device, launch_runtime + + if self._skip_build: + if self.compiled_artifacts: + artifacts = self.compiled_artifacts + else: + artifacts = _load_artifacts_from_dir(self.prebuilt_dir) + host_binary = artifacts["host_binary"] + orch_so_binary = artifacts["orch_so_binary"] + aicpu_binary = artifacts["aicpu_binary"] + aicore_binary = artifacts["aicore_binary"] + kernel_binaries = artifacts["kernel_binaries"] + orch_func_name = artifacts["orch_func_name"] + logger.info(f"=== Using pre-built artifacts ({len(kernel_binaries)} kernels) ===") + else: + # Build path + from runtime_builder import RuntimeBuilder + from elf_parser import extract_text_section + + # For requires_comm, prefer PTO_COMM_ISA_ROOT (pto-comm-isa has comm headers) + if self.requires_comm and os.environ.get("PTO_COMM_ISA_ROOT"): + pto_isa_root = os.environ["PTO_COMM_ISA_ROOT"].strip() + pto_isa_root = str(Path(pto_isa_root).resolve()) + pto_inst = Path(pto_isa_root) / "include" / "pto" / "pto-inst.hpp" + if not pto_inst.exists(): + raise EnvironmentError( + f"PTO_COMM_ISA_ROOT/include/pto/pto-inst.hpp not found:\n{pto_inst}\n" + f"PTO_COMM_ISA_ROOT={pto_isa_root}\n" + f"Ensure PTO_COMM_ISA_ROOT points to pto-comm-isa root (use single =)." + ) + logger.info(f"Using PTO_COMM_ISA_ROOT for comm kernels: {pto_isa_root}") + else: + pto_isa_root = _ensure_pto_isa_root(verbose=True) + if pto_isa_root is None: + raise EnvironmentError( + "PTO_ISA_ROOT (or PTO_COMM_ISA_ROOT for comm) could not be resolved.\n" + "Please set PTO_ISA_ROOT or PTO_COMM_ISA_ROOT to the pto-isa/pto-comm-isa root." + ) + + logger.info(f"=== Building Runtime: {self.runtime_name} (platform: {self.platform}) ===") + builder = RuntimeBuilder(platform=self.platform) + kernel_compiler = builder.get_kernel_compiler() + + from concurrent.futures import ThreadPoolExecutor + + runtime_include_dirs = [ + os.path.join(self.project_root, "src", "runtime", self.runtime_name, "runtime") + ] + if self.requires_comm: + comm_include = self.project_root / "examples" / "scripts" / "comm_include" + if comm_include.exists(): + runtime_include_dirs.append(str(comm_include)) + + def _build_runtime(): + return builder.build(self.runtime_name) + + def _compile_orchestration(): + return kernel_compiler.compile_orchestration( + self.runtime_name, + self.orchestration["source"], + ) + + def _compile_one_kernel(kernel): + logger.info(f"Compiling kernel: {kernel['source']} (func_id={kernel['func_id']})") + incore_o = kernel_compiler.compile_incore( + kernel["source"], + core_type=kernel["core_type"], + pto_isa_root=pto_isa_root, + extra_include_dirs=runtime_include_dirs, + ) + if self.platform == "a2a3sim": + return (kernel["func_id"], incore_o) + return (kernel["func_id"], extract_text_section(incore_o)) + + max_workers = 2 + len(self.kernels) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + fut_runtime = executor.submit(_build_runtime) + fut_orch = executor.submit(_compile_orchestration) + fut_kernels = [executor.submit(_compile_one_kernel, k) for k in self.kernels] + + try: + host_binary, aicpu_binary, aicore_binary = fut_runtime.result() + except Exception as e: + raise RuntimeError( + f"Failed to build runtime '{self.runtime_name}' for platform '{self.platform}'.\n" + f"Error: {e}" + ) from e + + orch_so_binary = fut_orch.result() + kernel_binaries = [f.result() for f in fut_kernels] + + logger.info(f"Compiled {len(kernel_binaries)} kernel(s)") + orch_func_name = self.orchestration["function_name"] + + # Load runtime and set device + logger.info(f"=== Loading Runtime ({len(host_binary)} bytes) ===") + Runtime = bind_host_binary(host_binary) + + logger.info(f"=== Setting Device {self.device_id} ===") + set_device(self.device_id) + + # Step 5: Run each parameter set + total_cases = len(self.params_list) + for case_idx, params in enumerate(self.params_list): + if comm_context: + params = {**params, **comm_context} + logger.info("=" * 60) + logger.info(f"=== Case {case_idx + 1}/{total_cases}: {params} ===") + logger.info("=" * 60) + + # Generate tensors using golden.py + logger.info("=== Generating Inputs ===") + result = self._golden_module.generate_inputs(params) + + if isinstance(result, list): + # New-style: generate_inputs returns flat argument list + func_args, arg_types, arg_sizes, args, inputs, outputs = \ + self._build_func_args_from_list(result) + tensors = args # args contains all named items; compute_golden receives all + else: + # Legacy: generate_inputs returns dict of tensors + tensors = {k: _to_torch(v) for k, v in result.items()} + func_args, arg_types, arg_sizes = self._build_func_args(tensors) + inputs, outputs = self._identify_outputs(tensors) + + logger.info(f"Inputs: {list(inputs.keys())}") + logger.info(f"Outputs: {list(outputs.keys())}") + + # Determine actual tensor order for debugging + logger.debug(f"Tensor order: {list(tensors.keys())}") + logger.debug(f"func_args count: {len(func_args)}") + + # Save expected values BEFORE hardware execution (outputs will be overwritten) + golden = {k: v.clone() for k, v in outputs.items()} + golden_with_inputs = {**inputs, **golden} + _t_golden_start = time.perf_counter() + self._golden_module.compute_golden(golden_with_inputs, params) + _t_golden_end = time.perf_counter() + logger.info(f">>> compute_golden() took {_t_golden_end - _t_golden_start:.3f}s") + + # Build environment for runtime initialization + run_env = _kernel_config_runtime_env(self._kernel_config, self.kernels_dir) + if run_env: + logger.debug(f"Runtime init env overrides: {run_env}") + + # Create and initialize runtime + logger.info("=== Initializing Runtime ===") + runtime = Runtime() + + if self.enable_profiling: + runtime.enable_profiling(True) + logger.info("Profiling enabled") + + _t_init_start = time.perf_counter() + with _temporary_env(run_env): + runtime.initialize( + orch_so_binary, + orch_func_name, + func_args, + arg_types=arg_types, + arg_sizes=arg_sizes, + kernel_binaries=kernel_binaries, + ) + _t_init_end = time.perf_counter() + logger.info(f">>> runtime.initialize() took {_t_init_end - _t_init_start:.3f}s") + + # HcclBarrier before launch (when using comm) + if comm_context and "comm" in comm_context and "stream" in comm_context: + from hccl_bindings import hccl_barrier + hccl_barrier(comm_context["comm"], comm_context["stream"]) + logger.info("HcclBarrier (pre-launch) done") + + # Launch runtime + logger.info("=== Launching Runtime ===") + import sys + sys.stdout.flush() + + launch_runtime( + runtime, + aicpu_thread_num=self.aicpu_thread_num, + block_dim=self.block_dim, + device_id=self.device_id, + aicpu_binary=aicpu_binary, + aicore_binary=aicore_binary, + ) + + logger.info("Launch completed successfully") + + # HcclBarrier after launch (when using comm) + if comm_context and "comm" in comm_context and "stream" in comm_context: + from hccl_bindings import hccl_barrier + hccl_barrier(comm_context["comm"], comm_context["stream"]) + logger.info("HcclBarrier (post-launch) done") + + # Finalize + logger.info("=== Finalizing Runtime ===") + runtime.finalize() + + # Compute golden and compare + logger.info("=== Comparing Results ===") + self._compare_with_golden(outputs, golden) + + logger.info(f"=== Case {case_idx + 1}/{total_cases} Passed ===") + + logger.info("=" * 60) + logger.info(f"=== All {total_cases} cases passed ===") + logger.info("=" * 60) + + def _compare_with_golden( + self, + outputs: Dict[str, torch.Tensor], + golden: Dict[str, torch.Tensor], + ) -> None: + """Compare hardware outputs with pre-computed golden values.""" + # Compare each output + for name in outputs: + actual = outputs[name] + expected = golden[name] + logger.info(f"Comparing {name}: shape={actual.shape}, dtype={actual.dtype}") + + # Ensure both are on CPU for comparison + actual = actual.cpu() + expected = expected.cpu() + + # Show first 10 values + if actual.numel() > 0: + flat_actual = actual.flatten() + flat_expected = expected.flatten() + n_show = min(10, flat_actual.numel()) + # 始终打印前几个元素,方便对比实际值与 golden + logger.info(f" First {n_show} actual: {flat_actual[:n_show].tolist()}") + logger.info(f" First {n_show} expected: {flat_expected[:n_show].tolist()}") + + # Use torch for comparison + if not torch.allclose(actual, expected, rtol=self.rtol, atol=self.atol): + close_mask = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol) + mismatches = (~close_mask).sum().item() + total = actual.numel() + + logger.warning( + "Output '%s' does not match golden (mismatched %d/%d, rtol=%g, atol=%g)", + name, + mismatches, + total, + self.rtol, + self.atol, + ) + + flat_a = actual.flatten().tolist() + flat_e = expected.flatten().tolist() + chunk = 64 + logger.warning("--- actual (%s) ---", name) + for i in range(0, len(flat_a), chunk): + logger.warning(" [%4d:%4d] %s", i, min(i + chunk, len(flat_a)), flat_a[i:i + chunk]) + logger.warning("--- expected (%s) ---", name) + for i in range(0, len(flat_e), chunk): + logger.warning(" [%4d:%4d] %s", i, min(i + chunk, len(flat_e)), flat_e[i:i + chunk]) + else: + matched = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol).sum().item() + logger.info(f" {name}: PASS ({matched}/{actual.numel()} elements matched)") + + +def create_code_runner(kernels_dir, golden_path, device_id=None, platform="a2a3", + enable_profiling=False, run_all_cases=False, case_name=None, + n_devices=None, first_device_id=None, + compiled_artifacts=None, prebuilt_dir=None, rank_id=None): + """Factory: creates a CodeRunner based on kernel_config.""" + return CodeRunner(kernels_dir=kernels_dir, golden_path=golden_path, + device_id=device_id, platform=platform, + enable_profiling=enable_profiling, + run_all_cases=run_all_cases, case_name=case_name, + n_devices=n_devices, first_device_id=first_device_id, + compiled_artifacts=compiled_artifacts, prebuilt_dir=prebuilt_dir, + rank_id=rank_id) + + +def run_on_device( + device_id: int, + artifacts: dict, + kernels_dir: str, + golden_path: str, + platform: str = "a2a3", + enable_profiling: bool = False, + run_all_cases: bool = False, + case_name: Optional[str] = None, +) -> None: + """ + Worker entry for multiprocessing: create CodeRunner and run. + Must be at module level so child process can import without running main(). + """ + runner = create_code_runner( + kernels_dir=kernels_dir, + golden_path=golden_path, + device_id=device_id, + platform=platform, + enable_profiling=enable_profiling, + run_all_cases=run_all_cases, + case_name=case_name, + n_devices=1, + first_device_id=device_id, + compiled_artifacts=artifacts, + ) + runner.run() + + +def run_on_device_comm( + rank_id: int, + device_id: int, + root_info: bytes, + artifacts: dict, + kernels_dir: str, + golden_path: str, + n_ranks: int, + n_devices: int, + first_device_id: int, + root: int, + platform: str = "a2a3", + enable_profiling: bool = False, + run_all_cases: bool = False, + case_name: Optional[str] = None, +) -> None: + """ + Worker for requires_comm: init HCCL, create CodeRunner, run with comm_context. + """ + from hccl_bindings import hccl_init_comm + + comm, device_ctx_ptr, win_in_base, win_out_base, stream, actual_rank_id = hccl_init_comm( + rank_id, n_ranks, n_devices, first_device_id, root_info + ) + + comm_context = { + "device_ctx_ptr": device_ctx_ptr, + "win_in_base": win_in_base, + "win_out_base": win_out_base, + "n_ranks": n_ranks, + "root": root, + # Keep graph/golden rank selection aligned with HCCL runtime rank. + "rank_id": actual_rank_id, + "comm": comm, + "stream": stream, + } + + runner = create_code_runner( + kernels_dir=kernels_dir, + golden_path=golden_path, + device_id=device_id, + platform=platform, + enable_profiling=enable_profiling, + run_all_cases=run_all_cases, + case_name=case_name, + n_devices=n_devices, + first_device_id=first_device_id, + compiled_artifacts=artifacts, + rank_id=rank_id, + ) + runner.run(comm_context=comm_context) + + +# ============================================================================= +# PTOCompiler - compile once, run many +# ============================================================================= + +def _write_artifacts_to_dir(artifacts: dict, out_dir: Path) -> None: + """Write compiled artifacts to a directory for subprocess loading.""" + import json + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "host.bin").write_bytes(artifacts["host_binary"]) + (out_dir / "orch.so").write_bytes(artifacts["orch_so_binary"]) + (out_dir / "aicpu.so").write_bytes(artifacts["aicpu_binary"]) + (out_dir / "aicore.bin").write_bytes(artifacts["aicore_binary"]) + for func_id, bin_data in artifacts["kernel_binaries"]: + (out_dir / f"kernel_{func_id}.bin").write_bytes(bin_data) + manifest = { + "orch_func_name": artifacts["orch_func_name"], + "kernel_func_ids": [k[0] for k in artifacts["kernel_binaries"]], + } + (out_dir / "manifest.json").write_text(json.dumps(manifest), encoding="utf-8") + + +def _load_artifacts_from_dir(prebuilt_dir: Path) -> dict: + """Load compiled artifacts from a prebuilt directory.""" + import json + manifest = json.loads((prebuilt_dir / "manifest.json").read_text(encoding="utf-8")) + kernel_binaries = [] + for func_id in manifest["kernel_func_ids"]: + bin_data = (prebuilt_dir / f"kernel_{func_id}.bin").read_bytes() + kernel_binaries.append((func_id, bin_data)) + return { + "host_binary": (prebuilt_dir / "host.bin").read_bytes(), + "orch_so_binary": (prebuilt_dir / "orch.so").read_bytes(), + "aicpu_binary": (prebuilt_dir / "aicpu.so").read_bytes(), + "aicore_binary": (prebuilt_dir / "aicore.bin").read_bytes(), + "kernel_binaries": kernel_binaries, + "orch_func_name": manifest["orch_func_name"], + } + + +class PTOCompiler: + """Compiles PTO runtime, orchestration, and kernels. Returns artifacts for Runner.""" + + def __init__( + self, + kernels_dir: str, + platform: str = "a2a3", + ): + self.kernels_dir = Path(kernels_dir).resolve() + self.platform = platform + self.project_root = _get_project_root() + self._kernel_config = _load_module_from_path( + self.kernels_dir / "kernel_config.py", f"kernel_config_compiler_{id(self)}" + ) + self.kernels = self._kernel_config.KERNELS + self.orchestration = self._kernel_config.ORCHESTRATION + runtime_config = getattr(self._kernel_config, "RUNTIME_CONFIG", {}) + self.runtime_name = runtime_config.get("runtime", "host_build_graph") + self.requires_comm = runtime_config.get("requires_comm", False) + + def compile(self) -> dict: + """Build runtime, orchestration, and kernels. Return artifacts dict.""" + from runtime_builder import RuntimeBuilder + from elf_parser import extract_text_section + + if self.requires_comm and os.environ.get("PTO_COMM_ISA_ROOT"): + pto_isa_root = os.environ["PTO_COMM_ISA_ROOT"].strip() + pto_isa_root = str(Path(pto_isa_root).resolve()) + pto_inst = Path(pto_isa_root) / "include" / "pto" / "pto-inst.hpp" + if not pto_inst.exists(): + raise EnvironmentError( + f"PTO_COMM_ISA_ROOT/include/pto/pto-inst.hpp not found:\n{pto_inst}\n" + f"Ensure PTO_COMM_ISA_ROOT points to pto-comm-isa root (use single =)." + ) + else: + pto_isa_root = _ensure_pto_isa_root(verbose=True) + if pto_isa_root is None: + raise EnvironmentError( + "PTO_ISA_ROOT (or PTO_COMM_ISA_ROOT for comm) could not be resolved." + ) + + logger.info(f"=== PTOCompiler: Building {self.runtime_name} (platform: {self.platform}) ===") + builder = RuntimeBuilder(platform=self.platform) + kernel_compiler = builder.get_kernel_compiler() + + from concurrent.futures import ThreadPoolExecutor + + runtime_include_dirs = [ + os.path.join(self.project_root, "src", "runtime", self.runtime_name, "runtime") + ] + if self.requires_comm: + comm_include = self.project_root / "examples" / "scripts" / "comm_include" + if comm_include.exists(): + runtime_include_dirs.append(str(comm_include)) + + def _build_runtime(): + return builder.build(self.runtime_name) + + def _compile_orchestration(): + return kernel_compiler.compile_orchestration( + self.runtime_name, + self.orchestration["source"], + ) + + def _compile_one_kernel(kernel): + logger.info(f"Compiling kernel: {kernel['source']} (func_id={kernel['func_id']})") + incore_o = kernel_compiler.compile_incore( + kernel["source"], + core_type=kernel["core_type"], + pto_isa_root=pto_isa_root, + extra_include_dirs=runtime_include_dirs, + ) + if self.platform == "a2a3sim": + return (kernel["func_id"], incore_o) + return (kernel["func_id"], extract_text_section(incore_o)) + + max_workers = 2 + len(self.kernels) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + fut_runtime = executor.submit(_build_runtime) + fut_orch = executor.submit(_compile_orchestration) + fut_kernels = [executor.submit(_compile_one_kernel, k) for k in self.kernels] + + try: + host_binary, aicpu_binary, aicore_binary = fut_runtime.result() + except Exception as e: + raise RuntimeError( + f"Failed to build runtime '{self.runtime_name}': {e}" + ) from e + + orch_so_binary = fut_orch.result() + kernel_binaries = [f.result() for f in fut_kernels] + + logger.info(f"PTOCompiler: Compiled {len(kernel_binaries)} kernel(s)") + + result = { + "host_binary": host_binary, + "orch_so_binary": orch_so_binary, + "aicpu_binary": aicpu_binary, + "aicore_binary": aicore_binary, + "kernel_binaries": kernel_binaries, + "orch_func_name": self.orchestration["function_name"], + } + return result + + +def create_compiler(kernels_dir, platform="a2a3"): + """Factory: creates a PTOCompiler.""" + return PTOCompiler(kernels_dir=kernels_dir, platform=platform) diff --git a/examples/scripts/multi_card_run_example.py b/examples/scripts/multi_card_run_example.py new file mode 100644 index 00000000..847a3fd0 --- /dev/null +++ b/examples/scripts/multi_card_run_example.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +""" +Multi-card test runner for PTO runtime tests. + +This script provides a command-line interface to run PTO runtime tests +with multi-card support (compile once, run on N devices in parallel). +Users only need to provide: +1. A kernels directory with kernel_config.py +2. A golden.py script + +Usage: + python examples/scripts/multi_card_run_example.py --kernels ./my_test/kernels --golden ./my_test/golden.py + python examples/scripts/multi_card_run_example.py -k ./kernels -g ./golden.py --device 0 --platform a2a3sim + +Examples: + # Run hardware example (requires Ascend device) + python examples/scripts/multi_card_run_example.py -k examples/host_build_graph/vector_example/kernels \ + -g examples/host_build_graph/vector_example/golden.py + + # Run simulation example (no hardware required) + python examples/scripts/multi_card_run_example.py -k examples/host_build_graph/vector_example/kernels \ + -g examples/host_build_graph/vector_example/golden.py \ + -p a2a3sim + + # Run with specific device + python examples/scripts/multi_card_run_example.py -k ./kernels -g ./golden.py -d 0 + + # Multi-card (e.g. multi_bgemm) + python examples/scripts/multi_card_run_example.py -k examples/host_build_graph/multi_bgemm/kernels \ + -g examples/host_build_graph/multi_bgemm/golden.py \ + --n-devices 2 --first-device 0 +""" + +import argparse +import logging +import os +import subprocess +import sys +from pathlib import Path + +# Get script and project directories +script_dir = Path(__file__).parent.resolve() +project_root = script_dir.parent.parent +python_dir = project_root / "python" +if python_dir.exists(): + sys.path.insert(0, str(python_dir)) + +logger = logging.getLogger(__name__) + + +def _get_device_log_dir(device_id): + """Return the device log directory using the same logic as device_log_resolver.""" + ascend_work_path = os.environ.get("ASCEND_WORK_PATH") + if ascend_work_path: + root = Path(ascend_work_path).expanduser() / "log" / "debug" + if root.exists(): + return root / f"device-{device_id}" + return Path.home() / "ascend" / "log" / "debug" / f"device-{device_id}" + + +def _run_profiling_swimlane(args, kernels_path, project_root, device_log_dir, pre_run_device_logs, log_level_str): + """Run swimlane converter after test. Returns 0 on success.""" + swimlane_script = project_root / "tools" / "swimlane_converter.py" + if not swimlane_script.exists(): + logger.warning("Swimlane converter script not found") + return 0 + import subprocess + try: + cmd = [sys.executable, str(swimlane_script), "-k", str(kernels_path)] + if device_log_dir is not None: + device_log_file = _wait_for_new_device_log(device_log_dir, pre_run_device_logs) + if device_log_file: + cmd += ["--device-log", str(device_log_file)] + else: + cmd += ["-d", str(args.device)] + else: + cmd += ["-d", str(args.device)] + if log_level_str == "debug": + cmd.append("-v") + subprocess.run(cmd, check=True, capture_output=True, text=True) + logger.info("Swimlane JSON generation completed") + except subprocess.CalledProcessError as e: + logger.warning(f"Swimlane conversion failed: {e}") + return 0 + + +def _wait_for_new_device_log(log_dir, pre_run_logs, timeout=15, interval=0.5): + """Wait for a new device log file that wasn't present before the run. + + CANN dlog writes device logs asynchronously, so the file may appear + a few seconds after the run completes. + """ + import time + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if log_dir.exists(): + current_logs = set(log_dir.glob("*.log")) + new_logs = current_logs - pre_run_logs + if new_logs: + return max(new_logs, key=lambda p: p.stat().st_mtime) + time.sleep(interval) + return None + + +def _ensure_hccl_helper_built(): + """Ensure libhccl_helper.so is built. Build if build dir or .so is missing.""" + hccl_helper_dir = script_dir / "hccl_helper" + build_dir = hccl_helper_dir / "build" + lib_path = build_dir / "libhccl_helper.so" + if lib_path.exists(): + return + logger.info("HCCL helper not built, compiling...") + build_dir.mkdir(parents=True, exist_ok=True) + try: + subprocess.run( + ["cmake", ".."], + cwd=str(build_dir), + check=True, + capture_output=True, + text=True, + ) + subprocess.run( + ["make"], + cwd=str(build_dir), + check=True, + capture_output=True, + text=True, + ) + logger.info("HCCL helper built successfully") + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"HCCL helper build failed: {e}\n" + "Ensure CANN env is set: source /usr/local/Ascend/ascend-toolkit/latest/bin/setenv.bash" + ) from e + + +def main(): + parser = argparse.ArgumentParser( + description="Run PTO runtime test with multi-card support (kernel config and golden script)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python examples/scripts/multi_card_run_example.py --kernels ./my_test/kernels --golden ./my_test/golden.py + python examples/scripts/multi_card_run_example.py -k ./kernels -g ./golden.py -d 0 + +Golden.py interface: + def generate_inputs(params: dict) -> dict: + '''Return dict of numpy arrays (inputs + outputs)''' + return {"a": np.array(...), "out_f": np.zeros(...)} + + def compute_golden(tensors: dict, params: dict) -> None: + '''Compute expected outputs in-place''' + tensors["out_f"][:] = tensors["a"] + 1 + + # Optional — for parameterized test cases: + ALL_CASES = {"Case1": {"size": 1024}, "Case2": {"size": 2048}} + DEFAULT_CASE = "Case1" + RTOL = 1e-5 # Relative tolerance + ATOL = 1e-5 # Absolute tolerance + __outputs__ = ["out_f"] # Or use 'out_' prefix + """ + ) + + parser.add_argument( + "-k", "--kernels", + required=True, + help="Path to kernels directory containing kernel_config.py" + ) + + parser.add_argument( + "-g", "--golden", + required=True, + help="Path to golden.py script" + ) + + parser.add_argument( + "-d", "--device", + type=int, + default=0, + help="Device ID (default: 0)" + ) + + parser.add_argument( + "--n-devices", + type=int, + default=None, + help="Number of devices to run on (multi-card). Overrides kernel_config RUNTIME_CONFIG. Default from config or 1." + ) + + parser.add_argument( + "--first-device", + type=int, + default=None, + help="First device ID for multi-card (e.g. 4 with --n-devices 4 uses devices 4,5,6,7). Overrides kernel_config." + ) + + parser.add_argument( + "-p", "--platform", + default="a2a3", + choices=["a2a3", "a2a3sim"], + help="Platform name: 'a2a3' for hardware, 'a2a3sim' for simulation (default: a2a3)" + ) + + parser.add_argument( + "-v", "--verbose", + action="store_true", + help="Enable verbose output (equivalent to --log-level debug)" + ) + + parser.add_argument( + "--silent", + action="store_true", + help="Silent mode - only show errors (equivalent to --log-level error)" + ) + + parser.add_argument( + "--log-level", + choices=["error", "warn", "info", "debug"], + help="Set log level explicitly (overrides --verbose and --silent)" + ) + + parser.add_argument( + "--enable-profiling", + action="store_true", + help="Enable profiling and generate swimlane.json" + ) + + parser.add_argument( + "--all", + action="store_true", + help="Run all test cases defined in ALL_CASES (default: run only DEFAULT_CASE)" + ) + + parser.add_argument( + "--case", + type=str, + default=None, + help="Run a specific test case by name (e.g., --case Case2)" + ) + + args = parser.parse_args() + + if args.all and args.case: + parser.error("--all and --case are mutually exclusive") + + # Determine log level from arguments + log_level_str = None + if args.log_level: + log_level_str = args.log_level + elif args.verbose: + log_level_str = "debug" + elif args.silent: + log_level_str = "error" + else: + log_level_str = "info" + + # Setup logging before any other operations + level_map = { + 'error': logging.ERROR, + 'warn': logging.WARNING, + 'info': logging.INFO, + 'debug': logging.DEBUG, + } + log_level = level_map.get(log_level_str.lower(), logging.INFO) + + # Configure Python logging + logging.basicConfig( + level=log_level, + format='[%(levelname)s] %(message)s', + force=True + ) + + # Set environment variable for C++ side + os.environ['PTO_LOG_LEVEL'] = log_level_str + + # Add script_dir for multi_card_code_runner (now co-located) + sys.path.insert(0, str(script_dir)) + + # Validate paths + kernels_path = Path(args.kernels) + golden_path = Path(args.golden) + + if not kernels_path.exists(): + logger.error(f"Kernels directory not found: {kernels_path}") + return 1 + + if not golden_path.exists(): + logger.error(f"Golden script not found: {golden_path}") + return 1 + + kernel_config_path = kernels_path / "kernel_config.py" + if not kernel_config_path.exists(): + logger.error(f"kernel_config.py not found in {kernels_path}") + return 1 + + # Import and run + try: + from concurrent.futures import ProcessPoolExecutor, as_completed + + from multi_card_code_runner import create_code_runner, create_compiler, run_on_device, run_on_device_comm + + # Ensure HCCL helper is built (for multi-card comm) before compile + _ensure_hccl_helper_built() + + # Compile first + compiler = create_compiler(kernels_dir=str(args.kernels), platform=args.platform) + artifacts = compiler.compile() + + # Resolve n_devices and first_device_id (args override config) + import importlib.util + spec = importlib.util.spec_from_file_location("kernel_config", kernel_config_path) + cfg = importlib.util.module_from_spec(spec) + spec.loader.exec_module(cfg) + runtime_config = getattr(cfg, "RUNTIME_CONFIG", {}) + n_devices = args.n_devices if args.n_devices is not None else runtime_config.get("n_devices", 1) + first_device_id = args.first_device if args.first_device is not None else runtime_config.get("first_device_id", 0) + requires_comm = runtime_config.get("requires_comm", False) + root = runtime_config.get("root", 0) + + if requires_comm and n_devices > 1: + # Multi-card with HCCL: use Barrier + shared root_info + import multiprocessing as mp + from hccl_bindings import hccl_get_root_info, HCCL_ROOT_INFO_BYTES + + # Shared buffer for HcclRootInfo: use unsigned char ('B') to match bytes 0-255 + root_info_arr = mp.Array("B", HCCL_ROOT_INFO_BYTES) + barrier = mp.Barrier(n_devices) + + def _comm_worker(rank_id): + device_id = rank_id % n_devices + first_device_id + if rank_id == 0: + root_info = hccl_get_root_info(device_id) + root_info_arr[:] = root_info[:HCCL_ROOT_INFO_BYTES] + barrier.wait() + root_info = bytes(root_info_arr[:]) + run_on_device_comm( + rank_id=rank_id, + device_id=device_id, + root_info=root_info, + artifacts=artifacts, + kernels_dir=str(args.kernels), + golden_path=str(args.golden), + n_ranks=n_devices, + n_devices=n_devices, + first_device_id=first_device_id, + root=root, + platform=args.platform, + enable_profiling=args.enable_profiling, + run_all_cases=args.all, + case_name=args.case, + ) + + procs = [mp.Process(target=_comm_worker, args=(r,)) for r in range(n_devices)] + for p in procs: + p.start() + failed = [] + for r, p in enumerate(procs): + p.join() + if p.exitcode != 0: + failed.append((r, RuntimeError(f"Rank {r} exited with code {p.exitcode}"))) + if failed: + err_msg = "; ".join(f"rank {d}: {e}" for d, e in failed) + raise RuntimeError(f"Multi-card comm run failed: {err_msg}") + logger.info("=" * 60) + logger.info("TEST PASSED (all ranks)") + logger.info("=" * 60) + return 0 + elif n_devices > 1: + # Multi-device: create N CodeRunner instances, run in parallel via ProcessPoolExecutor + device_ids = list(range(first_device_id, first_device_id + n_devices)) + logger.info(f"=== Multi-device: compile done, running on devices {device_ids} (parallel) ===") + + failed = [] + with ProcessPoolExecutor(max_workers=n_devices) as executor: + futures = { + executor.submit( + run_on_device, + did, + artifacts, + str(args.kernels), + str(args.golden), + args.platform, + args.enable_profiling, + args.all, + args.case, + ): did + for did in device_ids + } + for fut in as_completed(futures): + did = futures[fut] + try: + fut.result() + logger.info(f"Device {did}: PASS") + except Exception as e: + failed.append((did, e)) + logger.error(f"Device {did} failed: {e}") + + if failed: + err_msg = "; ".join(f"device {d}: {e}" for d, e in failed) + raise RuntimeError(f"Multi-device run failed: {err_msg}") + + logger.info("=" * 60) + logger.info("TEST PASSED (all devices)") + logger.info("=" * 60) + return 0 + else: + # Single device: run in-process with compiled artifacts + runner = create_code_runner( + kernels_dir=str(args.kernels), + golden_path=str(args.golden), + device_id=args.device, + platform=args.platform, + enable_profiling=args.enable_profiling, + run_all_cases=args.all, + case_name=args.case, + n_devices=1, + first_device_id=args.device, + compiled_artifacts=artifacts, + ) + + pre_run_device_logs = set() + device_log_dir = None + if args.enable_profiling and args.platform == "a2a3": + device_log_dir = _get_device_log_dir(args.device) + if device_log_dir.exists(): + pre_run_device_logs = set(device_log_dir.glob("*.log")) + + runner.run() + logger.info("=" * 60) + logger.info("TEST PASSED") + logger.info("=" * 60) + + if args.enable_profiling: + logger.info("Generating swimlane visualization...") + _run_profiling_swimlane(args, kernels_path, project_root, device_log_dir, pre_run_device_logs, log_level_str) + + return 0 + + except ImportError as e: + logger.error(f"Import error: {e}") + logger.error("Make sure you're running from the project root directory.") + return 1 + + except Exception as e: + logger.error(f"TEST FAILED: {e}") + if log_level_str == "debug": + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/scripts/run_example.py b/examples/scripts/run_example.py index 7a9512ac..36c23330 100644 --- a/examples/scripts/run_example.py +++ b/examples/scripts/run_example.py @@ -51,6 +51,32 @@ def _get_device_log_dir(device_id): return Path.home() / "ascend" / "log" / "debug" / f"device-{device_id}" +def _run_profiling_swimlane(args, kernels_path, project_root, device_log_dir, pre_run_device_logs, log_level_str): + """Run swimlane converter after test. Returns 0 on success.""" + swimlane_script = project_root / "tools" / "swimlane_converter.py" + if not swimlane_script.exists(): + logger.warning("Swimlane converter script not found") + return 0 + import subprocess + try: + cmd = [sys.executable, str(swimlane_script), "-k", str(kernels_path)] + if device_log_dir is not None: + device_log_file = _wait_for_new_device_log(device_log_dir, pre_run_device_logs) + if device_log_file: + cmd += ["--device-log", str(device_log_file)] + else: + cmd += ["-d", str(args.device)] + else: + cmd += ["-d", str(args.device)] + if log_level_str == "debug": + cmd.append("-v") + subprocess.run(cmd, check=True, capture_output=True, text=True) + logger.info("Swimlane JSON generation completed") + except subprocess.CalledProcessError as e: + logger.warning(f"Swimlane conversion failed: {e}") + return 0 + + def _wait_for_new_device_log(log_dir, pre_run_logs, timeout=15, interval=0.5): """Wait for a new device log file that wasn't present before the run. @@ -115,6 +141,20 @@ def compute_golden(tensors: dict, params: dict) -> None: help="Device ID (default: 0)" ) + parser.add_argument( + "--n-devices", + type=int, + default=None, + help="Number of devices to run on (multi-card). Overrides kernel_config RUNTIME_CONFIG. Default from config or 1." + ) + + parser.add_argument( + "--first-device", + type=int, + default=None, + help="First device ID for multi-card (e.g. 4 with --n-devices 4 uses devices 4,5,6,7). Overrides kernel_config." + ) + parser.add_argument( "-p", "--platform", default="a2a3", @@ -216,72 +256,91 @@ def compute_golden(tensors: dict, params: dict) -> None: # Import and run try: - from code_runner import create_code_runner - - runner = create_code_runner( - kernels_dir=str(args.kernels), - golden_path=str(args.golden), - device_id=args.device, - platform=args.platform, - enable_profiling=args.enable_profiling, - run_all_cases=args.all, - case_name=args.case, - ) - - # Snapshot existing device logs before the run so we can identify the - # new log created by this run (CANN writes device logs asynchronously). - pre_run_device_logs = set() - device_log_dir = None - if args.enable_profiling and args.platform == "a2a3": - device_log_dir = _get_device_log_dir(args.device) - if device_log_dir.exists(): - pre_run_device_logs = set(device_log_dir.glob("*.log")) - - runner.run() - logger.info("=" * 60) - logger.info("TEST PASSED") - logger.info("=" * 60) - - # If profiling was enabled, generate merged swimlane JSON - if args.enable_profiling: - logger.info("Generating swimlane visualization...") - kernel_config_path = kernels_path / "kernel_config.py" - swimlane_script = project_root / "tools" / "swimlane_converter.py" - - if swimlane_script.exists(): - import subprocess - try: - cmd = [ - sys.executable, - str(swimlane_script), - "-k", - str(kernel_config_path), - ] - - # Find the device log created by this run via snapshot diff - if device_log_dir is not None: - device_log_file = _wait_for_new_device_log( - device_log_dir, pre_run_device_logs) - if device_log_file: - cmd += ["--device-log", str(device_log_file)] - else: - logger.warning("No new device log found, falling back to device-id") - cmd += ["-d", str(args.device)] - else: - cmd += ["-d", str(args.device)] - - if log_level_str == "debug": - cmd.append("-v") - - result = subprocess.run(cmd, check=True, capture_output=True, text=True) - logger.info(result.stdout) - logger.info("Swimlane JSON generation completed") - except subprocess.CalledProcessError as e: - logger.warning(f"Failed to generate swimlane JSON: {e}") - if log_level_str == "debug": - logger.debug(f"stderr: {e.stderr}") - else: - logger.warning(f"Swimlane converter script not found: {swimlane_script}") + from concurrent.futures import ProcessPoolExecutor, as_completed + + from code_runner import create_code_runner, create_compiler, run_on_device + + # Compile first + compiler = create_compiler(kernels_dir=str(args.kernels), platform=args.platform) + artifacts = compiler.compile() + + # Resolve n_devices and first_device_id (args override config) + import importlib.util + spec = importlib.util.spec_from_file_location("kernel_config", kernel_config_path) + cfg = importlib.util.module_from_spec(spec) + spec.loader.exec_module(cfg) + runtime_config = getattr(cfg, "RUNTIME_CONFIG", {}) + n_devices = args.n_devices if args.n_devices is not None else runtime_config.get("n_devices", 1) + first_device_id = args.first_device if args.first_device is not None else runtime_config.get("first_device_id", 0) + + if n_devices > 1: + # Multi-device: create N CodeRunner instances, run in parallel via ProcessPoolExecutor + device_ids = list(range(first_device_id, first_device_id + n_devices)) + logger.info(f"=== Multi-device: compile done, running on devices {device_ids} (parallel) ===") + + failed = [] + with ProcessPoolExecutor(max_workers=n_devices) as executor: + futures = { + executor.submit( + run_on_device, + did, + artifacts, + str(args.kernels), + str(args.golden), + args.platform, + args.enable_profiling, + args.all, + args.case, + ): did + for did in device_ids + } + for fut in as_completed(futures): + did = futures[fut] + try: + fut.result() + logger.info(f"Device {did}: PASS") + except Exception as e: + failed.append((did, e)) + logger.error(f"Device {did} failed: {e}") + + if failed: + err_msg = "; ".join(f"device {d}: {e}" for d, e in failed) + raise RuntimeError(f"Multi-device run failed: {err_msg}") + + logger.info("=" * 60) + logger.info("TEST PASSED (all devices)") + logger.info("=" * 60) + return 0 + else: + # Single device: run in-process with compiled artifacts + runner = create_code_runner( + kernels_dir=str(args.kernels), + golden_path=str(args.golden), + device_id=args.device, + platform=args.platform, + enable_profiling=args.enable_profiling, + run_all_cases=args.all, + case_name=args.case, + n_devices=1, + first_device_id=args.device, + compiled_artifacts=artifacts, + ) + + pre_run_device_logs = set() + device_log_dir = None + if args.enable_profiling and args.platform == "a2a3": + device_log_dir = _get_device_log_dir(args.device) + if device_log_dir.exists(): + pre_run_device_logs = set(device_log_dir.glob("*.log")) + + runner.run() + logger.info("=" * 60) + logger.info("TEST PASSED") + logger.info("=" * 60) + + if args.enable_profiling: + logger.info("Generating swimlane visualization...") + _run_profiling_swimlane(args, kernels_path, project_root, device_log_dir, pre_run_device_logs, log_level_str) return 0 diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/README.md b/examples/tensormap_and_ringbuffer/allgather_Manual/README.md new file mode 100644 index 00000000..4439cfa3 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/README.md @@ -0,0 +1,16 @@ +# AllGather (Manual) - tensormap_and_ringbuffer + +直接 RDMA 读取的 AllGather,无 TGATHER。用于性能对比。 + +流程:WindowMemCopyIn → CommBarrier(pre) → AllGatherManual → WindowMemCopyOut → CommBarrier(post) + +所有 rank 并行执行,每个 rank 获得完整拼接结果。 + +## 运行 + +```bash +cd simpler-PTO +./run_tensormap.sh allgather_Manual 2 0 +``` + +需要设置 `PTO_COMM_ISA_ROOT` 及多卡 HCCL 环境。 diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/golden.py b/examples/tensormap_and_ringbuffer/allgather_Manual/golden.py new file mode 100644 index 00000000..dd6ba20f --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/golden.py @@ -0,0 +1,55 @@ +""" +Golden reference for AllGather (Manual RDMA variant, no compute). + +Each rank contributes GATHER_COUNT float32 elements. +After AllGather, EVERY rank holds the concatenation of all ranks' data. +""" + +import ctypes +import numpy as np + +GATHER_COUNT = 64 + +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" +__outputs__ = ["out"] +RTOL = 1e-4 +ATOL = 1e-4 + +def generate_inputs(params: dict) -> list: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + np.random.seed(42 + rank_id) + src = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out = np.zeros((n_ranks * GATHER_COUNT,), dtype=np.float32) + + result = [ + ("src", src), + ("out", out), + ("size_src", ctypes.c_int64(src.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ] + + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), + ("root", ctypes.c_int32(root)), + ("rank_id", ctypes.c_int32(rank_id)), + ]) + + return result + +def compute_golden(tensors: dict, params: dict) -> None: + n_ranks = params.get("n_ranks", 2) + out = tensors["out"] + + out_np = out.cpu().numpy() if hasattr(out, 'cpu') else np.asarray(out) + for r in range(n_ranks): + np.random.seed(42 + r) + src_r = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = src_r[:GATHER_COUNT] diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp new file mode 100644 index 00000000..fde62b31 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/allgather_manual_kernel.cpp @@ -0,0 +1,75 @@ +/** + * Manual AllGather kernel - direct RDMA reads, no TGATHER. + * + * Each rank independently reads from all ranks' win_src via HcclRemotePtr + * and writes to local dst. All ranks run in parallel. + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalars for ctx/n_ranks/rank_id. + * args[0] = dst (TensorData*) + * args[1] = src (TensorData*) + * args[2] = sync_done (TensorData*, dependency - ignored) + * args[3] = device_ctx_ptr (scalar) + * args[4] = nranks (scalar) + * args[5] = rank_id (scalar, unused) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + (void)args[2]; + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[3]); + int nranks = static_cast(args[4]); + (void)args[5]; + + __gm__ float* dst = reinterpret_cast<__gm__ float*>(dst_td->buffer.addr); + __gm__ float* src = reinterpret_cast<__gm__ float*>(src_td->buffer.addr); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + ShapeDyn sliceShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn sliceStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int r = 0; r < actual_nranks; ++r) { + __gm__ float* remote_src = HcclRemotePtr(hcclCtx, src, r); + __gm__ float* local_dst = dst + static_cast(r) * GATHER_COUNT; + + Global srcG(remote_src, sliceShape, sliceStride); + Global dstG(local_dst, sliceShape, sliceStride); + + TLOAD(ubTile, srcG); + set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + TSTORE(dstG, ubTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp new file mode 100644 index 00000000..24a7d384 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/comm_barrier_all_kernel.cpp @@ -0,0 +1,57 @@ +/** + * All-to-all barrier: every rank waits for every other rank. + * + * Used by AllGather where every rank reads from all ranks' windows. + * ALL ranks do TWAIT here. + * + * Tensormap_and_ringbuffer: args[0..4] as below, args[5] = sync_done (TensorData* output) + * args[0] = barrier_base (TensorData*) + * args[1] = device_ctx_ptr (scalar) + * args[2] = n_ranks (scalar) + * args[3] = root (scalar) + * args[4] = dependency (TensorData*, ignored) + * args[5] = sync_done (TensorData* output) - write 1 after barrier for task ordering + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* barrier_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(barrier_td->buffer.addr); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + (void)args[4]; + + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + __gm__ int32_t* root_barrier = HcclRemotePtr(ctx, local_barrier, root); + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(root_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + + __gm__ TensorData* sync_td = reinterpret_cast<__gm__ TensorData*>(args[5]); + __gm__ int32_t* sync_done = reinterpret_cast<__gm__ int32_t*>(sync_td->buffer.addr); + sync_done[0] = 1; + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/window_memcopy_in.cpp b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..04740386 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,36 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before TGATHER so remote ranks can read. + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalar for count. + * args[0] = win_dst (TensorData*) + * args[1] = dev_src (TensorData*) + * args[2] = count (scalar) + */ + +#include +#include + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* win_dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* dev_src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + int count = static_cast(args[2]); + + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(win_dst_td->buffer.addr); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(dev_src_td->buffer.addr); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/window_memcopy_out.cpp b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..abd63fb1 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,36 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * Root only - after TGATHER, copy gathered result to device. + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalar for count. + * args[0] = dev_dst (TensorData*) + * args[1] = win_src (TensorData*) + * args[2] = count (scalar) + */ + +#include +#include + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* dev_dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* win_src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + int count = static_cast(args[2]); + + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(dev_dst_td->buffer.addr); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(win_src_td->buffer.addr); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/kernel_config.py b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/kernel_config.py new file mode 100644 index 00000000..2ffce61b --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/kernel_config.py @@ -0,0 +1,32 @@ +""" +AllGather (Manual) for tensormap_and_ringbuffer runtime. + +Direct RDMA reads for performance comparison. No TGATHER. +Flow: WindowMemCopyIn -> CommBarrier(pre) -> AllGatherManual -> WindowMemCopyOut -> CommBarrier(post) +All ranks run in parallel. Every rank gets full concatenation. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "allgather_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 1, "name": "AllGatherManual", "source": str(_KERNELS_ROOT / "aiv" / "allgather_manual_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "CommBarrierAll", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/orchestration/allgather_orch.cpp b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/orchestration/allgather_orch.cpp new file mode 100644 index 00000000..f43d3277 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Manual/kernels/orchestration/allgather_orch.cpp @@ -0,0 +1,120 @@ +/** + * AllGather (Manual) orchestration for tensormap_and_ringbuffer runtime. + * + * Flow: WindowMemCopyIn -> CommBarrier(pre) -> AllGatherManual -> WindowMemCopyOut -> CommBarrier(post) + * All ranks run in parallel. Every rank gets full concatenation. + * + * Args (10): [0] dev_src, [1] dev_out, [2] size_src, [3] size_out, + * [4] device_ctx_ptr, [5] win_in_base, [6] win_out_base, + * [7] n_ranks, [8] root (unused), [9] rank_id + */ + +#include +#include + +#include "pto_orchestration_api.h" + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_WIN_MEMCOPY_IN 0 +#define FUNC_ALLGATHER 1 +#define FUNC_WIN_MEMCOPY_OUT 2 +#define FUNC_COMM_BARRIER 3 + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 10, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count) { + (void)arg_count; + pto2_rt_init_tensor_pool(rt); + + void* dev_src = reinterpret_cast(args[0]); + void* dev_out = reinterpret_cast(args[1]); + uint64_t device_ctx_ptr = args[4]; + uint64_t win_in_base = args[5]; + (void)args[6]; // win_out_base unused + int n_ranks = static_cast(args[7]); + int rank_id = static_cast(args[9]); + + LOG_INFO(rt, "allgather_Manual: n_ranks=%d rank_id=%d", n_ranks, rank_id); + + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base_pre = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t barrier_base_post = barrier_base_pre + barrier_size; + uint64_t win_src = barrier_base_post + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + uint64_t src_shapes[1] = {GATHER_COUNT}; + uint64_t dst_shapes[1] = {static_cast(n_ranks) * GATHER_COUNT}; + uint64_t barrier_shapes[1] = {static_cast(n_ranks)}; + uint64_t sync_shapes[1] = {1}; + Tensor sync_done_t = make_tensor(sync_shapes, 1, DataType::INT32); + + Tensor dev_src_t = make_tensor_external(dev_src, src_shapes, 1, DataType::FLOAT32); + Tensor dev_out_t = make_tensor_external(dev_out, dst_shapes, 1, DataType::FLOAT32); + Tensor win_src_t = make_tensor_external(reinterpret_cast(win_src), src_shapes, 1, DataType::FLOAT32); + Tensor win_dst_t = make_tensor_external(reinterpret_cast(win_dst), dst_shapes, 1, DataType::FLOAT32); + Tensor barrier_pre_t = make_tensor_external(reinterpret_cast(barrier_base_pre), barrier_shapes, 1, DataType::INT32); + Tensor barrier_post_t = make_tensor_external(reinterpret_cast(barrier_base_post), barrier_shapes, 1, DataType::INT32); + + PTO2_SCOPE(rt) { + PTOParam params_wmin[] = { + make_output_param(win_src_t), + make_input_param(dev_src_t), + make_scalar_param(static_cast(GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_IN, PTO2_WORKER_VECTOR, params_wmin, 3); + + PTOParam params_barrier_pre[] = { + make_input_param(barrier_pre_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(win_src_t), + make_output_param(sync_done_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier_pre, 6); + + PTOParam params_allgather[] = { + make_output_param(win_dst_t), + make_input_param(win_src_t), + make_input_param(sync_done_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(rank_id)), + }; + pto2_rt_submit_task(rt, FUNC_ALLGATHER, PTO2_WORKER_VECTOR, params_allgather, 6); + + PTOParam params_wmout[] = { + make_output_param(dev_out_t), + make_input_param(win_dst_t), + make_scalar_param(static_cast(n_ranks * GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_OUT, PTO2_WORKER_VECTOR, params_wmout, 3); + + Tensor sync_post_t = make_tensor(sync_shapes, 1, DataType::INT32); + PTOParam params_barrier_post[] = { + make_input_param(barrier_post_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(win_dst_t), + make_output_param(sync_post_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier_post, 6); + } + + LOG_INFO(rt, "allgather_Manual tasks submitted"); +} + +} // extern "C" diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/README.md b/examples/tensormap_and_ringbuffer/allgather_Tgather/README.md new file mode 100644 index 00000000..1015a2d5 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/README.md @@ -0,0 +1,16 @@ +# AllGather (TGATHER) - tensormap_and_ringbuffer + +N 次顺序 TGATHER 实现 AllGather,用于性能对比。 + +流程:WindowMemCopyIn -> for r in [0,n_ranks): Barrier -> Gather(root=r) -> [rank==r: WindowMemCopyOut] -> Barrier(post) + +每轮仅 root 调用 TGATHER,避免死锁。 + +## 运行 + +```bash +cd simpler-PTO +./run_tensormap.sh allgather_Tgather 2 0 +``` + +需要设置 `PTO_COMM_ISA_ROOT` 及多卡 HCCL 环境。 diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/golden.py b/examples/tensormap_and_ringbuffer/allgather_Tgather/golden.py new file mode 100644 index 00000000..90ed69bf --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/golden.py @@ -0,0 +1,55 @@ +""" +Golden reference for AllGather (TGATHER variant, no compute). + +Each rank contributes GATHER_COUNT float32 elements. +After AllGather, EVERY rank holds the concatenation of all ranks' data. +""" + +import ctypes +import numpy as np + +GATHER_COUNT = 64 + +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" +__outputs__ = ["out"] +RTOL = 1e-4 +ATOL = 1e-4 + +def generate_inputs(params: dict) -> list: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + np.random.seed(42 + rank_id) + src = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out = np.zeros((n_ranks * GATHER_COUNT,), dtype=np.float32) + + result = [ + ("src", src), + ("out", out), + ("size_src", ctypes.c_int64(src.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ] + + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), + ("root", ctypes.c_int32(root)), + ("rank_id", ctypes.c_int32(rank_id)), + ]) + + return result + +def compute_golden(tensors: dict, params: dict) -> None: + n_ranks = params.get("n_ranks", 2) + out = tensors["out"] + + out_np = out.cpu().numpy() if hasattr(out, 'cpu') else np.asarray(out) + for r in range(n_ranks): + np.random.seed(42 + r) + src_r = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = src_r[:GATHER_COUNT] diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp new file mode 100644 index 00000000..24a7d384 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/comm_barrier_all_kernel.cpp @@ -0,0 +1,57 @@ +/** + * All-to-all barrier: every rank waits for every other rank. + * + * Used by AllGather where every rank reads from all ranks' windows. + * ALL ranks do TWAIT here. + * + * Tensormap_and_ringbuffer: args[0..4] as below, args[5] = sync_done (TensorData* output) + * args[0] = barrier_base (TensorData*) + * args[1] = device_ctx_ptr (scalar) + * args[2] = n_ranks (scalar) + * args[3] = root (scalar) + * args[4] = dependency (TensorData*, ignored) + * args[5] = sync_done (TensorData* output) - write 1 after barrier for task ordering + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* barrier_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(barrier_td->buffer.addr); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + (void)args[4]; + + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + __gm__ int32_t* root_barrier = HcclRemotePtr(ctx, local_barrier, root); + for (int i = 0; i < n_ranks; ++i) { + pto::comm::Signal slot(root_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + + __gm__ TensorData* sync_td = reinterpret_cast<__gm__ TensorData*>(args[5]); + __gm__ int32_t* sync_done = reinterpret_cast<__gm__ int32_t*>(sync_td->buffer.addr); + sync_done[0] = 1; + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/gather_kernel.cpp b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/gather_kernel.cpp new file mode 100644 index 00000000..95051bb8 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/gather_kernel.cpp @@ -0,0 +1,76 @@ +/** + * TGATHER collective kernel - root gathers from all ranks. + * Requires pto-comm-isa (PTO_ISA_ROOT or PTO_COMM_ISA_ROOT). + * + * Tensormap_and_ringbuffer (allgather_Tgather variant): 6 args with sync_done dependency. + * args[0] = dst (TensorData*) + * args[1] = src (TensorData*) + * args[2] = sync_done (TensorData*, dependency only - ignored) + * args[3] = device_ctx_ptr (scalar) + * args[4] = nranks (scalar) + * args[5] = root (scalar) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + (void)args[2]; // sync_done dependency - ignored + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[3]); + int nranks = static_cast(args[4]); + int root = static_cast(args[5]); + + __gm__ float* dst = reinterpret_cast<__gm__ float*>(dst_td->buffer.addr); + __gm__ float* src = reinterpret_cast<__gm__ float*>(src_td->buffer.addr); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + + using TileData = pto::Tile; + + int my_rank = static_cast(hcclCtx->rankId); + + ShapeDyn srcShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn srcStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + ShapeDyn dstShape(1, 1, 1, nranks, GATHER_COUNT); + StrideDyn dstStride(nranks * GATHER_COUNT, nranks * GATHER_COUNT, nranks * GATHER_COUNT, GATHER_COUNT, 1); + Global dstG(dst, dstShape, dstStride); + + Global tensors[16]; + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int i = 0; i < actual_nranks; ++i) { + __gm__ float* remoteSrc = HcclRemotePtr(hcclCtx, src, i); + tensors[i] = Global(remoteSrc, srcShape, srcStride); + } + + pto::comm::ParallelGroup pg(tensors, actual_nranks, root); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + if (my_rank == root) { + pto::comm::TGATHER(pg, dstG, ubTile); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/window_memcopy_in.cpp b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..04740386 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,36 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before TGATHER so remote ranks can read. + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalar for count. + * args[0] = win_dst (TensorData*) + * args[1] = dev_src (TensorData*) + * args[2] = count (scalar) + */ + +#include +#include + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* win_dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* dev_src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + int count = static_cast(args[2]); + + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(win_dst_td->buffer.addr); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(dev_src_td->buffer.addr); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/window_memcopy_out.cpp b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..abd63fb1 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,36 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * Root only - after TGATHER, copy gathered result to device. + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalar for count. + * args[0] = dev_dst (TensorData*) + * args[1] = win_src (TensorData*) + * args[2] = count (scalar) + */ + +#include +#include + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* dev_dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* win_src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + int count = static_cast(args[2]); + + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(dev_dst_td->buffer.addr); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(win_src_td->buffer.addr); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/kernel_config.py b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/kernel_config.py new file mode 100644 index 00000000..ca41698d --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/kernel_config.py @@ -0,0 +1,31 @@ +""" +AllGather (TGATHER) for tensormap_and_ringbuffer runtime. + +N sequential Gathers. Flow: WindowMemCopyIn -> for r in [0, n_ranks): Barrier_r +-> Gather(root=r) -> [if rank_id==r: WindowMemCopyOut] -> Barrier(post) +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "allgather_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 1, "name": "Gather", "source": str(_KERNELS_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "CommBarrierAll", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/orchestration/allgather_orch.cpp b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/orchestration/allgather_orch.cpp new file mode 100644 index 00000000..afd81ea5 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/allgather_Tgather/kernels/orchestration/allgather_orch.cpp @@ -0,0 +1,128 @@ +/** + * AllGather (TGATHER) orchestration for tensormap_and_ringbuffer runtime. + * + * Flow: WindowMemCopyIn -> for r in [0, n_ranks): Barrier_r -> Gather(root=r) + * -> [if rank_id==r: WindowMemCopyOut] -> Barrier(post) + * Only root calls TGATHER per round. + * + * Args (10): [0] dev_src, [1] dev_out, [2] size_src, [3] size_out, + * [4] device_ctx_ptr, [5] win_in_base, [6] win_out_base, + * [7] n_ranks, [8] root (unused), [9] rank_id + */ + +#include +#include + +#include "pto_orchestration_api.h" + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_WIN_MEMCOPY_IN 0 +#define FUNC_GATHER 1 +#define FUNC_WIN_MEMCOPY_OUT 2 +#define FUNC_COMM_BARRIER 3 + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 10, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count) { + (void)arg_count; + pto2_rt_init_tensor_pool(rt); + + void* dev_src = reinterpret_cast(args[0]); + void* dev_out = reinterpret_cast(args[1]); + uint64_t device_ctx_ptr = args[4]; + uint64_t win_in_base = args[5]; + (void)args[6]; + int n_ranks = static_cast(args[7]); + int rank_id = static_cast(args[9]); + + LOG_INFO(rt, "allgather_Tgather: n_ranks=%d rank_id=%d", n_ranks, rank_id); + + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + size_t total_barrier_bytes = barrier_size * (static_cast(n_ranks) + 1); + uint64_t barrier_base_0 = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base_0 + total_barrier_bytes; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + uint64_t src_shapes[1] = {GATHER_COUNT}; + uint64_t dst_shapes[1] = {static_cast(n_ranks) * GATHER_COUNT}; + uint64_t barrier_shapes[1] = {static_cast(n_ranks)}; + uint64_t sync_shapes[1] = {1}; + + Tensor dev_src_t = make_tensor_external(dev_src, src_shapes, 1, DataType::FLOAT32); + Tensor dev_out_t = make_tensor_external(dev_out, dst_shapes, 1, DataType::FLOAT32); + Tensor win_src_t = make_tensor_external(reinterpret_cast(win_src), src_shapes, 1, DataType::FLOAT32); + Tensor win_dst_t = make_tensor_external(reinterpret_cast(win_dst), dst_shapes, 1, DataType::FLOAT32); + + PTO2_SCOPE(rt) { + PTOParam params_wmin[] = { + make_output_param(win_src_t), + make_input_param(dev_src_t), + make_scalar_param(static_cast(GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_IN, PTO2_WORKER_VECTOR, params_wmin, 3); + + for (int r = 0; r < n_ranks; r++) { + uint64_t barrier_base_r = barrier_base_0 + static_cast(r) * barrier_size; + Tensor barrier_r_t = make_tensor_external(reinterpret_cast(barrier_base_r), barrier_shapes, 1, DataType::INT32); + Tensor sync_r_t = make_tensor(sync_shapes, 1, DataType::INT32); + + PTOParam params_barrier[] = { + make_input_param(barrier_r_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(r == 0 ? win_src_t : win_dst_t), + make_output_param(sync_r_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier, 6); + + PTOParam params_gather[] = { + make_output_param(win_dst_t), + make_input_param(win_src_t), + make_input_param(sync_r_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(r)), + }; + pto2_rt_submit_task(rt, FUNC_GATHER, PTO2_WORKER_VECTOR, params_gather, 6); + + if (rank_id == r) { + PTOParam params_wmout[] = { + make_output_param(dev_out_t), + make_input_param(win_dst_t), + make_scalar_param(static_cast(n_ranks * GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_OUT, PTO2_WORKER_VECTOR, params_wmout, 3); + } + } + + uint64_t barrier_base_post = barrier_base_0 + static_cast(n_ranks) * barrier_size; + Tensor barrier_post_t = make_tensor_external(reinterpret_cast(barrier_base_post), barrier_shapes, 1, DataType::INT32); + Tensor sync_post_t = make_tensor(sync_shapes, 1, DataType::INT32); + PTOParam params_barrier_post[] = { + make_input_param(barrier_post_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(win_dst_t), + make_output_param(sync_post_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier_post, 6); + } + + LOG_INFO(rt, "allgather_Tgather tasks submitted"); +} + +} // extern "C" diff --git a/examples/tensormap_and_ringbuffer/gather/README.md b/examples/tensormap_and_ringbuffer/gather/README.md new file mode 100644 index 00000000..6b60d2eb --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/README.md @@ -0,0 +1,26 @@ +# Gather (tensormap_and_ringbuffer) + +纯 gather 通信算子,使用 tensormap_and_ringbuffer 运行时(设备端编排)。 + +流程:WindowMemCopyIn → CommBarrier → TGATHER → WindowMemCopyOut (root only) + +与 host_build_graph/gather 的区别: +- 使用 `aicpu_orchestration_entry` 设备端编排 +- 内核使用 TensorData 接口(buffer.addr) +- 非 root rank 使用 ZeroBuffer 初始化输出以通过校验 + +## 运行 + +```bash +cd simpler-PTO +python examples/scripts/multi_card_run_example.py \ + -k examples/tensormap_and_ringbuffer/gather/kernels \ + -g examples/tensormap_and_ringbuffer/gather/golden.py +``` + +或使用便捷脚本: +```bash +./run_tensormap.sh gather 2 0 +``` + +需要设置 `PTO_COMM_ISA_ROOT` 指向 pto-comm-isa 根目录,以及多卡 HCCL 环境。 diff --git a/examples/tensormap_and_ringbuffer/gather/golden.py b/examples/tensormap_and_ringbuffer/gather/golden.py new file mode 100644 index 00000000..58602dd6 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/golden.py @@ -0,0 +1,67 @@ +""" +Golden reference for gather: multi-card TGATHER only, no computation. + +Each rank has local src data (GATHER_COUNT elements). Root gathers first +GATHER_COUNT from each rank into out: [rank0_data, rank1_data, ...]. + +Same interface as host_build_graph/gather - requires_comm with device_ctx_ptr, +win_in_base, win_out_base, n_ranks, root, rank_id. +""" + +import ctypes +import numpy as np + +GATHER_COUNT = 64 + +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" +__outputs__ = ["out"] +RTOL = 1e-4 +ATOL = 1e-4 + + +def generate_inputs(params: dict) -> list: + """Return flat argument list. For requires_comm, params includes device_ctx_ptr, win_in_base, win_out_base, n_ranks, root, rank_id.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + # Per-rank src data (different per rank) + np.random.seed(42 + rank_id) + src = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out = np.zeros((n_ranks * GATHER_COUNT,), dtype=np.float32) # root only + + result = [ + ("src", src), + ("out", out), + ("size_src", ctypes.c_int64(src.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ] + + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), + ("root", ctypes.c_int32(root)), + ("rank_id", ctypes.c_int32(rank_id)), + ]) + + return result + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected: gather first GATHER_COUNT from each rank to root.""" + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + + out = tensors["out"] + + if rank_id == root: + out_np = out.cpu().numpy() if hasattr(out, 'cpu') else np.asarray(out) + for r in range(n_ranks): + np.random.seed(42 + r) + src_r = np.random.randn(GATHER_COUNT).astype(np.float32) * 0.1 + out_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = src_r[:GATHER_COUNT] diff --git a/examples/tensormap_and_ringbuffer/gather/kernels/aiv/comm_barrier_kernel.cpp b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/comm_barrier_kernel.cpp new file mode 100644 index 00000000..84c96bb2 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/comm_barrier_kernel.cpp @@ -0,0 +1,53 @@ +/** + * Device-side cross-rank barrier using TNOTIFY/TWAIT from pto-comm-isa. + * + * Tensormap_and_ringbuffer: args are TensorData* for barrier, scalars for ctx/n_ranks/root, + * and optional TensorData* for dependency (win_src - ignored). + * args[0] = barrier_base (TensorData*) + * args[1] = device_ctx_ptr (scalar) + * args[2] = n_ranks (scalar) + * args[3] = root (scalar) + * args[4] = win_src (TensorData*, dependency only - ignored) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* barrier_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ int32_t* local_barrier = reinterpret_cast<__gm__ int32_t*>(barrier_td->buffer.addr); + __gm__ HcclDeviceContext* ctx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[1]); + int n_ranks = static_cast(args[2]); + int root = static_cast(args[3]); + int my_rank = static_cast(ctx->rankId); + + (void)args[4]; // win_src dependency - ignored + + // Each rank writes flag=1 to root's barrier slot[my_rank] via RDMA. + __gm__ int32_t* remote_slot = HcclRemotePtr(ctx, local_barrier, root) + my_rank; + pto::comm::Signal sig(remote_slot); + pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::Set); + + // Root waits until every rank's flag is >= 1. + if (my_rank == root) { + for (int i = 0; i < n_ranks; i++) { + pto::comm::Signal slot(local_barrier + i); + pto::comm::TWAIT(slot, 1, pto::comm::WaitCmp::GE); + } + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/gather/kernels/aiv/gather_kernel.cpp b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/gather_kernel.cpp new file mode 100644 index 00000000..3e99ae0d --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/gather_kernel.cpp @@ -0,0 +1,74 @@ +/** + * TGATHER collective kernel - root gathers from all ranks. + * Requires pto-comm-isa (PTO_ISA_ROOT or PTO_COMM_ISA_ROOT). + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalars for ctx/n_ranks/root. + * args[0] = dst (TensorData*) + * args[1] = src (TensorData*) + * args[2] = device_ctx_ptr (scalar) + * args[3] = nranks (scalar) + * args[4] = root (scalar) + */ + +#include +#include +#include +#include "hccl_context.h" +#include "hccl_helpers.h" + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t GATHER_COUNT = 64; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ HcclDeviceContext* hcclCtx = reinterpret_cast<__gm__ HcclDeviceContext*>(args[2]); + int nranks = static_cast(args[3]); + int root = static_cast(args[4]); + + __gm__ float* dst = reinterpret_cast<__gm__ float*>(dst_td->buffer.addr); + __gm__ float* src = reinterpret_cast<__gm__ float*>(src_td->buffer.addr); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + + using TileData = pto::Tile; + + int my_rank = static_cast(hcclCtx->rankId); + + ShapeDyn srcShape(1, 1, 1, 1, GATHER_COUNT); + StrideDyn srcStride(GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, GATHER_COUNT, 1); + + ShapeDyn dstShape(1, 1, 1, nranks, GATHER_COUNT); + StrideDyn dstStride(nranks * GATHER_COUNT, nranks * GATHER_COUNT, nranks * GATHER_COUNT, GATHER_COUNT, 1); + Global dstG(dst, dstShape, dstStride); + + Global tensors[16]; + int actual_nranks = (nranks > 16) ? 16 : nranks; + for (int i = 0; i < actual_nranks; ++i) { + __gm__ float* remoteSrc = HcclRemotePtr(hcclCtx, src, i); + tensors[i] = Global(remoteSrc, srcShape, srcStride); + } + + pto::comm::ParallelGroup pg(tensors, actual_nranks, root); + + TileData ubTile(1, GATHER_COUNT); + TASSIGN(ubTile, 0x0); + + if (my_rank == root) { + pto::comm::TGATHER(pg, dstG, ubTile); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/gather/kernels/aiv/window_memcopy_in.cpp b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/window_memcopy_in.cpp new file mode 100644 index 00000000..04740386 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/window_memcopy_in.cpp @@ -0,0 +1,36 @@ +/** + * WindowMemCopyIn: Copy device buffer to HCCL window. + * Used before TGATHER so remote ranks can read. + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalar for count. + * args[0] = win_dst (TensorData*) + * args[1] = dev_src (TensorData*) + * args[2] = count (scalar) + */ + +#include +#include + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* win_dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* dev_src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + int count = static_cast(args[2]); + + __gm__ float* win_dst = reinterpret_cast<__gm__ float*>(win_dst_td->buffer.addr); + __gm__ float* dev_src = reinterpret_cast<__gm__ float*>(dev_src_td->buffer.addr); + + for (int i = 0; i < count; ++i) { + win_dst[i] = dev_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/gather/kernels/aiv/window_memcopy_out.cpp b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/window_memcopy_out.cpp new file mode 100644 index 00000000..abd63fb1 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/window_memcopy_out.cpp @@ -0,0 +1,36 @@ +/** + * WindowMemCopyOut: Copy HCCL window to device buffer. + * Root only - after TGATHER, copy gathered result to device. + * + * Tensormap_and_ringbuffer: args are TensorData* for buffers, scalar for count. + * args[0] = dev_dst (TensorData*) + * args[1] = win_src (TensorData*) + * args[2] = count (scalar) + */ + +#include +#include + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* dev_dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* win_src_td = reinterpret_cast<__gm__ TensorData*>(args[1]); + int count = static_cast(args[2]); + + __gm__ float* dev_dst = reinterpret_cast<__gm__ float*>(dev_dst_td->buffer.addr); + __gm__ float* win_src = reinterpret_cast<__gm__ float*>(win_src_td->buffer.addr); + + for (int i = 0; i < count; ++i) { + dev_dst[i] = win_src[i]; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/gather/kernels/aiv/zero_buffer.cpp b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/zero_buffer.cpp new file mode 100644 index 00000000..8b24d193 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/kernels/aiv/zero_buffer.cpp @@ -0,0 +1,35 @@ +/** + * ZeroBuffer: Zero a buffer. Used for non-root ranks to initialize output. + * + * Args: + * args[0] = dst (TensorData*) + * args[1] = count (scalar, in elements) + * args[2] = dependency (TensorData*, ignored - for task ordering) + */ + +#include +#include + +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* dst_td = reinterpret_cast<__gm__ TensorData*>(args[0]); + int count = static_cast(args[1]); + + (void)args[2]; // dependency - ignored + + __gm__ float* dst = reinterpret_cast<__gm__ float*>(dst_td->buffer.addr); + + for (int i = 0; i < count; ++i) { + dst[i] = 0.0f; + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/tensormap_and_ringbuffer/gather/kernels/kernel_config.py b/examples/tensormap_and_ringbuffer/gather/kernels/kernel_config.py new file mode 100644 index 00000000..64afb857 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/kernels/kernel_config.py @@ -0,0 +1,37 @@ +""" +Gather operator for tensormap_and_ringbuffer runtime. + +Multi-card TGATHER communication, no computation. +Flow: WindowMemCopyIn -> CommBarrier -> TGATHER -> WindowMemCopyOut (root only). + +Adapted from host_build_graph/gather with: +- aicpu_orchestration_entry (device-side orchestration) +- AIV_HUB for tensormap_and_ringbuffer runtime +- Kernels use TensorData interface (buffer.addr) +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "gather_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "name": "WindowMemCopyIn", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 1, "name": "Gather", "source": str(_KERNELS_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "WindowMemCopyOut", "source": str(_KERNELS_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "CommBarrier", "source": str(_KERNELS_ROOT / "aiv" / "comm_barrier_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 5, "name": "ZeroBuffer", "source": str(_KERNELS_ROOT / "aiv" / "zero_buffer.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/tensormap_and_ringbuffer/gather/kernels/orchestration/gather_orch.cpp b/examples/tensormap_and_ringbuffer/gather/kernels/orchestration/gather_orch.cpp new file mode 100644 index 00000000..a431fb7a --- /dev/null +++ b/examples/tensormap_and_ringbuffer/gather/kernels/orchestration/gather_orch.cpp @@ -0,0 +1,123 @@ +/** + * Gather orchestration for tensormap_and_ringbuffer runtime. + * + * Device-side orchestration: WindowMemCopyIn -> CommBarrier -> TGATHER -> WindowMemCopyOut (root only). + * Uses aicpu_orchestration_entry and pto2_rt_submit_task with Tensor params. + * + * Args (10): + * [0] dev_src (device ptr, from runtime_maker conversion of host src) + * [1] dev_out (device ptr, root only) + * [2] size_src + * [3] size_out + * [4] device_ctx_ptr + * [5] win_in_base + * [6] win_out_base + * [7] n_ranks + * [8] root + * [9] rank_id + */ + +#include +#include + +#include "pto_orchestration_api.h" + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_WIN_MEMCOPY_IN 0 +#define FUNC_GATHER 1 +#define FUNC_WIN_MEMCOPY_OUT 2 +#define FUNC_COMM_BARRIER 3 +#define FUNC_ZERO_BUFFER 5 + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 10, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count) { + (void)arg_count; + pto2_rt_init_tensor_pool(rt); + + void* dev_src = reinterpret_cast(args[0]); + void* dev_out = reinterpret_cast(args[1]); + uint64_t device_ctx_ptr = args[4]; + uint64_t win_in_base = args[5]; + uint64_t win_out_base = args[6]; + int n_ranks = static_cast(args[7]); + int root = static_cast(args[8]); + int rank_id = static_cast(args[9]); + + LOG_INFO(rt, "gather: n_ranks=%d root=%d rank_id=%d", n_ranks, root, rank_id); + + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + uint64_t src_shapes[1] = {GATHER_COUNT}; + uint64_t dst_shapes[1] = {static_cast(n_ranks) * GATHER_COUNT}; + uint64_t barrier_shapes[1] = {static_cast(n_ranks)}; + + Tensor dev_src_t = make_tensor_external(dev_src, src_shapes, 1, DataType::FLOAT32); + Tensor win_src_t = make_tensor_external(reinterpret_cast(win_src), src_shapes, 1, DataType::FLOAT32); + Tensor win_dst_t = make_tensor_external(reinterpret_cast(win_dst), dst_shapes, 1, DataType::FLOAT32); + Tensor barrier_t = make_tensor_external(reinterpret_cast(barrier_base), barrier_shapes, 1, DataType::INT32); + + PTO2_SCOPE(rt) { + PTOParam params_wmin[] = { + make_output_param(win_src_t), + make_input_param(dev_src_t), + make_scalar_param(static_cast(GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_IN, PTO2_WORKER_VECTOR, params_wmin, 3); + + PTOParam params_barrier[] = { + make_input_param(barrier_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(root)), + make_input_param(win_src_t), // dependency: wait for WindowMemCopyIn + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier, 5); + + PTOParam params_gather[] = { + make_output_param(win_dst_t), + make_input_param(win_src_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(root)), + }; + pto2_rt_submit_task(rt, FUNC_GATHER, PTO2_WORKER_VECTOR, params_gather, 5); + + if (rank_id == root) { + Tensor dev_out_t = make_tensor_external(dev_out, dst_shapes, 1, DataType::FLOAT32); + PTOParam params_wmout[] = { + make_output_param(dev_out_t), + make_input_param(win_dst_t), + make_scalar_param(static_cast(n_ranks * GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_OUT, PTO2_WORKER_VECTOR, params_wmout, 3); + } else { + Tensor dev_out_t = make_tensor_external(dev_out, dst_shapes, 1, DataType::FLOAT32); + PTOParam params_zero[] = { + make_output_param(dev_out_t), + make_scalar_param(static_cast(n_ranks * GATHER_COUNT)), + make_input_param(win_src_t), // dependency: run after CommBarrier + }; + pto2_rt_submit_task(rt, FUNC_ZERO_BUFFER, PTO2_WORKER_VECTOR, params_zero, 3); + } + } + + LOG_INFO(rt, "gather tasks submitted"); +} + +} // extern "C" diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/README.md b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/README.md new file mode 100644 index 00000000..37d6acaf --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/README.md @@ -0,0 +1,14 @@ +# Paged Attention + AllGather (Manual) - tensormap_and_ringbuffer + +Paged Attention 计算后 AllGather(直接 RDMA)。 + +流程:Paged Attention (QK->SF->PV->UP) -> WindowMemCopyIn -> CommBarrier +-> AllGatherManual -> WindowMemCopyOut -> CommBarrier(post) + +## 运行 + +```bash +./run_tensormap.sh paged_attention_allgather_Manual 2 0 +``` + +注意:编排文件 (paged_attention_allgather_orch.cpp) 需要完整实现,将 Phase 1(Paged Attention)与 Phase 2(AllGather)组合。当前仅包含 kernel_config 和 golden,编排需根据 runtime 接口补充。 diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/golden.py b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/golden.py new file mode 100644 index 00000000..52c5d99f --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/golden.py @@ -0,0 +1,149 @@ +""" +Paged Attention + AllGather (Manual): Paged Attention → AllGather. + +Each rank independently computes paged attention on its own Q/K/V data, +then AllGather: every rank gets the concatenation of all ranks' first +GATHER_COUNT elements of attn_out. + +Same golden logic as paged_attention_allgather_Tgather (Manual uses +direct RDMA reads instead of TGATHER, but output is identical). +""" + +import ctypes +import struct +import torch +import numpy as np + +GATHER_COUNT = 64 +BATCH = 1 +NUM_HEADS = 16 +KV_HEAD_NUM = 1 +HEAD_DIM = 16 +BLOCK_SIZE = 16 +CONTEXT_LEN = 16 +MAX_MODEL_LEN = 256 + +__outputs__ = ["attn_out", "allgather_out"] +RTOL = 1e-2 +ATOL = 1e-2 +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" + +def _make_block_table_and_context(): + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + cur_valid_blocks = (CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE + total_blocks = BATCH * cur_valid_blocks + torch.manual_seed(100) + block_table = torch.randint(0, max(total_blocks, 1), size=(BATCH, max_num_blocks_per_req), dtype=torch.int32) + context_lens = torch.full((BATCH,), CONTEXT_LEN, dtype=torch.int32) + return block_table, context_lens, total_blocks, max_num_blocks_per_req + +def _make_qkv(rank_id, total_blocks): + torch.manual_seed(42 + rank_id) + q = (torch.rand(BATCH, 1, NUM_HEADS * HEAD_DIM) - 0.5).to(torch.float16) + q = q.reshape(BATCH, NUM_HEADS, HEAD_DIM) + k = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) - 0.5).to(torch.float16) + v = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) * 2 - 1).to(torch.float16) + return q, k, v + +def generate_inputs(params: dict) -> list: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + block_table, context_lens, total_blocks, max_num_blocks_per_req = _make_block_table_and_context() + query_fp16, key_fp16, value_fp16 = _make_qkv(rank_id, total_blocks) + scale_value = 1.0 + scale_bits = struct.unpack('I', struct.pack('f', scale_value))[0] + config = torch.tensor([BATCH, NUM_HEADS, KV_HEAD_NUM, HEAD_DIM, BLOCK_SIZE, max_num_blocks_per_req, scale_bits], dtype=torch.int64) + query = query_fp16.flatten() + key_cache = key_fp16.flatten() + value_cache = value_fp16.flatten() + block_table_flat = block_table.flatten() + attn_out = torch.zeros(BATCH * NUM_HEADS * HEAD_DIM, dtype=torch.float32) + allgather_out = torch.zeros(n_ranks * GATHER_COUNT, dtype=torch.float32) + result = [ + ("query", query), ("key_cache", key_cache), ("value_cache", value_cache), + ("block_table", block_table_flat), ("context_lens", context_lens), + ("attn_out", attn_out), ("allgather_out", allgather_out), ("config", config), + ("size_query", ctypes.c_int64(query.nbytes)), ("size_key_cache", ctypes.c_int64(key_cache.nbytes)), + ("size_value_cache", ctypes.c_int64(value_cache.nbytes)), ("size_block_table", ctypes.c_int64(block_table_flat.nbytes)), + ("size_context_lens", ctypes.c_int64(context_lens.nbytes)), ("size_attn_out", ctypes.c_int64(attn_out.nbytes)), + ("size_allgather_out", ctypes.c_int64(allgather_out.nbytes)), ("size_config", ctypes.c_int64(config.nbytes)), + ] + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), ("root", ctypes.c_int32(root)), ("rank_id", ctypes.c_int32(rank_id)), + ]) + return result + +def paged_attention(query, key_cache, value_cache, num_kv_heads, num_heads, scale_value, block_table, context_lens): + assert num_kv_heads == 1 + batch, num_heads_dim, head_dim = query.shape + _, block_size, _, _ = key_cache.shape + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + q_tile = min(num_heads_dim, 128) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + oi, li, mi = None, None, None + for bn in range(max_bn): + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 + if not active_mask.any(): break + block_indices = block_table[:, bn] + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) + valid_mask = pos < valid_lens.unsqueeze(1) + valid_mask = valid_mask.unsqueeze(1) + sij = sij.masked_fill(~valid_mask, float('-inf')) + batch_mask = active_mask.view(-1, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + mij = sij.max(dim=-1, keepdim=True)[0] + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) + oi_new = torch.bmm(pij, vj_all) + if bn == 0: + oi, li, mi = oi_new, lij, mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + return out.reshape(-1, head_dim) + +def _compute_rank_attn(rank_id, block_table, context_lens, total_blocks): + q, k, v = _make_qkv(rank_id, total_blocks) + return paged_attention(q, k.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), v.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), + KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens) + +def compute_golden(tensors: dict, params: dict) -> None: + n_ranks = params.get("n_ranks", 2) + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + total_blocks = BATCH * ((CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE) + block_table = tensors["block_table"].reshape(BATCH, max_num_blocks_per_req) + context_lens_t = tensors["context_lens"] + query = tensors["query"].reshape(BATCH, NUM_HEADS, HEAD_DIM) + key_cache = tensors["key_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + value_cache = tensors["value_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + attn_result = paged_attention(query, key_cache, value_cache, KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens_t) + tensors["attn_out"][:] = attn_result.flatten() + allgather_np = tensors["allgather_out"].cpu().numpy() + for r in range(n_ranks): + attn_r = _compute_rank_attn(r, block_table, context_lens_t, total_blocks) + flat_r = attn_r.flatten().numpy() + allgather_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = flat_r[:GATHER_COUNT] diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/kernels/kernel_config.py b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/kernels/kernel_config.py new file mode 100644 index 00000000..19dea142 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/kernels/kernel_config.py @@ -0,0 +1,40 @@ +""" +Paged Attention + AllGather (Manual) for tensormap_and_ringbuffer. + +Flow: Paged Attention (QK->SF->PV->UP) -> WindowMemCopyIn -> CommBarrier +-> AllGatherManual -> WindowMemCopyOut -> CommBarrier(post) +All ranks get the full allgather output. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_PA_ROOT = _KERNELS_ROOT.parent.parent / "paged_attention" / "kernels" +_AG_ROOT = _KERNELS_ROOT.parent.parent / "allgather_Manual" / "kernels" + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "paged_attention_allgather_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "name": "QK", "source": str(_PA_ROOT / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 1, "name": "SF", "source": str(_PA_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "PV", "source": str(_PA_ROOT / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 3, "name": "UP", "source": str(_PA_ROOT / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 4, "name": "AIC_HUB", "source": str(_PA_ROOT / "aic" / "aic_hub.cpp"), "core_type": "aic"}, + {"func_id": 5, "name": "AIV_HUB", "source": str(_PA_ROOT / "aiv" / "aiv_hub.cpp"), "core_type": "aiv"}, + {"func_id": 6, "name": "WindowMemCopyIn", "source": str(_AG_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 7, "name": "AllGatherManual", "source": str(_AG_ROOT / "aiv" / "allgather_manual_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 8, "name": "WindowMemCopyOut", "source": str(_AG_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 9, "name": "CommBarrierAll", "source": str(_AG_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/kernels/orchestration/paged_attention_allgather_orch.cpp b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/kernels/orchestration/paged_attention_allgather_orch.cpp new file mode 100644 index 00000000..1c353168 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Manual/kernels/orchestration/paged_attention_allgather_orch.cpp @@ -0,0 +1,277 @@ +/** + * Paged Attention + AllGather (Manual) orchestration for tensormap_and_ringbuffer. + * + * Flow: Phase 1 - Paged Attention (QK->SF->PV->UP) + * Phase 2 - WindowMemCopyIn -> CommBarrier(pre) -> AllGatherManual + * -> WindowMemCopyOut -> CommBarrier(post) + * + * Args (22): [0] query, [1] key_cache, [2] value_cache, [3] block_table, + * [4] context_lens, [5] attn_out, [6] allgather_out, [7] config, + * [8-15] 7 sizes, [16] device_ctx_ptr, [17] win_in_base, [18] win_out_base, + * [19] n_ranks, [20] root, [21] rank_id + */ + +#include +#include + +#include "pto_orchestration_api.h" + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 +#define FUNC_AIC_HUB 4 +#define FUNC_AIV_HUB 5 +#define FUNC_WIN_MEMCOPY_IN 6 +#define FUNC_ALLGATHER 7 +#define FUNC_WIN_MEMCOPY_OUT 8 +#define FUNC_COMM_BARRIER 9 + +static uint64_t float_to_u64(float f) { + union { + float f32; + uint64_t u64; + } conv; + conv.u64 = 0; + conv.f32 = f; + return conv.u64; +} + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 22, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count) { + (void)arg_count; + pto2_rt_init_tensor_pool(rt); + + void* dev_query = reinterpret_cast(args[0]); + void* dev_key_cache = reinterpret_cast(args[1]); + void* dev_value_cache = reinterpret_cast(args[2]); + int* dev_block_table = reinterpret_cast(args[3]); + int* dev_context_lens = reinterpret_cast(args[4]); + void* dev_attn_out = reinterpret_cast(args[5]); + void* dev_allgather_out = reinterpret_cast(args[6]); + int64_t* dev_config = reinterpret_cast(args[7]); + + size_t query_size = static_cast(args[8]); + size_t key_cache_size = static_cast(args[9]); + size_t value_cache_size = static_cast(args[10]); + (void)args[11]; + (void)args[12]; + (void)args[13]; + (void)args[14]; + (void)args[15]; + + uint64_t device_ctx_ptr = args[16]; + uint64_t win_in_base = args[17]; + (void)args[18]; + int n_ranks = static_cast(args[19]); + (void)args[20]; + int rank_id = static_cast(args[21]); + + uint64_t batch = static_cast(static_cast(dev_config[0])); + uint64_t num_heads = static_cast(static_cast(dev_config[1])); + int kv_head_num = static_cast(dev_config[2]); + uint64_t head_dim = static_cast(static_cast(dev_config[3])); + uint64_t block_size = static_cast(static_cast(dev_config[4])); + uint64_t block_num = static_cast(static_cast(dev_config[5])); + union { uint32_t u; float f; } scale_conv; + scale_conv.u = static_cast(dev_config[6]); + float scale_value = scale_conv.f; + + uint64_t q_head_num = num_heads; + uint64_t q_tile = 16; + uint64_t q_loop = (q_head_num + q_tile - 1) / q_tile; + DataType data_type = DataType::FLOAT16; + uint64_t elem_size = get_element_size(data_type); + + (void)kv_head_num; + + LOG_INFO(rt, "paged_attention_allgather_Manual: n_ranks=%d rank_id=%d batch=%lu", + n_ranks, rank_id, (unsigned long)batch); + + uint64_t query_shapes[2] = {batch * num_heads, head_dim}; + uint64_t kv_total_rows = key_cache_size / (head_dim * elem_size); + uint64_t key_cache_shapes[2] = {kv_total_rows, head_dim}; + uint64_t value_cache_shapes[2] = {kv_total_rows, head_dim}; + uint64_t attn_out_shapes[2] = {batch * num_heads, head_dim}; + + Tensor query = make_tensor_external(dev_query, query_shapes, 2, data_type); + Tensor key_cache = make_tensor_external(dev_key_cache, key_cache_shapes, 2, data_type); + Tensor value_cache = make_tensor_external(dev_value_cache, value_cache_shapes, 2, data_type); + Tensor attn_out = make_tensor_external(dev_attn_out, attn_out_shapes, 2, DataType::FLOAT32); + + /* Phase 1: Paged Attention */ + for (uint64_t b_idx = 0; b_idx < batch; b_idx++) { + uint64_t cur_seq = static_cast(dev_context_lens[b_idx]); + uint64_t bn_this_batch = (cur_seq + block_size - 1) / block_size; + for (uint64_t q_idx = 0; q_idx < q_loop; q_idx++) { + PTO2_SCOPE(rt) { + uint64_t cur_offset = b_idx * q_head_num + q_idx * q_tile; + uint64_t oi_shapes[2] = {q_tile, head_dim}; + uint64_t li_shapes[1] = {q_tile}; + uint64_t mi_shapes[1] = {q_tile}; + Tensor oi = make_tensor(oi_shapes, 2, DataType::FLOAT32); + Tensor li_update = make_tensor(li_shapes, 1, DataType::FLOAT32); + Tensor mi_update = make_tensor(mi_shapes, 1, DataType::FLOAT32); + + uint64_t qi_shapes[2] = {q_tile, head_dim}; + uint64_t qi_offsets[2] = {cur_offset, 0}; + Tensor qi = query.view(qi_shapes, qi_offsets); + uint64_t out_view_shapes[2] = {q_tile, head_dim}; + uint64_t out_view_offsets[2] = {cur_offset, 0}; + Tensor out_view = attn_out.view(out_view_shapes, out_view_offsets); + + PTOParam params_inplace[] = { + make_output_param(oi), + make_output_param(li_update), + make_output_param(mi_update), + }; + pto2_rt_submit_task(rt, FUNC_AIV_HUB, PTO2_WORKER_VECTOR, params_inplace, 3); + + for (uint64_t bn = 0; bn < bn_this_batch; bn++) { + uint64_t cur_block_idx = static_cast(dev_block_table[b_idx * block_num + bn]); + uint64_t valid_len = block_size < (cur_seq - bn * block_size) ? block_size : (cur_seq - bn * block_size); + uint64_t kv_shapes[2] = {block_size, head_dim}; + uint64_t kv_offsets[2] = {cur_block_idx * block_size, 0}; + Tensor kj = key_cache.view(kv_shapes, kv_offsets); + Tensor vj = value_cache.view(kv_shapes, kv_offsets); + + uint64_t sij_shapes[2] = {q_tile, block_size}; + Tensor sij = make_tensor(sij_shapes, 2, DataType::FLOAT32); + Tensor pij_f16 = make_tensor(sij_shapes, 2, data_type); + + PTOParam params_qk[] = { + make_input_param(qi), + make_input_param(kj), + make_output_param(sij), + }; + pto2_rt_submit_task(rt, FUNC_QK_MATMUL, PTO2_WORKER_CUBE, params_qk, 3); + + uint64_t sij_valid_shapes[2] = {q_tile, valid_len}; + uint64_t sij_valid_offsets[2] = {0, 0}; + Tensor sij_valid = sij.view(sij_valid_shapes, sij_valid_offsets); + Tensor li = make_tensor(li_shapes, 1, DataType::FLOAT32); + Tensor mi = make_tensor(mi_shapes, 1, DataType::FLOAT32); + PTOParam params_sf[] = { + make_input_param(sij_valid), + make_scalar_param(float_to_u64(scale_value)), + make_output_param(pij_f16), + make_output_param(mi), + make_output_param(li), + }; + pto2_rt_submit_task(rt, FUNC_SOFTMAX_PREPARE, PTO2_WORKER_VECTOR, params_sf, 5); + + uint64_t oi_tmp_shapes[2] = {q_tile, head_dim}; + Tensor oi_tmp = make_tensor(oi_tmp_shapes, 2, DataType::FLOAT32); + + PTOParam params_pv[] = { + make_input_param(pij_f16), + make_input_param(vj), + make_output_param(oi_tmp), + }; + pto2_rt_submit_task(rt, FUNC_PV_MATMUL, PTO2_WORKER_CUBE, params_pv, 3); + + uint64_t is_first = (bn == 0) ? 1 : 0; + uint64_t is_last = (bn == bn_this_batch - 1) ? 1 : 0; + + PTOParam params_up[] = { + make_input_param(mi), + make_input_param(li), + make_input_param(oi_tmp), + make_inout_param(mi_update), + make_inout_param(li_update), + make_inout_param(oi), + make_output_param(out_view), + make_scalar_param(is_first), + make_scalar_param(is_last), + }; + pto2_rt_submit_task(rt, FUNC_ONLINE_UPDATE, PTO2_WORKER_VECTOR, params_up, 9); + } + } + } + } + + /* Phase 2: AllGather - copy first GATHER_COUNT elements of attn_out, allgather, write to allgather_out */ + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + uint64_t barrier_base_pre = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t barrier_base_post = barrier_base_pre + barrier_size; + uint64_t win_src = barrier_base_post + barrier_size; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + uint64_t src_shapes[1] = {static_cast(GATHER_COUNT)}; + uint64_t dst_shapes[1] = {static_cast(n_ranks) * GATHER_COUNT}; + uint64_t barrier_shapes[1] = {static_cast(n_ranks)}; + uint64_t sync_shapes[1] = {1}; + Tensor sync_done_t = make_tensor(sync_shapes, 1, DataType::INT32); + + Tensor dev_src_t = make_tensor_external(dev_attn_out, src_shapes, 1, DataType::FLOAT32); + Tensor dev_out_t = make_tensor_external(dev_allgather_out, dst_shapes, 1, DataType::FLOAT32); + Tensor win_src_t = make_tensor_external(reinterpret_cast(win_src), src_shapes, 1, DataType::FLOAT32); + Tensor win_dst_t = make_tensor_external(reinterpret_cast(win_dst), dst_shapes, 1, DataType::FLOAT32); + Tensor barrier_pre_t = make_tensor_external(reinterpret_cast(barrier_base_pre), barrier_shapes, 1, DataType::INT32); + Tensor barrier_post_t = make_tensor_external(reinterpret_cast(barrier_base_post), barrier_shapes, 1, DataType::INT32); + + PTO2_SCOPE(rt) { + PTOParam params_wmin[] = { + make_output_param(win_src_t), + make_input_param(dev_src_t), + make_scalar_param(static_cast(GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_IN, PTO2_WORKER_VECTOR, params_wmin, 3); + + PTOParam params_barrier_pre[] = { + make_input_param(barrier_pre_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(win_src_t), + make_output_param(sync_done_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier_pre, 6); + + PTOParam params_allgather[] = { + make_output_param(win_dst_t), + make_input_param(win_src_t), + make_input_param(sync_done_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(rank_id)), + }; + pto2_rt_submit_task(rt, FUNC_ALLGATHER, PTO2_WORKER_VECTOR, params_allgather, 6); + + PTOParam params_wmout[] = { + make_output_param(dev_out_t), + make_input_param(win_dst_t), + make_scalar_param(static_cast(n_ranks * GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_OUT, PTO2_WORKER_VECTOR, params_wmout, 3); + + Tensor sync_post_t = make_tensor(sync_shapes, 1, DataType::INT32); + PTOParam params_barrier_post[] = { + make_input_param(barrier_post_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(win_dst_t), + make_output_param(sync_post_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier_post, 6); + } + + LOG_INFO(rt, "paged_attention_allgather_Manual tasks submitted"); +} + +} // extern "C" diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/README.md b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/README.md new file mode 100644 index 00000000..d2ea6400 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/README.md @@ -0,0 +1,14 @@ +# Paged Attention + AllGather (TGATHER) - tensormap_and_ringbuffer + +Paged Attention 计算后 AllGather(N 轮 TGATHER)。 + +流程:Paged Attention (QK->SF->PV->UP) -> WindowMemCopyIn -> for r in [0,n_ranks): +Barrier -> Gather(root=r) -> [rank r: WindowMemCopyOut] -> Barrier(post) + +## 运行 + +```bash +./run_tensormap.sh paged_attention_allgather_Tgather 2 0 +``` + +注意:编排文件 (paged_attention_allgather_orch.cpp) 需要完整实现。当前仅包含 kernel_config 和 golden,编排需根据 runtime 接口补充。 diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/golden.py b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/golden.py new file mode 100644 index 00000000..64f7a4dc --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/golden.py @@ -0,0 +1,144 @@ +""" +Paged Attention + AllGather (TGATHER): Paged Attention → AllGather. + +Same golden logic as paged_attention_allgather_Manual (output is identical). +""" + +import ctypes +import struct +import torch +import numpy as np + +GATHER_COUNT = 64 +BATCH = 1 +NUM_HEADS = 16 +KV_HEAD_NUM = 1 +HEAD_DIM = 16 +BLOCK_SIZE = 16 +CONTEXT_LEN = 16 +MAX_MODEL_LEN = 256 + +__outputs__ = ["attn_out", "allgather_out"] +RTOL = 1e-2 +ATOL = 1e-2 +ALL_CASES = {"Default": {}} +DEFAULT_CASE = "Default" + +def _make_block_table_and_context(): + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + cur_valid_blocks = (CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE + total_blocks = BATCH * cur_valid_blocks + torch.manual_seed(100) + block_table = torch.randint(0, max(total_blocks, 1), size=(BATCH, max_num_blocks_per_req), dtype=torch.int32) + context_lens = torch.full((BATCH,), CONTEXT_LEN, dtype=torch.int32) + return block_table, context_lens, total_blocks, max_num_blocks_per_req + +def _make_qkv(rank_id, total_blocks): + torch.manual_seed(42 + rank_id) + q = (torch.rand(BATCH, 1, NUM_HEADS * HEAD_DIM) - 0.5).to(torch.float16) + q = q.reshape(BATCH, NUM_HEADS, HEAD_DIM) + k = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) - 0.5).to(torch.float16) + v = (torch.rand(total_blocks, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) * 2 - 1).to(torch.float16) + return q, k, v + +def generate_inputs(params: dict) -> list: + rank_id = params.get("rank_id", 0) + n_ranks = params.get("n_ranks", 2) + root = params.get("root", 0) + block_table, context_lens, total_blocks, max_num_blocks_per_req = _make_block_table_and_context() + query_fp16, key_fp16, value_fp16 = _make_qkv(rank_id, total_blocks) + scale_value = 1.0 + scale_bits = struct.unpack('I', struct.pack('f', scale_value))[0] + config = torch.tensor([BATCH, NUM_HEADS, KV_HEAD_NUM, HEAD_DIM, BLOCK_SIZE, max_num_blocks_per_req, scale_bits], dtype=torch.int64) + query = query_fp16.flatten() + key_cache = key_fp16.flatten() + value_cache = value_fp16.flatten() + block_table_flat = block_table.flatten() + attn_out = torch.zeros(BATCH * NUM_HEADS * HEAD_DIM, dtype=torch.float32) + allgather_out = torch.zeros(n_ranks * GATHER_COUNT, dtype=torch.float32) + result = [ + ("query", query), ("key_cache", key_cache), ("value_cache", value_cache), + ("block_table", block_table_flat), ("context_lens", context_lens), + ("attn_out", attn_out), ("allgather_out", allgather_out), ("config", config), + ("size_query", ctypes.c_int64(query.nbytes)), ("size_key_cache", ctypes.c_int64(key_cache.nbytes)), + ("size_value_cache", ctypes.c_int64(value_cache.nbytes)), ("size_block_table", ctypes.c_int64(block_table_flat.nbytes)), + ("size_context_lens", ctypes.c_int64(context_lens.nbytes)), ("size_attn_out", ctypes.c_int64(attn_out.nbytes)), + ("size_allgather_out", ctypes.c_int64(allgather_out.nbytes)), ("size_config", ctypes.c_int64(config.nbytes)), + ] + if "device_ctx_ptr" in params and "win_in_base" in params and "win_out_base" in params: + result.extend([ + ("device_ctx_ptr", ctypes.c_uint64(params["device_ctx_ptr"])), + ("win_in_base", ctypes.c_uint64(params["win_in_base"])), + ("win_out_base", ctypes.c_uint64(params["win_out_base"])), + ("n_ranks", ctypes.c_int32(n_ranks)), ("root", ctypes.c_int32(root)), ("rank_id", ctypes.c_int32(rank_id)), + ]) + return result + +def paged_attention(query, key_cache, value_cache, num_kv_heads, num_heads, scale_value, block_table, context_lens): + assert num_kv_heads == 1 + batch, num_heads_dim, head_dim = query.shape + _, block_size, _, _ = key_cache.shape + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + q_tile = min(num_heads_dim, 128) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + oi, li, mi = None, None, None + for bn in range(max_bn): + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 + if not active_mask.any(): break + block_indices = block_table[:, bn] + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) + valid_mask = pos < valid_lens.unsqueeze(1) + valid_mask = valid_mask.unsqueeze(1) + sij = sij.masked_fill(~valid_mask, float('-inf')) + batch_mask = active_mask.view(-1, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + mij = sij.max(dim=-1, keepdim=True)[0] + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) + oi_new = torch.bmm(pij, vj_all) + if bn == 0: + oi, li, mi = oi_new, lij, mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + return out.reshape(-1, head_dim) + +def _compute_rank_attn(rank_id, block_table, context_lens, total_blocks): + q, k, v = _make_qkv(rank_id, total_blocks) + return paged_attention(q, k.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), v.reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM), + KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens) + +def compute_golden(tensors: dict, params: dict) -> None: + n_ranks = params.get("n_ranks", 2) + max_num_blocks_per_req = MAX_MODEL_LEN // BLOCK_SIZE + total_blocks = BATCH * ((CONTEXT_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE) + block_table = tensors["block_table"].reshape(BATCH, max_num_blocks_per_req) + context_lens_t = tensors["context_lens"] + query = tensors["query"].reshape(BATCH, NUM_HEADS, HEAD_DIM) + key_cache = tensors["key_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + value_cache = tensors["value_cache"].reshape(-1, BLOCK_SIZE, KV_HEAD_NUM, HEAD_DIM) + attn_result = paged_attention(query, key_cache, value_cache, KV_HEAD_NUM, NUM_HEADS, 1.0, block_table, context_lens_t) + tensors["attn_out"][:] = attn_result.flatten() + allgather_np = tensors["allgather_out"].cpu().numpy() + for r in range(n_ranks): + attn_r = _compute_rank_attn(r, block_table, context_lens_t, total_blocks) + flat_r = attn_r.flatten().numpy() + allgather_np[r * GATHER_COUNT : (r + 1) * GATHER_COUNT] = flat_r[:GATHER_COUNT] diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/kernels/kernel_config.py b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/kernels/kernel_config.py new file mode 100644 index 00000000..97a14183 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/kernels/kernel_config.py @@ -0,0 +1,39 @@ +""" +Paged Attention + AllGather (TGATHER) for tensormap_and_ringbuffer. + +Flow: Paged Attention (QK->SF->PV->UP) -> WindowMemCopyIn -> for r in [0,n_ranks): +Barrier -> Gather(root=r) -> [rank r: WindowMemCopyOut] -> Barrier(post) +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_PA_ROOT = _KERNELS_ROOT.parent.parent / "paged_attention" / "kernels" +_AG_ROOT = _KERNELS_ROOT.parent.parent / "allgather_Tgather" / "kernels" + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "paged_attention_allgather_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "name": "QK", "source": str(_PA_ROOT / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 1, "name": "SF", "source": str(_PA_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 2, "name": "PV", "source": str(_PA_ROOT / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 3, "name": "UP", "source": str(_PA_ROOT / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 4, "name": "AIC_HUB", "source": str(_PA_ROOT / "aic" / "aic_hub.cpp"), "core_type": "aic"}, + {"func_id": 5, "name": "AIV_HUB", "source": str(_PA_ROOT / "aiv" / "aiv_hub.cpp"), "core_type": "aiv"}, + {"func_id": 6, "name": "WindowMemCopyIn", "source": str(_AG_ROOT / "aiv" / "window_memcopy_in.cpp"), "core_type": "aiv"}, + {"func_id": 7, "name": "Gather", "source": str(_AG_ROOT / "aiv" / "gather_kernel.cpp"), "core_type": "aiv"}, + {"func_id": 8, "name": "WindowMemCopyOut", "source": str(_AG_ROOT / "aiv" / "window_memcopy_out.cpp"), "core_type": "aiv"}, + {"func_id": 9, "name": "CommBarrierAll", "source": str(_AG_ROOT / "aiv" / "comm_barrier_all_kernel.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, + "n_devices": 2, + "first_device_id": 0, + "requires_comm": True, +} diff --git a/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/kernels/orchestration/paged_attention_allgather_orch.cpp b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/kernels/orchestration/paged_attention_allgather_orch.cpp new file mode 100644 index 00000000..f5e5ec87 --- /dev/null +++ b/examples/tensormap_and_ringbuffer/paged_attention_allgather_Tgather/kernels/orchestration/paged_attention_allgather_orch.cpp @@ -0,0 +1,284 @@ +/** + * Paged Attention + AllGather (TGATHER) orchestration for tensormap_and_ringbuffer. + * + * Flow: Phase 1 - Paged Attention (QK->SF->PV->UP) + * Phase 2 - WindowMemCopyIn -> for r in [0,n_ranks): Barrier -> Gather(root=r) + * -> [rank r: WindowMemCopyOut] -> Barrier(post) + * + * Args (22): [0] query, [1] key_cache, [2] value_cache, [3] block_table, + * [4] context_lens, [5] attn_out, [6] allgather_out, [7] config, + * [8-15] 7 sizes, [16] device_ctx_ptr, [17] win_in_base, [18] win_out_base, + * [19] n_ranks, [20] root, [21] rank_id + */ + +#include +#include + +#include "pto_orchestration_api.h" + +constexpr int GATHER_COUNT = 64; +constexpr size_t HCCL_WIN_SYNC_PREFIX = 64 * sizeof(int32_t); + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 +#define FUNC_AIC_HUB 4 +#define FUNC_AIV_HUB 5 +#define FUNC_WIN_MEMCOPY_IN 6 +#define FUNC_GATHER 7 +#define FUNC_WIN_MEMCOPY_OUT 8 +#define FUNC_COMM_BARRIER 9 + +static uint64_t float_to_u64(float f) { + union { + float f32; + uint64_t u64; + } conv; + conv.u64 = 0; + conv.f32 = f; + return conv.u64; +} + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 22, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count) { + (void)arg_count; + pto2_rt_init_tensor_pool(rt); + + void* dev_query = reinterpret_cast(args[0]); + void* dev_key_cache = reinterpret_cast(args[1]); + void* dev_value_cache = reinterpret_cast(args[2]); + int* dev_block_table = reinterpret_cast(args[3]); + int* dev_context_lens = reinterpret_cast(args[4]); + void* dev_attn_out = reinterpret_cast(args[5]); + void* dev_allgather_out = reinterpret_cast(args[6]); + int64_t* dev_config = reinterpret_cast(args[7]); + + size_t query_size = static_cast(args[8]); + size_t key_cache_size = static_cast(args[9]); + size_t value_cache_size = static_cast(args[10]); + (void)args[11]; + (void)args[12]; + (void)args[13]; + (void)args[14]; + (void)args[15]; + + uint64_t device_ctx_ptr = args[16]; + uint64_t win_in_base = args[17]; + (void)args[18]; + int n_ranks = static_cast(args[19]); + (void)args[20]; + int rank_id = static_cast(args[21]); + + uint64_t batch = static_cast(static_cast(dev_config[0])); + uint64_t num_heads = static_cast(static_cast(dev_config[1])); + int kv_head_num = static_cast(dev_config[2]); + uint64_t head_dim = static_cast(static_cast(dev_config[3])); + uint64_t block_size = static_cast(static_cast(dev_config[4])); + uint64_t block_num = static_cast(static_cast(dev_config[5])); + union { uint32_t u; float f; } scale_conv; + scale_conv.u = static_cast(dev_config[6]); + float scale_value = scale_conv.f; + + uint64_t q_head_num = num_heads; + uint64_t q_tile = 16; + uint64_t q_loop = (q_head_num + q_tile - 1) / q_tile; + DataType data_type = DataType::FLOAT16; + uint64_t elem_size = get_element_size(data_type); + + (void)kv_head_num; + + LOG_INFO(rt, "paged_attention_allgather_Tgather: n_ranks=%d rank_id=%d batch=%lu", + n_ranks, rank_id, (unsigned long)batch); + + uint64_t query_shapes[2] = {batch * num_heads, head_dim}; + uint64_t kv_total_rows = key_cache_size / (head_dim * elem_size); + uint64_t key_cache_shapes[2] = {kv_total_rows, head_dim}; + uint64_t value_cache_shapes[2] = {kv_total_rows, head_dim}; + uint64_t attn_out_shapes[2] = {batch * num_heads, head_dim}; + + Tensor query = make_tensor_external(dev_query, query_shapes, 2, data_type); + Tensor key_cache = make_tensor_external(dev_key_cache, key_cache_shapes, 2, data_type); + Tensor value_cache = make_tensor_external(dev_value_cache, value_cache_shapes, 2, data_type); + Tensor attn_out = make_tensor_external(dev_attn_out, attn_out_shapes, 2, DataType::FLOAT32); + + /* Phase 1: Paged Attention */ + for (uint64_t b_idx = 0; b_idx < batch; b_idx++) { + uint64_t cur_seq = static_cast(dev_context_lens[b_idx]); + uint64_t bn_this_batch = (cur_seq + block_size - 1) / block_size; + for (uint64_t q_idx = 0; q_idx < q_loop; q_idx++) { + PTO2_SCOPE(rt) { + uint64_t cur_offset = b_idx * q_head_num + q_idx * q_tile; + uint64_t oi_shapes[2] = {q_tile, head_dim}; + uint64_t li_shapes[1] = {q_tile}; + uint64_t mi_shapes[1] = {q_tile}; + Tensor oi = make_tensor(oi_shapes, 2, DataType::FLOAT32); + Tensor li_update = make_tensor(li_shapes, 1, DataType::FLOAT32); + Tensor mi_update = make_tensor(mi_shapes, 1, DataType::FLOAT32); + + uint64_t qi_shapes[2] = {q_tile, head_dim}; + uint64_t qi_offsets[2] = {cur_offset, 0}; + Tensor qi = query.view(qi_shapes, qi_offsets); + uint64_t out_view_shapes[2] = {q_tile, head_dim}; + uint64_t out_view_offsets[2] = {cur_offset, 0}; + Tensor out_view = attn_out.view(out_view_shapes, out_view_offsets); + + PTOParam params_inplace[] = { + make_output_param(oi), + make_output_param(li_update), + make_output_param(mi_update), + }; + pto2_rt_submit_task(rt, FUNC_AIV_HUB, PTO2_WORKER_VECTOR, params_inplace, 3); + + for (uint64_t bn = 0; bn < bn_this_batch; bn++) { + uint64_t cur_block_idx = static_cast(dev_block_table[b_idx * block_num + bn]); + uint64_t valid_len = block_size < (cur_seq - bn * block_size) ? block_size : (cur_seq - bn * block_size); + uint64_t kv_shapes[2] = {block_size, head_dim}; + uint64_t kv_offsets[2] = {cur_block_idx * block_size, 0}; + Tensor kj = key_cache.view(kv_shapes, kv_offsets); + Tensor vj = value_cache.view(kv_shapes, kv_offsets); + + uint64_t sij_shapes[2] = {q_tile, block_size}; + Tensor sij = make_tensor(sij_shapes, 2, DataType::FLOAT32); + Tensor pij_f16 = make_tensor(sij_shapes, 2, data_type); + + PTOParam params_qk[] = { + make_input_param(qi), + make_input_param(kj), + make_output_param(sij), + }; + pto2_rt_submit_task(rt, FUNC_QK_MATMUL, PTO2_WORKER_CUBE, params_qk, 3); + + uint64_t sij_valid_shapes[2] = {q_tile, valid_len}; + uint64_t sij_valid_offsets[2] = {0, 0}; + Tensor sij_valid = sij.view(sij_valid_shapes, sij_valid_offsets); + Tensor li = make_tensor(li_shapes, 1, DataType::FLOAT32); + Tensor mi = make_tensor(mi_shapes, 1, DataType::FLOAT32); + PTOParam params_sf[] = { + make_input_param(sij_valid), + make_scalar_param(float_to_u64(scale_value)), + make_output_param(pij_f16), + make_output_param(mi), + make_output_param(li), + }; + pto2_rt_submit_task(rt, FUNC_SOFTMAX_PREPARE, PTO2_WORKER_VECTOR, params_sf, 5); + + uint64_t oi_tmp_shapes[2] = {q_tile, head_dim}; + Tensor oi_tmp = make_tensor(oi_tmp_shapes, 2, DataType::FLOAT32); + + PTOParam params_pv[] = { + make_input_param(pij_f16), + make_input_param(vj), + make_output_param(oi_tmp), + }; + pto2_rt_submit_task(rt, FUNC_PV_MATMUL, PTO2_WORKER_CUBE, params_pv, 3); + + uint64_t is_first = (bn == 0) ? 1 : 0; + uint64_t is_last = (bn == bn_this_batch - 1) ? 1 : 0; + + PTOParam params_up[] = { + make_input_param(mi), + make_input_param(li), + make_input_param(oi_tmp), + make_inout_param(mi_update), + make_inout_param(li_update), + make_inout_param(oi), + make_output_param(out_view), + make_scalar_param(is_first), + make_scalar_param(is_last), + }; + pto2_rt_submit_task(rt, FUNC_ONLINE_UPDATE, PTO2_WORKER_VECTOR, params_up, 9); + } + } + } + } + + /* Phase 2: AllGather (TGATHER) - for r in [0,n_ranks): Barrier -> Gather(root=r) -> [rank r: WindowMemCopyOut] */ + size_t barrier_size = static_cast(n_ranks) * sizeof(int32_t); + size_t total_barrier_bytes = barrier_size * (static_cast(n_ranks) + 1); + uint64_t barrier_base_0 = win_in_base + HCCL_WIN_SYNC_PREFIX; + uint64_t win_src = barrier_base_0 + total_barrier_bytes; + uint64_t win_dst = win_src + GATHER_COUNT * sizeof(float); + + uint64_t src_shapes[1] = {static_cast(GATHER_COUNT)}; + uint64_t dst_shapes[1] = {static_cast(n_ranks) * GATHER_COUNT}; + uint64_t barrier_shapes[1] = {static_cast(n_ranks)}; + uint64_t sync_shapes[1] = {1}; + + Tensor dev_src_t = make_tensor_external(dev_attn_out, src_shapes, 1, DataType::FLOAT32); + Tensor dev_out_t = make_tensor_external(dev_allgather_out, dst_shapes, 1, DataType::FLOAT32); + Tensor win_src_t = make_tensor_external(reinterpret_cast(win_src), src_shapes, 1, DataType::FLOAT32); + Tensor win_dst_t = make_tensor_external(reinterpret_cast(win_dst), dst_shapes, 1, DataType::FLOAT32); + + PTO2_SCOPE(rt) { + PTOParam params_wmin[] = { + make_output_param(win_src_t), + make_input_param(dev_src_t), + make_scalar_param(static_cast(GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_IN, PTO2_WORKER_VECTOR, params_wmin, 3); + + for (int r = 0; r < n_ranks; r++) { + uint64_t barrier_base_r = barrier_base_0 + static_cast(r) * barrier_size; + Tensor barrier_r_t = make_tensor_external(reinterpret_cast(barrier_base_r), barrier_shapes, 1, DataType::INT32); + Tensor sync_r_t = make_tensor(sync_shapes, 1, DataType::INT32); + + PTOParam params_barrier[] = { + make_input_param(barrier_r_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(r == 0 ? win_src_t : win_dst_t), + make_output_param(sync_r_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier, 6); + + PTOParam params_gather[] = { + make_output_param(win_dst_t), + make_input_param(win_src_t), + make_input_param(sync_r_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(r)), + }; + pto2_rt_submit_task(rt, FUNC_GATHER, PTO2_WORKER_VECTOR, params_gather, 6); + + if (rank_id == r) { + PTOParam params_wmout[] = { + make_output_param(dev_out_t), + make_input_param(win_dst_t), + make_scalar_param(static_cast(n_ranks * GATHER_COUNT)), + }; + pto2_rt_submit_task(rt, FUNC_WIN_MEMCOPY_OUT, PTO2_WORKER_VECTOR, params_wmout, 3); + } + } + + uint64_t barrier_base_post = barrier_base_0 + static_cast(n_ranks) * barrier_size; + Tensor barrier_post_t = make_tensor_external(reinterpret_cast(barrier_base_post), barrier_shapes, 1, DataType::INT32); + Tensor sync_post_t = make_tensor(sync_shapes, 1, DataType::INT32); + PTOParam params_barrier_post[] = { + make_input_param(barrier_post_t), + make_scalar_param(device_ctx_ptr), + make_scalar_param(static_cast(n_ranks)), + make_scalar_param(static_cast(0)), + make_input_param(win_dst_t), + make_output_param(sync_post_t), + }; + pto2_rt_submit_task(rt, FUNC_COMM_BARRIER, PTO2_WORKER_VECTOR, params_barrier_post, 6); + } + + LOG_INFO(rt, "paged_attention_allgather_Tgather tasks submitted"); +} + +} // extern "C" diff --git a/run_hostbuild.sh b/run_hostbuild.sh new file mode 100644 index 00000000..3a33dcaa --- /dev/null +++ b/run_hostbuild.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# +# 运行 host_build_graph 多卡算子测试 +# +# 用法: ./run_hostbuild.sh <算子名称> <设备数> <起始卡ID> +# 示例: ./run_hostbuild.sh allgather_Tgather 2 0 +# + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +if [ $# -lt 3 ]; then + echo "用法: $0 <算子名称> <设备数> <起始卡ID>" + echo "示例: $0 allgather_Tgather 2 0" + echo "" + echo "可用算子示例: allgather_Tgather, allgather_Manual, paged_attention_gather, paged_attention_allgather_Tgather, paged_attention_allgather_Manual, ..." + exit 1 +fi + +OP_NAME="$1" +N_DEVICES="$2" +FIRST_DEVICE="$3" + +KERNELS_DIR="examples/host_build_graph/${OP_NAME}/kernels" +GOLDEN_FILE="examples/host_build_graph/${OP_NAME}/golden.py" + +if [ ! -d "$KERNELS_DIR" ]; then + echo "错误: 算子目录不存在: $KERNELS_DIR" + exit 1 +fi + +if [ ! -f "$GOLDEN_FILE" ]; then + echo "错误: golden 文件不存在: $GOLDEN_FILE" + exit 1 +fi + +exec python3 examples/scripts/multi_card_run_example.py \ + -k "$KERNELS_DIR" \ + -g "$GOLDEN_FILE" \ + --n-devices "$N_DEVICES" \ + --first-device "$FIRST_DEVICE" \ + "${@:4}" diff --git a/run_tensormap.sh b/run_tensormap.sh new file mode 100644 index 00000000..7a054dec --- /dev/null +++ b/run_tensormap.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# +# 运行 tensormap_and_ringbuffer 算子测试 +# +# 用法: ./run_tensormap.sh <算子名称> [设备数] [起始卡ID] +# 示例: ./run_tensormap.sh paged_attention +# 示例: ./run_tensormap.sh gather 2 0 +# +# 单卡算子(设备数默认 1): vector_example, paged_attention, batch_paged_attention, bgemm +# 多卡算子(设备数需 >= 2): gather +# + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +if [ $# -lt 1 ]; then + echo "用法: $0 <算子名称> [设备数] [起始卡ID]" + echo "示例: $0 paged_attention" + echo "示例: $0 gather 2 0" + echo "" + echo "可用算子: vector_example, paged_attention, batch_paged_attention, bgemm, gather, allgather_Manual, allgather_Tgather, paged_attention_allgather_Manual, paged_attention_allgather_Tgather" + echo " - 单卡: vector_example, paged_attention, batch_paged_attention, bgemm (默认 1 卡)" + echo " - 多卡: gather, allgather_Manual, allgather_Tgather, paged_attention_allgather_Manual, paged_attention_allgather_Tgather (需指定 2 卡及以上)" + exit 1 +fi + +OP_NAME="$1" +N_DEVICES="${2:-1}" +FIRST_DEVICE="${3:-0}" + +KERNELS_DIR="examples/tensormap_and_ringbuffer/${OP_NAME}/kernels" +GOLDEN_FILE="examples/tensormap_and_ringbuffer/${OP_NAME}/golden.py" + +if [ ! -d "$KERNELS_DIR" ]; then + echo "错误: 算子目录不存在: $KERNELS_DIR" + exit 1 +fi + +if [ ! -f "$GOLDEN_FILE" ]; then + echo "错误: golden 文件不存在: $GOLDEN_FILE" + exit 1 +fi + +exec python3 examples/scripts/multi_card_run_example.py \ + -k "$KERNELS_DIR" \ + -g "$GOLDEN_FILE" \ + --n-devices "$N_DEVICES" \ + --first-device "$FIRST_DEVICE" \ + "${@:4}" diff --git a/src/runtime/host_build_graph/aicore/aicore_executor.cpp b/src/runtime/host_build_graph/aicore/aicore_executor.cpp index da665624..7c7ec121 100644 --- a/src/runtime/host_build_graph/aicore/aicore_executor.cpp +++ b/src/runtime/host_build_graph/aicore/aicore_executor.cpp @@ -10,11 +10,22 @@ __aicore__ __attribute__((always_inline)) static void execute_task(__gm__ Task* if (task->function_bin_addr == 0) { return; } + + // Invalidate data cache so kernel scalar reads fetch fresh data from GM, + // not stale cache lines left by a previous task on this core. + dcci(task, ENTIRE_DATA_CACHE); + KernelFunc kernel = (KernelFunc)task->function_bin_addr; kernel(reinterpret_cast<__gm__ int64_t*>(task->args)); - // Ensure all memory writes are visible to other cores + // Drain all pipelines (MTE2/MTE3/Vector/Cube) pipe_barrier(PIPE_ALL); + + // Flush (clean + invalidate) the entire data cache to GM. + // Without this, scalar writes stay in this core's cache and are invisible + // to MTE2 DMA reads issued by successor tasks on any core. + // This is the task-level equivalent of aclrtSynchronizeStream. + dcci(task, ENTIRE_DATA_CACHE, CACHELINE_OUT); } __aicore__ __attribute__((weak)) void aicore_execute(__gm__ Runtime* runtime, int block_idx, CoreType core_type) { diff --git a/src/runtime/tensormap_and_ringbuffer/aicore/aicore_executor.cpp b/src/runtime/tensormap_and_ringbuffer/aicore/aicore_executor.cpp index 384ea615..c4fcedd0 100644 --- a/src/runtime/tensormap_and_ringbuffer/aicore/aicore_executor.cpp +++ b/src/runtime/tensormap_and_ringbuffer/aicore/aicore_executor.cpp @@ -30,8 +30,11 @@ __aicore__ __attribute__((always_inline)) static void execute_task(__gm__ void* UnifiedKernelFunc kernel = (UnifiedKernelFunc)payload->function_bin_addr; kernel(reinterpret_cast<__gm__ int64_t*>(payload->args)); - // Ensure all memory writes are visible to other cores pipe_barrier(PIPE_ALL); + + // Flush (clean + invalidate) data cache to GM so successor tasks' MTE2 DMA + // reads see scalar writes from this kernel. + dcci(task_ptr, ENTIRE_DATA_CACHE, CACHELINE_OUT); } /**