diff --git a/ci.sh b/ci.sh index f108a018..602e4a86 100755 --- a/ci.sh +++ b/ci.sh @@ -16,8 +16,17 @@ while [[ $# -gt 0 ]]; do shift 2 ;; -d|--device) - DEVICE_RANGE="$2" - shift 2 + shift + DEVICE_VALUES=() + while [[ $# -gt 0 && "$1" != -* ]]; do + DEVICE_VALUES+=("$1") + shift + done + if [[ ${#DEVICE_VALUES[@]} -eq 0 ]]; then + echo "Missing value for --device" + exit 1 + fi + DEVICE_RANGE=$(IFS=,; echo "${DEVICE_VALUES[*]}") ;; -r|--runtime) RUNTIME="$2" @@ -78,15 +87,22 @@ if [[ -n "$RUNTIME" ]]; then fi fi -# Parse device range (e.g., "5-8" or "5") -if [[ "$DEVICE_RANGE" == *-* ]]; then - IFS='-' read -r DEV_START DEV_END <<< "$DEVICE_RANGE" - DEVICES=() - for ((i=DEV_START; i<=DEV_END; i++)); do - DEVICES+=("$i") - done +# Parse device spec (e.g., "5-8", "5", or "0,1,3,5") +DEVICES=() +if [[ -z "$DEVICE_RANGE" ]]; then + DEVICES=("0") else - DEVICES=("${DEVICE_RANGE:-0}") + IFS=',' read -r -a DEVICE_ITEMS <<< "$DEVICE_RANGE" + for item in "${DEVICE_ITEMS[@]}"; do + if [[ "$item" == *-* ]]; then + IFS='-' read -r DEV_START DEV_END <<< "$item" + for ((i=DEV_START; i<=DEV_END; i++)); do + DEVICES+=("$i") + done + else + DEVICES+=("$item") + fi + done fi NUM_DEVICES=${#DEVICES[@]} @@ -199,6 +215,40 @@ pin_pto_isa_on_failure() { return 0 # Pinned, caller should retry } +get_task_device_count() { + local kernel_config="$1" + python - "$kernel_config" <<'PY' +import importlib.util +import sys + +path = sys.argv[1] +spec = importlib.util.spec_from_file_location("kernel_config", path) +mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(mod) +dist = getattr(mod, "DISTRIBUTED_CONFIG", None) +nranks = 1 +if isinstance(dist, dict): + try: + nranks = int(dist.get("nranks", 1)) + except (TypeError, ValueError): + nranks = 1 +print(max(nranks, 1)) +PY +} + +format_device_spec() { + local count="$1" + if [[ "$count" -le 1 ]]; then + echo "${DEVICES[0]}" + return 0 + fi + + local selected=("${DEVICES[@]:0:$count}") + local joined + joined=$(IFS=,; echo "${selected[*]}") + echo "$joined" +} + # ---- Discover all tasks ---- EXAMPLES_DIR="examples" DEVICE_TESTS_DIR="tests/device_tests" @@ -206,6 +256,7 @@ DEVICE_TESTS_DIR="tests/device_tests" declare -a HW_TASK_NAMES=() declare -a HW_TASK_DIRS=() declare -a HW_TASK_PLATS=() +declare -a HW_TASK_DEVICE_COUNTS=() declare -a SIM_TASK_NAMES=() declare -a SIM_TASK_DIRS=() declare -a SIM_TASK_PLATS=() @@ -245,18 +296,22 @@ while IFS= read -r -d '' example_dir; do SIM_TASK_DIRS+=("${example_dir}") SIM_TASK_PLATS+=("${PLATFORM}") else + required_devices="$(get_task_device_count "$kernel_config")" HW_TASK_NAMES+=("example:${example_name}") HW_TASK_DIRS+=("${example_dir}") HW_TASK_PLATS+=("${PLATFORM}") + HW_TASK_DEVICE_COUNTS+=("${required_devices}") fi elif [[ "$OS" == "Darwin" ]]; then SIM_TASK_NAMES+=("example:${example_name}") SIM_TASK_DIRS+=("${example_dir}") SIM_TASK_PLATS+=("${example_arch}sim") else + required_devices="$(get_task_device_count "$kernel_config")" HW_TASK_NAMES+=("example:${example_name}") HW_TASK_DIRS+=("${example_dir}") HW_TASK_PLATS+=("${example_arch}") + HW_TASK_DEVICE_COUNTS+=("${required_devices}") SIM_TASK_NAMES+=("example:${example_name}") SIM_TASK_DIRS+=("${example_dir}") SIM_TASK_PLATS+=("${example_arch}sim") @@ -299,6 +354,7 @@ if [[ -d "$DEVICE_TESTS_DIR" ]]; then HW_TASK_NAMES+=("device_test:${test_name}") HW_TASK_DIRS+=("${test_dir}") HW_TASK_PLATS+=("${PLATFORM:-${test_arch}}") + HW_TASK_DEVICE_COUNTS+=("$(get_task_device_count "$kernel_config")") done < <(find "$DEVICE_TESTS_DIR" -mindepth 1 -type d -print0 | sort -z) else echo "Skipping device tests (hardware platforms only)" @@ -314,7 +370,7 @@ MAX_RETRIES=3 # Log naming: ${safe_name}_${platform}_attempt${attempt}.log # Result format: name|platform|PASS/FAIL|device:X|attempt:N|Xs run_task() { - local name="$1" dir="$2" platform="$3" attempt="$4" device_id="$5" print_log_on_fail="${6:-true}" + local name="$1" dir="$2" platform="$3" attempt="$4" device_spec="$5" print_log_on_fail="${6:-true}" required_devices="${7:-1}" local safe_name="${name//[:\/]/_}" local task_log="${LOG_DIR}/${safe_name}_${platform}_attempt${attempt}.log" local start_time=$SECONDS @@ -323,10 +379,16 @@ run_task() { cmd=(python examples/scripts/run_example.py -k "${dir}/kernels" -g "${dir}/golden.py" -p "$platform" --clone-protocol "$CLONE_PROTOCOL" "${commit_flag[@]}") - [[ -n "$device_id" ]] && cmd+=(-d "$device_id") + if [[ -n "$device_spec" ]]; then + if [[ "$required_devices" -gt 1 ]]; then + cmd+=(--devices "$device_spec" --nranks "$required_devices") + else + cmd+=(-d "$device_spec") + fi + fi # Progress to stdout (not captured in log) - echo "[${platform}${device_id:+:dev${device_id}}] Running: $name (attempt $attempt)" + echo "[${platform}${device_spec:+:dev${device_spec}}] Running: $name (attempt $attempt)" # Command output to log file only "${cmd[@]}" > "$task_log" 2>&1 @@ -336,21 +398,46 @@ run_task() { local status if [[ $rc -eq 0 ]]; then status="PASS" - echo "[${platform}${device_id:+:dev${device_id}}] PASS: $name (${elapsed}s)" + echo "[${platform}${device_spec:+:dev${device_spec}}] PASS: $name (${elapsed}s)" else status="FAIL" - echo "[${platform}${device_id:+:dev${device_id}}] FAIL: $name (${elapsed}s)" + echo "[${platform}${device_spec:+:dev${device_spec}}] FAIL: $name (${elapsed}s)" if [[ "$print_log_on_fail" == "true" ]]; then echo "--- LOG: $name (attempt $attempt) ---" cat "$task_log" echo "--- END ---" fi fi - echo "${name}|${platform}|${status}|device:${device_id:-sim}|attempt:${attempt}|${elapsed}s" \ + echo "${name}|${platform}|${status}|device:${device_spec:-sim}|attempt:${attempt}|${elapsed}s" \ >> "$RESULTS_FILE" return $rc } +run_hw_multidevice_tasks() { + local attempt="$1"; shift + local indices=("$@") + HW_MULTI_FAILURES=() + + for idx in "${indices[@]}"; do + local required_devices="${HW_TASK_DEVICE_COUNTS[$idx]}" + local platform="${HW_TASK_PLATS[$idx]}" + local name="${HW_TASK_NAMES[$idx]}" + + if [[ "$required_devices" -gt "$NUM_DEVICES" ]]; then + echo "[${platform}] FAIL: $name requires ${required_devices} devices, only ${NUM_DEVICES} available" + echo "${name}|${platform}|FAIL|device:insufficient|attempt:${attempt}|0s" >> "$RESULTS_FILE" + HW_MULTI_FAILURES+=("$idx") + continue + fi + + local device_spec + device_spec="$(format_device_spec "$required_devices")" + if ! run_task "$name" "${HW_TASK_DIRS[$idx]}" "$platform" "$attempt" "$device_spec" "true" "$required_devices"; then + HW_MULTI_FAILURES+=("$idx") + fi + done +} + # ---- SIM executor ---- # run_sim_tasks ... # Sets SIM_FAILURES to array of failed indices after return. @@ -429,7 +516,7 @@ run_hw_tasks() { IFS=':' read -r idx attempt <<< "$entry" - if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false"; then + if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false" "1"; then echo "${idx}|PASS" >> "$hw_marker" else next=$((attempt + 1)) @@ -473,12 +560,36 @@ fi # HW phase if [[ ${#HW_TASK_NAMES[@]} -gt 0 ]]; then - ALL_HW=($(seq 0 $((${#HW_TASK_NAMES[@]} - 1)))) - echo "---- HW: ${#ALL_HW[@]} tasks on ${NUM_DEVICES} devices ----" - run_hw_tasks "${ALL_HW[@]}" - if [[ ${#HW_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then - echo "[CI] Retrying ${#HW_FAILURES[@]} HW failures with pinned PTO-ISA" - run_hw_tasks "${HW_FAILURES[@]}" + ALL_HW_SINGLE=() + ALL_HW_MULTI=() + for idx in $(seq 0 $((${#HW_TASK_NAMES[@]} - 1))); do + if [[ "${HW_TASK_DEVICE_COUNTS[$idx]}" -gt 1 ]]; then + ALL_HW_MULTI+=("$idx") + else + ALL_HW_SINGLE+=("$idx") + fi + done + + echo "---- HW: ${#ALL_HW_SINGLE[@]} single-device tasks, ${#ALL_HW_MULTI[@]} multi-device tasks on ${NUM_DEVICES} devices ----" + + HW_MULTI_FAILURES=() + if [[ ${#ALL_HW_MULTI[@]} -gt 0 ]]; then + run_hw_multidevice_tasks 0 "${ALL_HW_MULTI[@]}" + if [[ ${#HW_MULTI_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then + echo "[CI] Retrying ${#HW_MULTI_FAILURES[@]} multi-device HW failures with pinned PTO-ISA" + run_hw_multidevice_tasks 1 "${HW_MULTI_FAILURES[@]}" + fi + fi + + HW_SINGLE_FAILURES=() + if [[ ${#ALL_HW_SINGLE[@]} -gt 0 ]]; then + run_hw_tasks "${ALL_HW_SINGLE[@]}" + HW_SINGLE_FAILURES=("${HW_FAILURES[@]}") + if [[ ${#HW_SINGLE_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then + echo "[CI] Retrying ${#HW_SINGLE_FAILURES[@]} HW failures with pinned PTO-ISA" + run_hw_tasks "${HW_SINGLE_FAILURES[@]}" + HW_SINGLE_FAILURES=("${HW_FAILURES[@]}") + fi fi fi diff --git a/examples/a2a3/aicpu_build_graph/allreduce_distributed/golden.py b/examples/a2a3/aicpu_build_graph/allreduce_distributed/golden.py new file mode 100644 index 00000000..0923630a --- /dev/null +++ b/examples/a2a3/aicpu_build_graph/allreduce_distributed/golden.py @@ -0,0 +1,41 @@ +""" +Golden script for distributed AllReduce. + +Each rank r contributes input[i] = i + r * 100 for i in [0, 256). +Every rank independently reduces (Sum) all inputs, so all ranks +produce the same output. + +Expected output (same on every rank): + output[i] = sum_{r=0}^{nranks-1} (i + r * 100) + = nranks * i + 100 * nranks * (nranks - 1) / 2 +""" + +ALLREDUCE_COUNT = 256 +NRANKS = 4 + +__outputs__ = ["output"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + """Each rank generates its own input; output is allocated on all ranks.""" + input_data = [float(i + rank * 100) for i in range(ALLREDUCE_COUNT)] + output_data = [0.0] * ALLREDUCE_COUNT + return [ + ("input", input_data), + ("output", output_data), + ("nranks", nranks), + ("root", root), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output — same for every rank.""" + nranks = params.get("nranks", NRANKS) + output = tensors["output"] + for i in range(ALLREDUCE_COUNT): + output[i] = float( + nranks * i + 100 * nranks * (nranks - 1) // 2) diff --git a/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp b/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp new file mode 100644 index 00000000..b4268cc4 --- /dev/null +++ b/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp @@ -0,0 +1,104 @@ +/** + * AllReduce kernel for simpler's kernel_entry signature. + * + * Every rank independently reads all ranks' inputs from the RDMA window, + * computes the element-wise sum, and writes the result to its own output. + * This is a symmetric allreduce — no designated root, all ranks active. + * + * args layout (all uint64_t, cast as needed): + * args[0] = __gm__ float* input (device addr in RDMA window) + * args[1] = __gm__ float* output (device addr, local) + * args[2] = int nranks + * args[3] = (unused, kept for ABI compatibility) + * args[4] = __gm__ CommDeviceContext* ctx (device addr) + */ + +#include +#include +#include "pto/comm/comm_types.hpp" +#include "pto/comm/pto_comm_inst.hpp" +#include "common/comm_context.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t ALLREDUCE_COUNT = 256; +static constexpr int kMaxSupportedRanks = 16; + +template +AICORE inline __gm__ T *CommRemotePtr( + __gm__ CommDeviceContext *ctx, __gm__ T *localPtr, int pe) +{ + uint64_t localBase = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)localPtr - localBase; + return (__gm__ T *)(ctx->windowsIn[pe] + offset); +} + + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + __gm__ float* input = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* output = reinterpret_cast<__gm__ float*>(args[1]); + int nranks = static_cast(args[2]); + int root = static_cast(args[3]); + __gm__ CommDeviceContext* commCtx = + reinterpret_cast<__gm__ CommDeviceContext*>(args[4]); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + int my_rank = static_cast(commCtx->rankId); + + ShapeDyn shape(1, 1, 1, 1, ALLREDUCE_COUNT); + StrideDyn stride(ALLREDUCE_COUNT, ALLREDUCE_COUNT, ALLREDUCE_COUNT, + ALLREDUCE_COUNT, 1); + + TileData accTile(1, ALLREDUCE_COUNT); + TileData recvTile(1, ALLREDUCE_COUNT); + TASSIGN(accTile, 0x0); + TASSIGN(recvTile, 0x10000); + + if (nranks <= 0 || nranks > kMaxSupportedRanks) { + pipe_barrier(PIPE_ALL); + return; + } + + // Every rank reads all inputs and sums them into its own output. + Global outputG(output, shape, stride); + + __gm__ float* firstInput = CommRemotePtr(commCtx, input, 0); + Global firstG(firstInput, shape, stride); + TLOAD(accTile, firstG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + for (int r = 1; r < nranks; ++r) { + __gm__ float* remoteInput = CommRemotePtr(commCtx, input, r); + Global remoteG(remoteInput, shape, stride); + TLOAD(recvTile, remoteG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TADD(accTile, accTile, recvTile); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(outputG, accTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/kernel_config.py b/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/kernel_config.py new file mode 100644 index 00000000..598efc86 --- /dev/null +++ b/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/kernel_config.py @@ -0,0 +1,62 @@ +""" +Distributed AllReduce kernel configuration — aicpu_build_graph runtime. + +Every rank reads all inputs via RDMA and computes the sum locally. +The AICPU orchestration plugin reads args from runtime->orch_args[], +builds the task graph via the aicpu_build_api, and publishes tasks for +the AICPU scheduler threads. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "allreduce_orch.cpp"), + "function_name": "build_allreduce_graph", +} + +KERNELS = [ + { + "func_id": 0, + "source": str(_KERNELS_ROOT / "aiv" / "allreduce_kernel.cpp"), + "core_type": "aiv", + }, +] + +RUNTIME_CONFIG = { + "runtime": "aicpu_build_graph", + "aicpu_thread_num": 4, + "block_dim": 4, +} + +RUNTIME_ENV = { + "PTO_AICPU_BUILD_GRAPH_BUILD_MODE": "1", +} + +# Distributed layout contract consumed by DistributedCodeRunner/worker: +# - win_sync_prefix reserves a small header at the front of each rank's RDMA +# window before any placement="window" buffers are laid out. +# - buffers declares runtime allocation metadata: +# * count is the element count, not byte size. +# * placement="window": buffer lives in the shared RDMA window and may be +# accessed by remote ranks. +# * placement="device": buffer uses regular device_malloc and is local-only. +# - inputs/outputs control which buffers are loaded from .bin files and which +# are copied back after execution. +# - args defines the orchestration/kernel uint64_t* args order. +DISTRIBUTED_CONFIG = { + "nranks": 4, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + # Every rank reads all ranks' inputs via CommRemotePtr, so the + # input buffer must be placed in the shared RDMA window. + {"name": "input", "dtype": "float32", "count": 256, "placement": "window"}, + # Each rank writes the reduced sum to its own local output. + {"name": "output", "dtype": "float32", "count": 256, "placement": "device"}, + ], + "inputs": ["input"], + "outputs": ["output"], + "args": ["input", "output", "nranks", "root", "deviceCtx"], +} diff --git a/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp b/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp new file mode 100644 index 00000000..83f2cfd5 --- /dev/null +++ b/examples/a2a3/aicpu_build_graph/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp @@ -0,0 +1,37 @@ +/** + * AllReduce Orchestration — aicpu_build_graph runtime. + * + * This orchestration plugin runs on AICPU. It reads args from + * runtime->orch_args[] (populated by init_runtime from func_args[]) + * and builds a single AIV task via the aicpu_build_api. + * + * orch_args layout (same as host_build_graph variant): + * [0] = input device pointer (in RDMA window) + * [1] = output device pointer (regular device memory) + * [2] = nranks + * [3] = root rank + * [4] = CommDeviceContext device pointer + */ + +#include "runtime.h" +#include + +extern "C" int build_allreduce_graph(Runtime* runtime) { + if (runtime == nullptr || runtime->orch_argc < 5) { + return -1; + } + + uint64_t task_args[5]; + task_args[0] = runtime->orch_args[0]; + task_args[1] = runtime->orch_args[1]; + task_args[2] = runtime->orch_args[2]; + task_args[3] = runtime->orch_args[3]; + task_args[4] = runtime->orch_args[4]; + + const AicpuBuildApi& api = runtime->aicpu_build_api; + int t0 = api.add_task(runtime, task_args, 5, 0, CoreType::AIV, 0); + if (t0 < 0) return -1; + api.publish_task(runtime, t0); + + return 0; +} diff --git a/examples/a2a3/host_build_graph/allreduce_distributed/golden.py b/examples/a2a3/host_build_graph/allreduce_distributed/golden.py new file mode 100644 index 00000000..0923630a --- /dev/null +++ b/examples/a2a3/host_build_graph/allreduce_distributed/golden.py @@ -0,0 +1,41 @@ +""" +Golden script for distributed AllReduce. + +Each rank r contributes input[i] = i + r * 100 for i in [0, 256). +Every rank independently reduces (Sum) all inputs, so all ranks +produce the same output. + +Expected output (same on every rank): + output[i] = sum_{r=0}^{nranks-1} (i + r * 100) + = nranks * i + 100 * nranks * (nranks - 1) / 2 +""" + +ALLREDUCE_COUNT = 256 +NRANKS = 4 + +__outputs__ = ["output"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + """Each rank generates its own input; output is allocated on all ranks.""" + input_data = [float(i + rank * 100) for i in range(ALLREDUCE_COUNT)] + output_data = [0.0] * ALLREDUCE_COUNT + return [ + ("input", input_data), + ("output", output_data), + ("nranks", nranks), + ("root", root), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output — same for every rank.""" + nranks = params.get("nranks", NRANKS) + output = tensors["output"] + for i in range(ALLREDUCE_COUNT): + output[i] = float( + nranks * i + 100 * nranks * (nranks - 1) // 2) diff --git a/examples/a2a3/host_build_graph/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp b/examples/a2a3/host_build_graph/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp new file mode 100644 index 00000000..b4268cc4 --- /dev/null +++ b/examples/a2a3/host_build_graph/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp @@ -0,0 +1,104 @@ +/** + * AllReduce kernel for simpler's kernel_entry signature. + * + * Every rank independently reads all ranks' inputs from the RDMA window, + * computes the element-wise sum, and writes the result to its own output. + * This is a symmetric allreduce — no designated root, all ranks active. + * + * args layout (all uint64_t, cast as needed): + * args[0] = __gm__ float* input (device addr in RDMA window) + * args[1] = __gm__ float* output (device addr, local) + * args[2] = int nranks + * args[3] = (unused, kept for ABI compatibility) + * args[4] = __gm__ CommDeviceContext* ctx (device addr) + */ + +#include +#include +#include "pto/comm/comm_types.hpp" +#include "pto/comm/pto_comm_inst.hpp" +#include "common/comm_context.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t ALLREDUCE_COUNT = 256; +static constexpr int kMaxSupportedRanks = 16; + +template +AICORE inline __gm__ T *CommRemotePtr( + __gm__ CommDeviceContext *ctx, __gm__ T *localPtr, int pe) +{ + uint64_t localBase = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)localPtr - localBase; + return (__gm__ T *)(ctx->windowsIn[pe] + offset); +} + + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + __gm__ float* input = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* output = reinterpret_cast<__gm__ float*>(args[1]); + int nranks = static_cast(args[2]); + int root = static_cast(args[3]); + __gm__ CommDeviceContext* commCtx = + reinterpret_cast<__gm__ CommDeviceContext*>(args[4]); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + int my_rank = static_cast(commCtx->rankId); + + ShapeDyn shape(1, 1, 1, 1, ALLREDUCE_COUNT); + StrideDyn stride(ALLREDUCE_COUNT, ALLREDUCE_COUNT, ALLREDUCE_COUNT, + ALLREDUCE_COUNT, 1); + + TileData accTile(1, ALLREDUCE_COUNT); + TileData recvTile(1, ALLREDUCE_COUNT); + TASSIGN(accTile, 0x0); + TASSIGN(recvTile, 0x10000); + + if (nranks <= 0 || nranks > kMaxSupportedRanks) { + pipe_barrier(PIPE_ALL); + return; + } + + // Every rank reads all inputs and sums them into its own output. + Global outputG(output, shape, stride); + + __gm__ float* firstInput = CommRemotePtr(commCtx, input, 0); + Global firstG(firstInput, shape, stride); + TLOAD(accTile, firstG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + for (int r = 1; r < nranks; ++r) { + __gm__ float* remoteInput = CommRemotePtr(commCtx, input, r); + Global remoteG(remoteInput, shape, stride); + TLOAD(recvTile, remoteG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TADD(accTile, accTile, recvTile); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(outputG, accTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/host_build_graph/allreduce_distributed/kernels/kernel_config.py b/examples/a2a3/host_build_graph/allreduce_distributed/kernels/kernel_config.py new file mode 100644 index 00000000..d578cc82 --- /dev/null +++ b/examples/a2a3/host_build_graph/allreduce_distributed/kernels/kernel_config.py @@ -0,0 +1,61 @@ +""" +Distributed AllReduce kernel configuration. + +Multi-card collective allreduce (Sum) across N ranks using PTO comm +instructions. Every rank reads all inputs via RDMA and computes the sum +locally. Communication addresses are set up by the comm_* platform API. + +DISTRIBUTED_CONFIG is the "multi-card graph" — it describes buffer layout, +args order, and artifact names. DistributedCodeRunner translates this into +distributed_worker.py CLI arguments. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "allreduce_orch.cpp"), + "function_name": "build_allreduce_graph", +} + +KERNELS = [ + { + "func_id": 0, + "source": str(_KERNELS_ROOT / "aiv" / "allreduce_kernel.cpp"), + "core_type": "aiv", + }, +] + +RUNTIME_CONFIG = { + "runtime": "host_build_graph", + "aicpu_thread_num": 1, + "block_dim": 1, +} + +# Distributed layout contract consumed by DistributedCodeRunner/worker: +# - win_sync_prefix reserves a small header at the front of each rank's RDMA +# window before any placement="window" buffers are laid out. +# - buffers declares runtime allocation metadata: +# * count is the element count, not byte size. +# * placement="window": buffer lives in the shared RDMA window and may be +# accessed by remote ranks. +# * placement="device": buffer uses regular device_malloc and is local-only. +# - inputs/outputs control which buffers are loaded from .bin files and which +# are copied back after execution. +# - args defines the orchestration/kernel uint64_t* args order. +DISTRIBUTED_CONFIG = { + "nranks": 4, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + # Every rank reads all ranks' inputs via CommRemotePtr, so the + # input buffer must be placed in the shared RDMA window. + {"name": "input", "dtype": "float32", "count": 256, "placement": "window"}, + # Each rank writes the reduced sum to its own local output. + {"name": "output", "dtype": "float32", "count": 256, "placement": "device"}, + ], + "inputs": ["input"], + "outputs": ["output"], + "args": ["input", "output", "nranks", "root", "deviceCtx"], +} diff --git a/examples/a2a3/host_build_graph/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp b/examples/a2a3/host_build_graph/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp new file mode 100644 index 00000000..c9ab7125 --- /dev/null +++ b/examples/a2a3/host_build_graph/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp @@ -0,0 +1,54 @@ +/** + * AllReduce Orchestration Function + * + * All arguments are device pointers / scalars already set up by the + * distributed_worker (comm window addresses, device context pointer). + * Creates a single AIV task with func_id=0 (allreduce kernel). + * + * args layout: + * args[0] = input device pointer (in RDMA window) + * args[1] = output device pointer (regular device memory) + * args[2] = nranks + * args[3] = root rank + * args[4] = CommDeviceContext device pointer + */ + +#include "runtime.h" +#include +#include + +extern "C" { + +int build_allreduce_graph(Runtime* runtime, uint64_t* args, int arg_count) { + if (arg_count < 5) { + std::cerr << "build_allreduce_graph: need 5 args, got " << arg_count << '\n'; + return -1; + } + + uint64_t input_dev = args[0]; + uint64_t output_dev = args[1]; + uint64_t nranks = args[2]; + uint64_t root = args[3]; + uint64_t comm_ctx_dev = args[4]; + + std::cout << "\n=== build_allreduce_graph ===" << '\n'; + std::cout << " input_dev = 0x" << std::hex << input_dev << '\n'; + std::cout << " output_dev = 0x" << output_dev << '\n'; + std::cout << " comm_ctx = 0x" << comm_ctx_dev << std::dec << '\n'; + std::cout << " nranks = " << nranks << '\n'; + std::cout << " root = " << root << '\n'; + + uint64_t task_args[5]; + task_args[0] = input_dev; + task_args[1] = output_dev; + task_args[2] = nranks; + task_args[3] = root; + task_args[4] = comm_ctx_dev; + + int t0 = runtime->add_task(task_args, 5, 0, CoreType::AIV); + std::cout << " Created task " << t0 << " (allreduce, func_id=0, AIV)" << '\n'; + + return 0; +} + +} // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/golden.py b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/golden.py new file mode 100644 index 00000000..0923630a --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/golden.py @@ -0,0 +1,41 @@ +""" +Golden script for distributed AllReduce. + +Each rank r contributes input[i] = i + r * 100 for i in [0, 256). +Every rank independently reduces (Sum) all inputs, so all ranks +produce the same output. + +Expected output (same on every rank): + output[i] = sum_{r=0}^{nranks-1} (i + r * 100) + = nranks * i + 100 * nranks * (nranks - 1) / 2 +""" + +ALLREDUCE_COUNT = 256 +NRANKS = 4 + +__outputs__ = ["output"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + """Each rank generates its own input; output is allocated on all ranks.""" + input_data = [float(i + rank * 100) for i in range(ALLREDUCE_COUNT)] + output_data = [0.0] * ALLREDUCE_COUNT + return [ + ("input", input_data), + ("output", output_data), + ("nranks", nranks), + ("root", root), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output — same for every rank.""" + nranks = params.get("nranks", NRANKS) + output = tensors["output"] + for i in range(ALLREDUCE_COUNT): + output[i] = float( + nranks * i + 100 * nranks * (nranks - 1) // 2) diff --git a/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp new file mode 100644 index 00000000..b4268cc4 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/aiv/allreduce_kernel.cpp @@ -0,0 +1,104 @@ +/** + * AllReduce kernel for simpler's kernel_entry signature. + * + * Every rank independently reads all ranks' inputs from the RDMA window, + * computes the element-wise sum, and writes the result to its own output. + * This is a symmetric allreduce — no designated root, all ranks active. + * + * args layout (all uint64_t, cast as needed): + * args[0] = __gm__ float* input (device addr in RDMA window) + * args[1] = __gm__ float* output (device addr, local) + * args[2] = int nranks + * args[3] = (unused, kept for ABI compatibility) + * args[4] = __gm__ CommDeviceContext* ctx (device addr) + */ + +#include +#include +#include "pto/comm/comm_types.hpp" +#include "pto/comm/pto_comm_inst.hpp" +#include "common/comm_context.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr size_t ALLREDUCE_COUNT = 256; +static constexpr int kMaxSupportedRanks = 16; + +template +AICORE inline __gm__ T *CommRemotePtr( + __gm__ CommDeviceContext *ctx, __gm__ T *localPtr, int pe) +{ + uint64_t localBase = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)localPtr - localBase; + return (__gm__ T *)(ctx->windowsIn[pe] + offset); +} + + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + __gm__ float* input = reinterpret_cast<__gm__ float*>(args[0]); + __gm__ float* output = reinterpret_cast<__gm__ float*>(args[1]); + int nranks = static_cast(args[2]); + int root = static_cast(args[3]); + __gm__ CommDeviceContext* commCtx = + reinterpret_cast<__gm__ CommDeviceContext*>(args[4]); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + int my_rank = static_cast(commCtx->rankId); + + ShapeDyn shape(1, 1, 1, 1, ALLREDUCE_COUNT); + StrideDyn stride(ALLREDUCE_COUNT, ALLREDUCE_COUNT, ALLREDUCE_COUNT, + ALLREDUCE_COUNT, 1); + + TileData accTile(1, ALLREDUCE_COUNT); + TileData recvTile(1, ALLREDUCE_COUNT); + TASSIGN(accTile, 0x0); + TASSIGN(recvTile, 0x10000); + + if (nranks <= 0 || nranks > kMaxSupportedRanks) { + pipe_barrier(PIPE_ALL); + return; + } + + // Every rank reads all inputs and sums them into its own output. + Global outputG(output, shape, stride); + + __gm__ float* firstInput = CommRemotePtr(commCtx, input, 0); + Global firstG(firstInput, shape, stride); + TLOAD(accTile, firstG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + for (int r = 1; r < nranks; ++r) { + __gm__ float* remoteInput = CommRemotePtr(commCtx, input, r); + Global remoteG(remoteInput, shape, stride); + TLOAD(recvTile, remoteG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TADD(accTile, accTile, recvTile); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(outputG, accTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/kernel_config.py new file mode 100644 index 00000000..9a9232da --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/kernel_config.py @@ -0,0 +1,59 @@ +""" +Distributed AllReduce kernel configuration — tensormap_and_ringbuffer runtime. + +Every rank reads all inputs via RDMA and computes the sum locally. +Device-side orchestration via PTO2Runtime API. The orchestration function +wraps each arg as a PTOParam (tensor or scalar) and submits a single AIV task. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "allreduce_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + { + "func_id": 0, + "source": str(_KERNELS_ROOT / "aiv" / "allreduce_kernel.cpp"), + "core_type": "aiv", + }, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 3, + "orch_thread_num": 1, + "rounds": 1, +} + +# Distributed layout contract consumed by DistributedCodeRunner/worker: +# - win_sync_prefix reserves a small header at the front of each rank's RDMA +# window before any placement="window" buffers are laid out. +# - buffers declares runtime allocation metadata: +# * count is the element count, not byte size. +# * placement="window": buffer lives in the shared RDMA window and may be +# accessed by remote ranks. +# * placement="device": buffer uses regular device_malloc and is local-only. +# - inputs/outputs control which buffers are loaded from .bin files and which +# are copied back after execution. +# - args defines the orchestration/kernel uint64_t* args order. +DISTRIBUTED_CONFIG = { + "nranks": 4, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + # Every rank reads all ranks' inputs via CommRemotePtr, so the + # input buffer must be placed in the shared RDMA window. + {"name": "input", "dtype": "float32", "count": 256, "placement": "window"}, + # Each rank writes the reduced sum to its own local output. + {"name": "output", "dtype": "float32", "count": 256, "placement": "device"}, + ], + "inputs": ["input"], + "outputs": ["output"], + "args": ["input", "output", "nranks", "root", "deviceCtx"], +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp new file mode 100644 index 00000000..1dfa63ab --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/allreduce_distributed/kernels/orchestration/allreduce_orch.cpp @@ -0,0 +1,51 @@ +/** + * AllReduce Orchestration — tensormap_and_ringbuffer runtime (PTO2 API). + * + * All five arguments are passed as SCALAR params so the kernel receives + * raw uint64_t values (device pointers + integers) in the same flat + * args[] layout as host_build_graph / aicpu_build_graph. + * + * The Tensor/PTOParam system maps tensor params to Tensor-struct pointers + * (not device addresses) — that would break the allreduce kernel which reads + * args[] as raw pointers. Using all-scalar avoids this incompatibility. + * + * args layout: + * [0] = input device pointer (in RDMA window) + * [1] = output device pointer (regular device memory) + * [2] = nranks + * [3] = root rank + * [4] = CommDeviceContext device pointer + */ + +#include +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 5, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count, + int orch_thread_num, int orch_thread_index) { + (void)arg_count; + (void)orch_thread_num; + + if (orch_thread_index != 0) return; + + PTOParam params; + params.add_scalar(args[0]); + params.add_scalar(args[1]); + params.add_scalar(args[2]); + params.add_scalar(args[3]); + params.add_scalar(args[4]); + pto2_rt_submit_aiv_task(rt, 0, params); +} + +} // extern "C" diff --git a/examples/scripts/README.md b/examples/scripts/README.md index 0afcb07c..1f6f6ca9 100644 --- a/examples/scripts/README.md +++ b/examples/scripts/README.md @@ -42,6 +42,32 @@ python examples/scripts/run_example.py \ -p a2a3sim ``` +#### Running Distributed (Multi-Rank) Tests + +Distributed examples are auto-detected when `kernel_config.py` contains a `DISTRIBUTED_CONFIG` dictionary. No separate script is needed — `run_example.py` handles it automatically: + +```bash +# Simulation (no hardware required, 8 ranks by default from kernel_config) +python examples/scripts/run_example.py \ + -k examples/a2a3/host_build_graph/allreduce_distributed/kernels \ + -g examples/a2a3/host_build_graph/allreduce_distributed/golden.py \ + -p a2a3sim + +# Hardware platform — pick specific devices (nranks inferred from device count) +python examples/scripts/run_example.py \ + -k examples/a2a3/host_build_graph/allreduce_distributed/kernels \ + -g examples/a2a3/host_build_graph/allreduce_distributed/golden.py \ + -p a2a3 --devices 0,1,2,3,4,5,6,7 + +# Hardware platform — non-contiguous devices +python examples/scripts/run_example.py \ + -k examples/a2a3/host_build_graph/allreduce_distributed/kernels \ + -g examples/a2a3/host_build_graph/allreduce_distributed/golden.py \ + -p a2a3 --devices 2,4,5,7 +``` + +The framework spawns one worker process per rank, each using the backend-neutral `comm_*` API. On simulation (`a2a3sim`), ranks communicate via POSIX shared memory; on hardware (`a2a3`), they use HCCL over RDMA. + ## Command Line Arguments ### `run_example.py` Parameters @@ -56,6 +82,7 @@ python examples/scripts/run_example.py \ | `--verbose` | `-v` | Enable verbose output (equivalent to `--log-level debug`) | False | | `--silent` | | Enable silent mode (equivalent to `--log-level error`) | False | | `--log-level` | | Set log level: `error`, `warn`, `info`, `debug` | `info` | +| `--nranks` | | Number of ranks for distributed tests | From `DISTRIBUTED_CONFIG` | | `--clone-protocol` | | Git protocol for cloning pto-isa: `ssh` or `https` | `ssh` | ### Platform Description @@ -161,7 +188,54 @@ ORCHESTRATION = { } ``` -### 3. `golden.py` Format +### 3. Distributed `kernel_config.py` Format + +To make a test distributed, add a `DISTRIBUTED_CONFIG` dictionary alongside the standard `KERNELS` and `ORCHESTRATION` fields: + +```python +DISTRIBUTED_CONFIG = { + "nranks": 8, # Number of ranks + "root": 0, # Root rank for collective ops + "comm_include_dirs": [...], # Extra include dirs for kernel compilation + "win_sync_prefix": 256, # Bytes reserved before window buffers + "buffers": [ + {"name": "input", "dtype": "float32", "count": 256, "placement": "window"}, + {"name": "output", "dtype": "float32", "count": 256, "placement": "device"}, + ], + "inputs": ["input"], # Buffers to load from .bin files + "outputs": ["output"], # Buffers to save after execution + "args": ["input", "output", "nranks", "root", "deviceCtx"], +} +``` + +- **`placement: "window"`** — Buffer is allocated in the RDMA window region (accessible by all ranks). +- **`placement: "device"`** — Buffer is allocated via `device_malloc` (local to each rank). +- **`args`** — Tokens passed as orchestration function arguments. Special tokens: `nranks`, `root`, `deviceCtx` (pointer to `CommDeviceContext`). + +### 4. Distributed `golden.py` Format + +The golden script for distributed tests uses `generate_distributed_inputs` instead of `generate_inputs`: + +```python +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + """Return a list of (name, data) tuples for this rank.""" + input_data = [float(i + rank * 100) for i in range(256)] + output_data = [0.0] * 256 + return [ + ("input", input_data), + ("output", output_data), + ] + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output for the root rank (in-place).""" + nranks = params.get("nranks", 8) + output = tensors["output"] + for i in range(256): + output[i] = float(nranks * i + 100 * nranks * (nranks - 1) // 2) +``` + +### 5. Standard `golden.py` Format ```python import torch @@ -365,6 +439,25 @@ TEST PASSED ============================================================ ``` +### Distributed Test Success Example + +``` +[INFO] Detected DISTRIBUTED_CONFIG — using distributed runner +[INFO] === Phase 1: Building runtime === +... +[INFO] === Launching 8 workers === +[INFO] Rank 0: OK +[INFO] Rank 1: OK +... +[INFO] Rank 7: OK +[INFO] VERIFY PASSED: output — 256 elements correct +[INFO] Sample: [2800.0, 2808.0, 2816.0, 2824.0, 2832.0] + +============================================================ +TEST PASSED +============================================================ +``` + ### Failure Example ``` @@ -378,8 +471,8 @@ TEST FAILED: Output 'f' does not match golden ## Reference Examples -- **Hardware Example**: [examples/host_build_graph/vector_example/](../host_build_graph/vector_example/) -- **Simulation Example**: [examples/host_build_graph/vector_example/](../host_build_graph/vector_example/) +- **Single-Card Example**: [examples/a2a3/host_build_graph/vector_example/](../a2a3/host_build_graph/vector_example/) +- **Distributed Example**: [examples/a2a3/host_build_graph/allreduce_distributed/](../a2a3/host_build_graph/allreduce_distributed/) ## FAQ @@ -521,6 +614,21 @@ runner = create_code_runner( runner.run() # Execute test ``` +### Distributed Programmatic Usage + +```python +from distributed_code_runner import DistributedCodeRunner + +runner = DistributedCodeRunner( + kernels_dir="examples/a2a3/host_build_graph/allreduce_distributed/kernels", + golden_path="examples/a2a3/host_build_graph/allreduce_distributed/golden.py", + platform="a2a3sim", + nranks=8, +) + +runner.run_all() # compile, prepare data, launch workers, verify +``` + ## Related Documentation - [Main Project README](../../README.md) diff --git a/examples/scripts/distributed_code_runner.py b/examples/scripts/distributed_code_runner.py new file mode 100644 index 00000000..83a29ffd --- /dev/null +++ b/examples/scripts/distributed_code_runner.py @@ -0,0 +1,465 @@ +""" +DistributedCodeRunner — compile, prepare data, launch workers, and verify +results for distributed (multi-card) PTO kernel tests. + +Parallel to CodeRunner, but handles DISTRIBUTED_CONFIG and spawns N +Python worker processes (one per rank) via distributed_worker.py. + +Usage: + runner = DistributedCodeRunner( + kernels_dir="examples/a2a3/.../allreduce_distributed/kernels", + golden_path="examples/a2a3/.../allreduce_distributed/golden.py", + platform="a2a3", nranks=8, + ) + runner.run() +""" + +import importlib.util +import logging +import os +import shutil +import struct +import subprocess +import sys +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +SIMPLER_ROOT = Path(__file__).resolve().parent.parent.parent +SCRIPTS_DIR = Path(__file__).resolve().parent + +DTYPE_FORMAT = { + "float32": ("f", 4), + "float64": ("d", 8), + "int32": ("i", 4), + "int64": ("q", 8), + "uint32": ("I", 4), + "uint64": ("Q", 8), + "float16": ("e", 2), + "int16": ("h", 2), + "uint16": ("H", 2), + "int8": ("b", 1), + "uint8": ("B", 1), +} + + +def _load_module(path, name="mod"): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +class DistributedCodeRunner: + + def __init__( + self, + kernels_dir: str, + golden_path: Optional[str] = None, + platform: str = "a2a3", + nranks: Optional[int] = None, + device_ids: Optional[list[int]] = None, + root: Optional[int] = None, + build_dir: Optional[str] = None, + artifact_dir: Optional[str] = None, + orch_func: Optional[str] = None, + pto_isa_commit: Optional[str] = None, + clone_protocol: str = "ssh", + ): + self.kernels_dir = Path(kernels_dir).resolve() + self.platform = platform + self.build_dir = Path(build_dir).resolve() if build_dir else \ + SIMPLER_ROOT / "build" / "distributed" / "cache" + self.artifact_dir = Path(artifact_dir).resolve() if artifact_dir else \ + SIMPLER_ROOT / "build" / "distributed" / "artifacts" + self.pto_isa_commit = pto_isa_commit + self.clone_protocol = clone_protocol + + self._load_kernel_config() + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + + self.nranks = nranks if nranks is not None else dist.get("nranks", 8) + self.root = root if root is not None else dist.get("root", 0) + self.orch_func = orch_func or self.kcfg.ORCHESTRATION["function_name"] + if self.nranks <= 0: + raise ValueError(f"Distributed nranks must be positive, got {self.nranks}") + if self.root < 0 or self.root >= self.nranks: + raise ValueError( + f"Distributed root must be in [0, {self.nranks}), got {self.root}" + ) + + if device_ids is None: + self.device_ids = list(range(self.nranks)) + else: + if len(device_ids) != self.nranks: + raise ValueError( + f"Expected {self.nranks} device ids, got {len(device_ids)}: {device_ids}" + ) + self.device_ids = list(device_ids) + + self.golden_path = Path(golden_path).resolve() if golden_path else None + self.golden_mod = None + + def _load_kernel_config(self): + config_path = self.kernels_dir / "kernel_config.py" + if not config_path.exists(): + raise FileNotFoundError(f"kernel_config.py not found in {self.kernels_dir}") + self.kcfg = _load_module(config_path, "kernel_config") + + def _load_golden(self): + if self.golden_mod is None and self.golden_path and self.golden_path.exists(): + self.golden_mod = _load_module(self.golden_path, "golden") + return self.golden_mod + + def _orch_artifact_name(self): + src = Path(self.kcfg.ORCHESTRATION["source"]) + return src.stem + ".so" + + def _kernel_artifact_name(self, kernel_cfg): + src = Path(kernel_cfg["source"]) + return src.stem + ".bin" + + def _get_buffer_config(self, name: str): + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + for buf_cfg in dist.get("buffers", []): + if buf_cfg["name"] == name: + return buf_cfg + raise ValueError( + f"Buffer '{name}' from golden.py not found in DISTRIBUTED_CONFIG['buffers']" + ) + + def _get_dtype_format(self, dtype: str, buffer_name: str): + fmt = DTYPE_FORMAT.get(dtype) + if fmt is None: + raise ValueError( + f"Unsupported dtype '{dtype}' for buffer '{buffer_name}'" + ) + return fmt + + # ------------------------------------------------------------------ + # compile() + # ------------------------------------------------------------------ + + def compile(self): + self.artifact_dir.mkdir(parents=True, exist_ok=True) + for sub in ("aicore", "aicpu", "host"): + p = self.build_dir / sub + if p.exists(): + shutil.rmtree(p) + self.build_dir.mkdir(parents=True, exist_ok=True) + + python_dir = SIMPLER_ROOT / "python" + sys.path.insert(0, str(python_dir)) + sys.path.insert(0, str(SCRIPTS_DIR)) + + from runtime_builder import RuntimeBuilder + from elf_parser import extract_text_section + from code_runner import _ensure_pto_isa_root + + pto_isa_root = _ensure_pto_isa_root( + verbose=True, commit=self.pto_isa_commit, + clone_protocol=self.clone_protocol) + if pto_isa_root is None: + raise EnvironmentError("PTO_ISA_ROOT could not be resolved.") + + runtime_name = self.kcfg.RUNTIME_CONFIG.get("runtime", "host_build_graph") + builder = RuntimeBuilder(platform=self.platform) + kernel_compiler = builder.get_kernel_compiler() + + logger.info("=== Phase 1: Building runtime ===") + host_binary, aicpu_binary, aicore_binary = builder.build( + runtime_name, str(self.build_dir)) + + logger.info("=== Phase 2: Compiling orchestration ===") + orch_source = self.kcfg.ORCHESTRATION["source"] + if not os.path.isabs(orch_source): + orch_source = str(self.kernels_dir / orch_source) + orch_binary = kernel_compiler.compile_orchestration( + runtime_name, orch_source, build_dir=str(self.build_dir)) + + logger.info("=== Phase 3: Compiling kernels ===") + if self.platform in ("a2a3", "a2a3sim"): + arch = "a2a3" + elif self.platform in ("a5", "a5sim"): + arch = "a5" + else: + arch = "a2a3" + + runtime_include_dirs = [ + str(SIMPLER_ROOT / "src" / arch / "runtime" / runtime_name / "runtime") + ] + + dist_config = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + extra_includes = list(runtime_include_dirs) + [ + str(SIMPLER_ROOT / "src" / arch / "platform" / "include"), + ] + for d in dist_config.get("comm_include_dirs", []): + p = Path(pto_isa_root) / d if not os.path.isabs(d) else Path(d) + extra_includes.append(str(p)) + + kernel_bins = {} + for k in self.kcfg.KERNELS: + src = k["source"] + if not os.path.isabs(src): + src = str(self.kernels_dir / src) + incore_o = kernel_compiler.compile_incore( + src, + core_type=k.get("core_type", "aiv"), + pto_isa_root=pto_isa_root, + extra_include_dirs=extra_includes, + build_dir=str(self.build_dir), + ) + if self.platform.endswith("sim"): + kernel_bins[k["func_id"]] = (k, incore_o) + else: + kernel_bins[k["func_id"]] = (k, extract_text_section(incore_o)) + + logger.info("=== Phase 4: Saving artifacts ===") + + def save(name, data): + path = self.artifact_dir / name + path.write_bytes(data) + logger.info(f" {name}: {len(data)} bytes") + + save("libhost_runtime.so", host_binary) + save("libaicpu_kernel.so", aicpu_binary) + save("aicore_kernel.o", aicore_binary) + save(self._orch_artifact_name(), orch_binary) + for func_id, (kcfg, data) in kernel_bins.items(): + save(self._kernel_artifact_name(kcfg), data) + + logger.info(f"All artifacts saved to {self.artifact_dir}") + + # ------------------------------------------------------------------ + # prepare_data() + # ------------------------------------------------------------------ + + def prepare_data(self): + golden = self._load_golden() + if not golden or not hasattr(golden, "generate_distributed_inputs"): + logger.info("No golden.py or generate_distributed_inputs — skipping data prep") + return + + for r in range(self.nranks): + rank_dir = self.artifact_dir / f"rank_{r}" + rank_dir.mkdir(parents=True, exist_ok=True) + + inputs = golden.generate_distributed_inputs(r, self.nranks, self.root) + for name, data in inputs: + if isinstance(data, (list, tuple)): + buf_cfg = self._get_buffer_config(name) + fmt_char, _ = self._get_dtype_format(buf_cfg["dtype"], name) + bin_data = struct.pack(f"<{len(data)}{fmt_char}", *data) + path = rank_dir / f"{name}.bin" + path.write_bytes(bin_data) + logger.debug(f" rank_{r}/{name}.bin: {len(bin_data)} bytes") + + logger.info(f"Prepared data for {self.nranks} ranks in {self.artifact_dir}") + + # ------------------------------------------------------------------ + # run() + # ------------------------------------------------------------------ + + def _build_worker_cmd(self, r): + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + rootinfo_file = self.artifact_dir / "rootinfo.bin" + + cmd = [ + sys.executable, + str(SCRIPTS_DIR / "distributed_worker.py"), + "--device-id", str(self.device_ids[r]), + "--rank", str(r), + "--nranks", str(self.nranks), + "--root", str(self.root), + "--artifact-dir", str(self.artifact_dir), + "--rootinfo-file", str(rootinfo_file), + "--data-dir", str(self.artifact_dir / f"rank_{r}"), + "--orch-file", self._orch_artifact_name(), + "--orch-func", self.orch_func, + ] + + rt_cfg = getattr(self.kcfg, "RUNTIME_CONFIG", {}) + cmd += ["--aicpu-thread-num", str(rt_cfg.get("aicpu_thread_num", 1))] + cmd += ["--block-dim", str(rt_cfg.get("block_dim", 1))] + cmd += ["--orch-thread-num", str(rt_cfg.get("orch_thread_num", 0))] + + win_sync = dist.get("win_sync_prefix", 0) + if win_sync: + cmd += ["--win-sync-prefix", str(win_sync)] + + for buf in dist.get("buffers", []): + spec = f"{buf['name']}:{buf['dtype']}:{buf['count']}" + if buf["placement"] == "window": + cmd += ["--win-buffer", spec] + else: + cmd += ["--dev-buffer", spec] + + for name in dist.get("inputs", []): + cmd += ["--load", name] + + for name in dist.get("outputs", []): + cmd += ["--save", name] + + for tok in dist.get("args", []): + cmd += ["--arg", tok] + + for k in self.kcfg.KERNELS: + cmd += ["--kernel-bin", + f"{k['func_id']}:{self._kernel_artifact_name(k)}"] + + return cmd + + def run(self): + rootinfo_file = self.artifact_dir / "rootinfo.bin" + + for f in self.artifact_dir.glob("barrier_*.ready"): + f.unlink() + if rootinfo_file.exists(): + rootinfo_file.unlink() + + shm_dir = Path("/dev/shm") + if shm_dir.is_dir(): + for f in shm_dir.glob("simpler_comm_*"): + try: + f.unlink() + except OSError: + pass + + logger.info(f"=== Launching {self.nranks} workers ===") + + procs = [] + log_files = [] + for r in range(self.nranks): + log_path = self.artifact_dir / f"rank{r}.log" + log_f = open(log_path, "w") + log_files.append(log_f) + + cmd = self._build_worker_cmd(r) + env = os.environ.copy() + runtime_env = getattr(self.kcfg, "RUNTIME_ENV", None) + if isinstance(runtime_env, dict): + env.update(runtime_env) + + proc = subprocess.Popen(cmd, stdout=log_f, stderr=subprocess.STDOUT, env=env) + procs.append(proc) + + fail_count = 0 + for r, proc in enumerate(procs): + proc.wait() + log_files[r].close() + if proc.returncode != 0: + fail_count += 1 + logger.error(f"Rank {r}: FAILED (exit code {proc.returncode})") + else: + logger.info(f"Rank {r}: OK") + + print() + for r in range(self.nranks): + log_path = self.artifact_dir / f"rank{r}.log" + lines = log_path.read_text().strip().split("\n") + print(f"--- RANK {r} (last 5 lines) ---") + for line in lines[-5:]: + print(line) + + print() + if fail_count == 0: + print(f"=== ALL {self.nranks} RANKS COMPLETED ===") + else: + print(f"=== {fail_count}/{self.nranks} RANKS FAILED ===") + + for f in self.artifact_dir.glob("barrier_*.ready"): + f.unlink() + + self._run_ok = (fail_count == 0) + return self._run_ok + + # ------------------------------------------------------------------ + # verify() + # ------------------------------------------------------------------ + + def verify(self): + golden = self._load_golden() + if not golden or not hasattr(golden, "compute_golden"): + logger.info("No golden.py or compute_golden — skipping verification") + return True + + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + output_names = dist.get("outputs", []) + buf_map = {b["name"]: b for b in dist.get("buffers", [])} + + # Compute expected outputs once (same for all ranks in allreduce). + seed_dir = self.artifact_dir / f"rank_{self.root}" + seed_outputs = {} + for name in output_names: + path = seed_dir / f"{name}.bin" + if not path.exists(): + logger.error(f"Output file not found: {path}") + return False + raw = path.read_bytes() + dtype = buf_map.get(name, {}).get("dtype", "float32") + fmt_char, elem_sz = DTYPE_FORMAT.get(dtype, ("f", 4)) + count = len(raw) // elem_sz + seed_outputs[name] = list(struct.unpack(f"<{count}{fmt_char}", raw)) + + expected_outputs = {n: v.copy() for n, v in seed_outputs.items()} + params = {"nranks": self.nranks, "root": self.root} + golden.compute_golden(expected_outputs, params) + + rtol = getattr(golden, "RTOL", 1e-5) + atol = getattr(golden, "ATOL", 1e-5) + + all_ok = True + for rank in range(self.nranks): + rank_dir = self.artifact_dir / f"rank_{rank}" + for name in output_names: + path = rank_dir / f"{name}.bin" + if not path.exists(): + logger.error(f"Output file not found: {path}") + all_ok = False + continue + raw = path.read_bytes() + dtype = buf_map.get(name, {}).get("dtype", "float32") + fmt_char, elem_sz = DTYPE_FORMAT.get(dtype, ("f", 4)) + count = len(raw) // elem_sz + actual = list(struct.unpack(f"<{count}{fmt_char}", raw)) + expected = expected_outputs[name] + + mismatches = 0 + for i, (a, e) in enumerate(zip(actual, expected)): + if abs(a - e) > atol + rtol * abs(e): + if mismatches < 3: + logger.error(f" rank {rank} {name}[{i}]: got {a}, expected {e}") + mismatches += 1 + if mismatches > 0: + logger.error(f"VERIFY FAILED: rank {rank} {name} — {mismatches}/{len(actual)} mismatches") + all_ok = False + else: + logger.info(f"VERIFY PASSED: rank {rank} {name} — {len(actual)} elements correct") + if rank == 0 and len(actual) >= 5: + logger.info(f" Sample: {actual[:5]}") + + if all_ok: + print("\n=== VERIFICATION PASSED ===\n") + else: + print("\n=== VERIFICATION FAILED ===\n") + + return all_ok + + # ------------------------------------------------------------------ + # Full pipeline + # ------------------------------------------------------------------ + + def run_all(self, skip_compile=False, skip_verify=False): + if not skip_compile: + self.compile() + + if self.golden_path: + self.prepare_data() + + success = self.run() + + if success and self.golden_path and not skip_verify: + success = self.verify() + + return success diff --git a/examples/scripts/distributed_worker.py b/examples/scripts/distributed_worker.py new file mode 100644 index 00000000..68877121 --- /dev/null +++ b/examples/scripts/distributed_worker.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +""" +Per-rank Python worker for distributed (multi-card) kernel execution. + +Replaces the monolithic C++ distributed_worker binary. Each rank runs +as a separate process, using the comm_* C API (via ctypes bindings) for +HCCL / sim communication and the existing PTO runtime C API for kernel +execution. + +Spawned by DistributedCodeRunner — not intended for direct invocation. +""" + +import argparse +import struct +import sys +from pathlib import Path + +script_dir = Path(__file__).parent.resolve() +project_root = script_dir.parent.parent +sys.path.insert(0, str(project_root / "python")) +sys.path.insert(0, str(script_dir)) + + +DTYPE_FORMAT = { + "float32": ("f", 4), + "float64": ("d", 8), + "int32": ("i", 4), + "int64": ("q", 8), + "uint32": ("I", 4), + "uint64": ("Q", 8), + "float16": ("e", 2), + "int16": ("h", 2), + "uint16": ("H", 2), + "int8": ("b", 1), + "uint8": ("B", 1), +} + + +def parse_buffer_spec(spec): + parts = spec.split(":") + return {"name": parts[0], "dtype": parts[1], "count": int(parts[2])} + + +def parse_kernel_spec(spec): + p = spec.index(":") + return {"func_id": int(spec[:p]), "filename": spec[p + 1:]} + + +def main(): + parser = argparse.ArgumentParser(description="Distributed per-rank worker") + parser.add_argument("--device-id", type=int, required=True) + parser.add_argument("--rank", type=int, required=True) + parser.add_argument("--nranks", type=int, required=True) + parser.add_argument("--root", type=int, default=0) + parser.add_argument("--artifact-dir", required=True) + parser.add_argument("--rootinfo-file", required=True) + parser.add_argument("--data-dir", default=None) + parser.add_argument("--orch-file", required=True) + parser.add_argument("--orch-func", required=True) + parser.add_argument("--win-sync-prefix", type=int, default=0) + parser.add_argument("--aicpu-thread-num", type=int, default=1) + parser.add_argument("--block-dim", type=int, default=1) + parser.add_argument("--orch-thread-num", type=int, default=0) + parser.add_argument("--win-buffer", action="append", default=[]) + parser.add_argument("--dev-buffer", action="append", default=[]) + parser.add_argument("--load", action="append", default=[], dest="loads") + parser.add_argument("--save", action="append", default=[], dest="saves") + parser.add_argument("--arg", action="append", default=[], dest="args") + parser.add_argument("--kernel-bin", action="append", default=[]) + args = parser.parse_args() + + artifact_dir = Path(args.artifact_dir) + data_dir = Path(args.data_dir) if args.data_dir else artifact_dir / f"rank_{args.rank}" + + buffers = [] + for spec in args.win_buffer: + b = parse_buffer_spec(spec) + b["placement"] = "window" + buffers.append(b) + for spec in args.dev_buffer: + b = parse_buffer_spec(spec) + b["placement"] = "device" + buffers.append(b) + + kernel_bins = [parse_kernel_spec(s) for s in args.kernel_bin] + + buf_by_name = {b["name"]: b for b in buffers} + + def elem_size(dtype): + return DTYPE_FORMAT.get(dtype, ("f", 4))[1] + + def buf_bytes(b): + return b["count"] * elem_size(b["dtype"]) + + # ---------------------------------------------------------------- + # 1. Load library + # ---------------------------------------------------------------- + from bindings import ( + bind_host_binary, set_device, launch_runtime, + device_malloc, device_free, copy_to_device, copy_from_device, + comm_init, comm_alloc_windows, comm_get_local_window_base, + comm_barrier, comm_destroy, + ARG_SCALAR, ARG_INPUT_PTR, ARG_OUTPUT_PTR, ARG_INOUT_PTR, + ) + + lib_path = artifact_dir / "libhost_runtime.so" + Runtime = bind_host_binary(str(lib_path)) + set_device(args.device_id) + + sys.stderr.write(f"[rank {args.rank}] Library loaded, device {args.device_id} set\n") + + # ---------------------------------------------------------------- + # 2. Comm init + alloc windows + # ---------------------------------------------------------------- + comm = comm_init(args.rank, args.nranks, args.rootinfo_file) + + total_win = args.win_sync_prefix + for b in buffers: + if b["placement"] == "window": + total_win += buf_bytes(b) + + device_ctx_ptr = comm_alloc_windows(comm, total_win) + local_base = comm_get_local_window_base(comm) + + sys.stderr.write(f"[rank {args.rank}] Comm initialized, local_base=0x{local_base:x}\n") + + # ---------------------------------------------------------------- + # 3. Allocate buffers + # ---------------------------------------------------------------- + win_offset = args.win_sync_prefix + + for b in buffers: + nbytes = buf_bytes(b) + if b["placement"] == "window": + b["dev_ptr"] = local_base + win_offset + win_offset += nbytes + else: + ptr = device_malloc(nbytes) + if not ptr: + sys.stderr.write(f"[rank {args.rank}] device_malloc failed for '{b['name']}'\n") + return 3 + b["dev_ptr"] = ptr + sys.stderr.write( + f"[rank {args.rank}] Buffer '{b['name']}': {b['placement']} " + f"{b['count']}x{b['dtype']}={nbytes}B @ 0x{b['dev_ptr']:x}\n" + ) + + # ---------------------------------------------------------------- + # 4. Load inputs + # ---------------------------------------------------------------- + for name in args.loads: + b = buf_by_name.get(name) + if not b: + sys.stderr.write(f"[rank {args.rank}] --load: buffer '{name}' not found\n") + return 1 + path = data_dir / f"{name}.bin" + host_data = path.read_bytes() + if len(host_data) != buf_bytes(b): + sys.stderr.write( + f"[rank {args.rank}] Size mismatch for '{name}': " + f"file={len(host_data)}, expected={buf_bytes(b)}\n" + ) + return 2 + import ctypes + host_buf = (ctypes.c_uint8 * len(host_data)).from_buffer_copy(host_data) + copy_to_device(b["dev_ptr"], ctypes.addressof(host_buf), len(host_data)) + + # ---------------------------------------------------------------- + # 5. Barrier before kernel execution + # ---------------------------------------------------------------- + comm_barrier(comm) + + # ---------------------------------------------------------------- + # 6. Run simpler runtime + # ---------------------------------------------------------------- + orch_binary = (artifact_dir / args.orch_file).read_bytes() + aicpu_binary = (artifact_dir / "libaicpu_kernel.so").read_bytes() + aicore_binary = (artifact_dir / "aicore_kernel.o").read_bytes() + + kernel_binaries = [] + for k in kernel_bins: + data = (artifact_dir / k["filename"]).read_bytes() + kernel_binaries.append((k["func_id"], data)) + + func_args = [] + arg_types = [] + arg_sizes = [] + for tok in args.args: + if tok == "nranks": + func_args.append(args.nranks) + elif tok == "root": + func_args.append(args.root) + elif tok == "deviceCtx": + func_args.append(device_ctx_ptr) + else: + b = buf_by_name.get(tok) + if not b: + sys.stderr.write(f"[rank {args.rank}] --arg: unknown token '{tok}'\n") + return 1 + func_args.append(b["dev_ptr"]) + # In distributed mode, all memory is pre-allocated by the worker + # (RDMA windows / device_malloc). Pass everything as scalar so + # the runtime doesn't try to re-allocate or copy. + arg_types.append(ARG_SCALAR) + arg_sizes.append(0) + + sys.stderr.write( + f"[rank {args.rank}] Launching kernel: {len(func_args)} args, " + f"{len(kernel_binaries)} kernels\n" + ) + + runtime = Runtime() + runtime.initialize( + orch_binary, + args.orch_func, + func_args, + arg_types=arg_types, + arg_sizes=arg_sizes, + kernel_binaries=kernel_binaries, + ) + + launch_runtime( + runtime, + aicpu_thread_num=args.aicpu_thread_num, + block_dim=args.block_dim, + device_id=args.device_id, + aicpu_binary=aicpu_binary, + aicore_binary=aicore_binary, + orch_thread_num=args.orch_thread_num, + ) + + runtime.finalize() + sys.stderr.write(f"[rank {args.rank}] Kernel execution complete\n") + + # ---------------------------------------------------------------- + # 7. Barrier + save outputs + # ---------------------------------------------------------------- + comm_barrier(comm) + + import ctypes + for name in args.saves: + b = buf_by_name.get(name) + if not b: + sys.stderr.write(f"[rank {args.rank}] --save: buffer '{name}' not found\n") + continue + nbytes = buf_bytes(b) + host_buf = (ctypes.c_uint8 * nbytes)() + copy_from_device(ctypes.addressof(host_buf), b["dev_ptr"], nbytes) + path = data_dir / f"{name}.bin" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(bytes(host_buf)) + sys.stderr.write(f"[rank {args.rank}] Saved '{name}' to {path} ({nbytes}B)\n") + + # ---------------------------------------------------------------- + # 8. Cleanup + # ---------------------------------------------------------------- + for b in buffers: + if b["placement"] == "device" and b.get("dev_ptr"): + device_free(b["dev_ptr"]) + + comm_destroy(comm) + sys.stderr.write(f"[rank {args.rank}] Done\n") + return 0 + + +if __name__ == "__main__": + sys.exit(main() or 0) diff --git a/examples/scripts/run_example.py b/examples/scripts/run_example.py index 36c06a05..c2a27f6f 100644 --- a/examples/scripts/run_example.py +++ b/examples/scripts/run_example.py @@ -73,6 +73,36 @@ def _wait_for_new_device_log(log_dir, pre_run_logs, timeout=15, interval=0.5): return None +def _parse_device_spec(spec): + """Expand a device spec like '4-7' or '0,1,3,5' into device ids.""" + if spec is None: + return None + + spec = spec.strip() + if not spec: + raise ValueError("Device spec must not be empty") + + device_ids = [] + for item in spec.split(","): + item = item.strip() + if not item: + continue + if "-" in item: + start_str, end_str = item.split("-", 1) + start = int(start_str) + end = int(end_str) + if end < start: + raise ValueError(f"Invalid device range '{item}': end < start") + device_ids.extend(range(start, end + 1)) + else: + device_ids.append(int(item)) + + if not device_ids: + raise ValueError("Device spec must contain at least one device") + + return device_ids + + def main(): parser = argparse.ArgumentParser( description="Run PTO runtime test with kernel config and golden script", @@ -192,10 +222,33 @@ def compute_golden(tensors: dict, params: dict) -> None: help="Git protocol for cloning pto-isa (default: ssh)" ) + parser.add_argument( + "--nranks", + type=int, + default=None, + help="Override number of ranks for distributed tests (default: from kernel_config)" + ) + + parser.add_argument( + "--device-range", + type=str, + default=None, + help="Explicit device range for distributed tests (e.g., 4-7)" + ) + + parser.add_argument( + "--devices", + type=str, + default=None, + help="Explicit distributed device list, supports comma lists/ranges (e.g., 0,1,3,5 or 4-7)" + ) + args = parser.parse_args() if args.all and args.case: parser.error("--all and --case are mutually exclusive") + if args.device_range and args.devices: + parser.error("--device-range and --devices are mutually exclusive") # Determine log level from arguments log_level_str = None @@ -246,6 +299,55 @@ def compute_golden(tensors: dict, params: dict) -> None: # Import and run try: + # Detect DISTRIBUTED_CONFIG to choose runner + import importlib.util as _ilu + _kc_spec = _ilu.spec_from_file_location("_kc_check", kernel_config_path) + _kc_mod = _ilu.module_from_spec(_kc_spec) + _kc_spec.loader.exec_module(_kc_mod) + is_distributed = hasattr(_kc_mod, "DISTRIBUTED_CONFIG") + + if is_distributed: + from distributed_code_runner import DistributedCodeRunner + + logger.info("Detected DISTRIBUTED_CONFIG — using distributed runner") + dist_cfg = getattr(_kc_mod, "DISTRIBUTED_CONFIG", {}) + + if args.devices is not None: + device_ids = _parse_device_spec(args.devices) + effective_nranks = len(device_ids) + elif args.device_range is not None: + device_ids = _parse_device_spec(args.device_range) + effective_nranks = len(device_ids) + else: + effective_nranks = args.nranks if args.nranks is not None else dist_cfg.get("nranks", 8) + device_ids = [args.device + i for i in range(effective_nranks)] + + if args.nranks is not None and args.nranks != effective_nranks: + raise ValueError( + f"--nranks={args.nranks} conflicts with device list " + f"({effective_nranks} devices)" + ) + + runner = DistributedCodeRunner( + kernels_dir=str(args.kernels), + golden_path=str(args.golden), + platform=args.platform, + nranks=effective_nranks, + device_ids=device_ids, + build_dir=args.savetemp, + pto_isa_commit=args.pto_isa_commit, + clone_protocol=args.clone_protocol, + ) + success = runner.run_all() + if success: + logger.info("=" * 60) + logger.info("TEST PASSED") + logger.info("=" * 60) + else: + logger.error("TEST FAILED") + return 1 + return 0 + from code_runner import create_code_runner runner = create_code_runner( diff --git a/python/bindings.py b/python/bindings.py index 6474f35d..fe50d9dc 100644 --- a/python/bindings.py +++ b/python/bindings.py @@ -180,6 +180,22 @@ def _setup_functions(self): self.lib.enable_runtime_profiling.argtypes = [c_void_p, c_int] self.lib.enable_runtime_profiling.restype = c_int + # --- Distributed communication API (comm_*) --- + self.lib.comm_init.argtypes = [c_int, c_int, c_char_p] + self.lib.comm_init.restype = c_void_p + + self.lib.comm_alloc_windows.argtypes = [c_void_p, c_size_t, POINTER(c_uint64)] + self.lib.comm_alloc_windows.restype = c_int + + self.lib.comm_get_local_window_base.argtypes = [c_void_p, POINTER(c_uint64)] + self.lib.comm_get_local_window_base.restype = c_int + + self.lib.comm_barrier.argtypes = [c_void_p] + self.lib.comm_barrier.restype = c_int + + self.lib.comm_destroy.argtypes = [c_void_p] + self.lib.comm_destroy.restype = c_int + # ============================================================================ # Python Wrapper Classes @@ -522,6 +538,123 @@ def launch_runtime( raise RuntimeError(f"launch_runtime failed: {rc}") +# ============================================================================ +# Distributed Communication Functions +# ============================================================================ + + +def comm_init(rank: int, nranks: int, rootinfo_path: str) -> int: + """ + Initialize a distributed communicator for the given rank. + + Args: + rank: This process's rank (0-based) + nranks: Total number of ranks + rootinfo_path: Filesystem path for root info exchange + + Returns: + Opaque comm handle (as integer) + + Raises: + RuntimeError: If not loaded or initialization fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + handle = _lib.comm_init(rank, nranks, rootinfo_path.encode('utf-8')) + if not handle: + raise RuntimeError(f"comm_init failed for rank {rank}") + return handle + + +def comm_alloc_windows(handle: int, win_size: int) -> int: + """ + Allocate RDMA / shared-memory windows. + + Args: + handle: Comm handle from comm_init() + win_size: Window size hint (bytes per rank) + + Returns: + Device pointer to CommDeviceContext struct + + Raises: + RuntimeError: If allocation fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + device_ctx = c_uint64(0) + rc = _lib.comm_alloc_windows(ctypes.c_void_p(handle), win_size, ctypes.byref(device_ctx)) + if rc != 0: + raise RuntimeError(f"comm_alloc_windows failed: {rc}") + return device_ctx.value + + +def comm_get_local_window_base(handle: int) -> int: + """ + Get the base address of this rank's local window. + + Args: + handle: Comm handle from comm_init() + + Returns: + Device-pointer base address + + Raises: + RuntimeError: If query fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + base = c_uint64(0) + rc = _lib.comm_get_local_window_base(ctypes.c_void_p(handle), ctypes.byref(base)) + if rc != 0: + raise RuntimeError(f"comm_get_local_window_base failed: {rc}") + return base.value + + +def comm_barrier(handle: int) -> None: + """ + Synchronize all ranks in the communicator. + + Args: + handle: Comm handle from comm_init() + + Raises: + RuntimeError: If barrier fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + rc = _lib.comm_barrier(ctypes.c_void_p(handle)) + if rc != 0: + raise RuntimeError(f"comm_barrier failed: {rc}") + + +def comm_destroy(handle: int) -> None: + """ + Destroy the communicator and release all resources. + + Args: + handle: Comm handle from comm_init() + + Raises: + RuntimeError: If destruction fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + rc = _lib.comm_destroy(ctypes.c_void_p(handle)) + if rc != 0: + raise RuntimeError(f"comm_destroy failed: {rc}") + + # ============================================================================ # Compile Strategy Functions # ============================================================================ diff --git a/src/a2a3/platform/include/common/comm_context.h b/src/a2a3/platform/include/common/comm_context.h new file mode 100644 index 00000000..d3b74c8b --- /dev/null +++ b/src/a2a3/platform/include/common/comm_context.h @@ -0,0 +1,30 @@ +/** + * CommDeviceContext — device-side distributed communication context. + * + * This struct is the ABI contract between host (comm_hccl.cpp / comm_sim.cpp) + * and device kernels. PTO communication instructions (TREDUCE, TGET, TPUT) + * access remote data through the GVA addresses in windowsIn[]/windowsOut[] + * via MTE2 DMA. + * + * On HCCL MESH topology the struct layout matches what HCCL returns directly. + * On RING topology the host builds it by extracting remote RDMA addresses + * from HcclOpResParam's remoteRes array. + * On simulation the host fills it with malloc'd pointers. + */ + +#pragma once + +#include + +static constexpr uint32_t COMM_MAX_RANK_NUM = 64; + +struct CommDeviceContext { + uint64_t workSpace; + uint64_t workSpaceSize; + + uint32_t rankId; + uint32_t rankNum; + uint64_t winSize; + uint64_t windowsIn[COMM_MAX_RANK_NUM]; + uint64_t windowsOut[COMM_MAX_RANK_NUM]; +}; diff --git a/src/a2a3/platform/include/host/comm.h b/src/a2a3/platform/include/host/comm.h new file mode 100644 index 00000000..1744e68b --- /dev/null +++ b/src/a2a3/platform/include/host/comm.h @@ -0,0 +1,92 @@ +/** + * Backend-neutral distributed communication C API. + * + * Provides five primitives for multi-rank communication: init, allocate + * shared windows, query local window base, barrier, and destroy. + * + * Implementations: + * onboard/host/comm_hccl.cpp — HCCL backend (links CANN hccl/hccl_fwk) + * sim/host/comm_sim.cpp — malloc-based simulation + * + * All functions are compiled into libhost_runtime.so. The linker selects + * the implementation at build time (onboard vs sim), with no runtime + * dispatch or virtual functions. + */ + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct CommHandle_* CommHandle; + +/** + * Initialize a communicator for the given rank. + * + * On the HCCL backend this performs ACL init, RootInfo exchange (rank 0 + * writes the file, others wait), and HcclCommInitRootInfo. + * + * @param rank This process's rank (0-based). + * @param nranks Total number of ranks. + * @param rootinfo_path Filesystem path used to exchange root info between + * ranks (rank 0 writes, others read). + * @return Opaque handle, or NULL on failure. + */ +CommHandle comm_init(int rank, int nranks, const char* rootinfo_path); + +/** + * Allocate RDMA / shared-memory windows and populate the device context. + * + * On HCCL this calls HcclAllocComResourceByTiling and extracts per-rank + * window addresses (MESH or RING topology). On sim it mallocs a shared + * region and partitions it. + * + * @param h Handle from comm_init(). + * @param win_size Window size hint (bytes per rank). The backend + * may allocate more; actual size is stored in the + * returned device context. + * @param device_ctx_out Receives a device pointer to a CommDeviceContext + * struct that can be passed to device kernels. + * @return 0 on success, non-zero on failure. + */ +int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out); + +/** + * Get the base address of this rank's local window. + * + * Window buffers allocated via comm_alloc_windows() are contiguous per + * rank. This returns the start of the local rank's region. + * + * @param h Handle from comm_init(). + * @param base_out Receives the device-pointer base address. + * @return 0 on success, non-zero on failure. + */ +int comm_get_local_window_base(CommHandle h, uint64_t* base_out); + +/** + * Synchronize all ranks. + * + * Blocks until every rank in the communicator has called comm_barrier(). + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_barrier(CommHandle h); + +/** + * Destroy the communicator and release all resources. + * + * After this call the handle is invalid. + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_destroy(CommHandle h); + +#ifdef __cplusplus +} +#endif diff --git a/src/a2a3/platform/onboard/host/CMakeLists.txt b/src/a2a3/platform/onboard/host/CMakeLists.txt index d83ecae7..11163e97 100644 --- a/src/a2a3/platform/onboard/host/CMakeLists.txt +++ b/src/a2a3/platform/onboard/host/CMakeLists.txt @@ -28,6 +28,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/host_log.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_hccl.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) foreach(SRC_DIR ${CUSTOM_SOURCE_DIRS}) @@ -87,6 +88,8 @@ target_link_libraries(host_runtime PRIVATE runtime ascendcl + hccl + hccl_fwk dl ) diff --git a/src/a2a3/platform/onboard/host/comm_hccl.cpp b/src/a2a3/platform/onboard/host/comm_hccl.cpp new file mode 100644 index 00000000..0ab606cb --- /dev/null +++ b/src/a2a3/platform/onboard/host/comm_hccl.cpp @@ -0,0 +1,475 @@ +/** + * HCCL backend for the comm_* distributed communication API. + * + * Implements the five functions declared in host/comm.h using Ascend + * HCCL (bundled with CANN). Handles both MESH and RING topologies + * when extracting per-rank RDMA window addresses. + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "hccl/hccl_comm.h" +#include "hccl/hccl_types.h" + +// Internal HCCL APIs (not in public headers) +extern "C" HcclResult HcclAllocComResourceByTiling(HcclComm comm, void* stream, + void* mc2Tiling, void** commContext); +extern "C" HcclResult HcomGetCommHandleByGroup(const char* group, HcclComm* commHandle); + +using CommTopo = uint32_t; +extern "C" HcclResult HcomGetL0TopoTypeEx(const char* group, CommTopo* topoType, + uint32_t isSetDevice); + +static constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0; +static constexpr uint32_t COMM_TOPO_MESH = 0b1u; + +using rtStream_t = void*; +static constexpr int32_t RT_STREAM_PRIORITY_DEFAULT = 0; +extern "C" int32_t rtStreamCreate(rtStream_t* stream, int32_t priority); +extern "C" int32_t rtStreamDestroy(rtStream_t stream); + +// ============================================================================ +// HCCL tiling structures (required by HcclAllocComResourceByTiling) +// ============================================================================ + +namespace { + +static constexpr uint32_t MAX_CC_TILING_NUM = 8U; +static constexpr uint32_t GROUP_NAME_SIZE = 128U; +static constexpr uint32_t ALG_CONFIG_SIZE = 128U; + +struct Mc2InitTilingInner { + uint32_t version; + uint32_t mc2HcommCnt; + uint32_t offset[MAX_CC_TILING_NUM]; + uint8_t debugMode; + uint8_t preparePosition; + uint16_t queueNum; + uint16_t commBlockNum; + uint8_t devType; + char reserved[17]; +}; + +struct Mc2cCTilingInner { + uint8_t skipLocalRankCopy; + uint8_t skipBufferWindowCopy; + uint8_t stepSize; + uint8_t version; + char reserved[9]; + uint8_t commEngine; + uint8_t srcDataType; + uint8_t dstDataType; + char groupName[GROUP_NAME_SIZE]; + char algConfig[ALG_CONFIG_SIZE]; + uint32_t opType; + uint32_t reduceType; +}; + +struct Mc2CommConfigV2 { + Mc2InitTilingInner init; + Mc2cCTilingInner inner; +}; + +// HCCL compat structs for RING topology parsing +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +static constexpr uint32_t COMPAT_LOCAL_NOTIFY_MAX_NUM = 64; +static constexpr uint32_t COMPAT_LOCAL_STREAM_MAX_NUM = 19; +static constexpr uint32_t COMPAT_AICPU_OP_NOTIFY_MAX_NUM = 2; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[COMPAT_LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[COMPAT_LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[COMPAT_AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; +}; + +struct AlgoTopoInfo { + uint32_t userRank; + uint32_t userRankSize; + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interXLinkDisable; + uint32_t floatOverflowMode; + uint32_t multiQpThreshold; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HcclMC2WorkSpace { + uint64_t workspace; + uint64_t workspaceSize; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParamHead { + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; +}; + +struct HcclOpResParam { + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[1]; +}; + +} // anonymous namespace + +// ============================================================================ +// Internal state +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string rootinfo_path; + + rtStream_t stream = nullptr; + HcclComm hccl_comm = nullptr; + + CommDeviceContext host_ctx{}; + CommDeviceContext* device_ctx = nullptr; + bool owns_device_ctx = false; +}; + +// ============================================================================ +// Helpers +// ============================================================================ + +static bool wait_for_file(const std::string& path, int timeout_sec = 120) { + for (int i = 0; i < timeout_sec * 10; ++i) { + std::ifstream f(path, std::ios::binary); + if (f.good()) { + auto sz = f.seekg(0, std::ios::end).tellg(); + if (sz >= static_cast(HCCL_ROOT_INFO_BYTES)) return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; +} + +static void file_barrier(const std::string& dir, int rank, int nranks, const std::string& tag) { + std::string my_marker = dir + "/barrier_" + tag + "_" + std::to_string(rank) + ".ready"; + { std::ofstream(my_marker) << "1"; } + + for (int r = 0; r < nranks; ++r) { + std::string marker = dir + "/barrier_" + tag + "_" + std::to_string(r) + ".ready"; + while (true) { + std::ifstream f(marker); + if (f.good()) break; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + + h->rank = rank; + h->nranks = nranks; + h->rootinfo_path = rootinfo_path; + + // ACL init + constexpr int kAclRepeatInit = 100002; + aclError aRet = aclInit(nullptr); + if (aRet != ACL_SUCCESS && static_cast(aRet) != kAclRepeatInit) { + fprintf(stderr, "[comm rank %d] aclInit failed: %d\n", rank, (int)aRet); + delete h; + return nullptr; + } + + // NOTE: Do NOT call aclrtSetDevice here — the caller (distributed_worker) + // already set the correct physical device via set_device(device_id). + // Calling aclrtSetDevice(rank) would override the context when + // rank != device_id (e.g. devices=[2,4,5,7]). + + // RootInfo exchange + HcclRootInfo rootInfo{}; + if (rank == 0) { + HcclResult hret = HcclGetRootInfo(&rootInfo); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank 0] HcclGetRootInfo failed: %d\n", (int)hret); + delete h; + return nullptr; + } + std::ofstream fout(rootinfo_path, std::ios::binary); + fout.write(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + fout.close(); + } else { + if (!wait_for_file(rootinfo_path)) { + fprintf(stderr, "[comm rank %d] Timeout waiting for rootinfo\n", rank); + delete h; + return nullptr; + } + std::ifstream fin(rootinfo_path, std::ios::binary); + fin.read(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + } + + // Create stream for HCCL operations + rtStreamCreate(&h->stream, RT_STREAM_PRIORITY_DEFAULT); + + // Init communicator + HcclResult hret = HcclCommInitRootInfo( + static_cast(nranks), &rootInfo, static_cast(rank), &h->hccl_comm); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank %d] HcclCommInitRootInfo failed: %d\n", rank, (int)hret); + if (h->stream) rtStreamDestroy(h->stream); + delete h; + return nullptr; + } + + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t /*win_size*/, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + char group[128] = {}; + HcclResult hret = HcclGetCommName(h->hccl_comm, group); + if (hret != HCCL_SUCCESS) return -1; + + CommTopo topoType = 0; + hret = HcomGetL0TopoTypeEx(group, &topoType, COMM_IS_NOT_SET_DEVICE); + if (hret != HCCL_SUCCESS) return -1; + + HcclComm commHandle = nullptr; + hret = HcomGetCommHandleByGroup(group, &commHandle); + if (hret != HCCL_SUCCESS) return -1; + + // File barrier so all ranks have completed HcclCommInitRootInfo + std::string barrier_dir = h->rootinfo_path; + auto last_slash = barrier_dir.rfind('/'); + if (last_slash != std::string::npos) + barrier_dir = barrier_dir.substr(0, last_slash); + file_barrier(barrier_dir, h->rank, h->nranks, "hccl_init"); + + // Tiling configuration for HcclAllocComResourceByTiling + Mc2CommConfigV2 tiling{}; + memset(&tiling, 0, sizeof(tiling)); + tiling.init.version = 100U; + tiling.init.mc2HcommCnt = 1U; + tiling.init.commBlockNum = 48U; + tiling.init.devType = 4U; + tiling.init.offset[0] = static_cast( + reinterpret_cast(&tiling.inner) - reinterpret_cast(&tiling.init)); + tiling.inner.opType = 18U; + tiling.inner.commEngine = 3U; + tiling.inner.version = 1U; + strncpy(tiling.inner.groupName, group, GROUP_NAME_SIZE - 1); + strncpy(tiling.inner.algConfig, "BatchWrite=level0:fullmesh", ALG_CONFIG_SIZE - 1); + + void* ctxPtr = nullptr; + hret = HcclAllocComResourceByTiling(commHandle, h->stream, &tiling, &ctxPtr); + if (hret != HCCL_SUCCESS || ctxPtr == nullptr) return -1; + + // Extract CommDeviceContext (topology-dependent) + aclError aRet; + if (topoType == COMM_TOPO_MESH) { + h->device_ctx = reinterpret_cast(ctxPtr); + aRet = aclrtMemcpy(&h->host_ctx, sizeof(h->host_ctx), + h->device_ctx, sizeof(h->host_ctx), ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + } else { + // RING topology: parse HcclOpResParam structure on device + auto* rawCtx = reinterpret_cast(ctxPtr); + + HcclOpResParamHead head{}; + const size_t headOff = offsetof(HcclOpResParam, localUsrRankId); + aRet = aclrtMemcpy(&head, sizeof(head), rawCtx + headOff, sizeof(head), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + const size_t remoteResOff = offsetof(HcclOpResParam, remoteRes); + const size_t remoteResBytes = head.rankSize * sizeof(RemoteResPtr); + std::vector remoteResArr(head.rankSize); + aRet = aclrtMemcpy(remoteResArr.data(), remoteResBytes, + rawCtx + remoteResOff, remoteResBytes, ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + memset(&h->host_ctx, 0, sizeof(h->host_ctx)); + + uint64_t wsFields[2] = {0, 0}; + aclrtMemcpy(wsFields, sizeof(wsFields), rawCtx, sizeof(wsFields), ACL_MEMCPY_DEVICE_TO_HOST); + h->host_ctx.workSpace = wsFields[0]; + h->host_ctx.workSpaceSize = wsFields[1]; + h->host_ctx.rankId = head.localUsrRankId; + h->host_ctx.rankNum = head.rankSize; + h->host_ctx.winSize = head.winSize; + + for (uint32_t i = 0; i < head.rankSize; ++i) { + if (i == head.localUsrRankId) { + h->host_ctx.windowsIn[i] = head.localWindowsIn; + continue; + } + uint64_t devPtr = remoteResArr[i].nextDevicePtr; + if (devPtr == 0) return -1; + + HcclRankRelationResV2 remoteInfo{}; + aRet = aclrtMemcpy(&remoteInfo, sizeof(remoteInfo), + reinterpret_cast(devPtr), sizeof(remoteInfo), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + h->host_ctx.windowsIn[i] = remoteInfo.windowsIn; + } + + void* newDevMem = nullptr; + aRet = aclrtMalloc(&newDevMem, sizeof(CommDeviceContext), ACL_MEM_MALLOC_HUGE_FIRST); + if (aRet != ACL_SUCCESS) return -1; + + aRet = aclrtMemcpy(newDevMem, sizeof(CommDeviceContext), + &h->host_ctx, sizeof(CommDeviceContext), ACL_MEMCPY_HOST_TO_DEVICE); + if (aRet != ACL_SUCCESS) { + aclrtFree(newDevMem); + return -1; + } + h->device_ctx = reinterpret_cast(newDevMem); + h->owns_device_ctx = true; + } + + *device_ctx_out = reinterpret_cast(h->device_ctx); + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h) return -1; + HcclBarrier(h->hccl_comm, (aclrtStream)h->stream); + aclrtSynchronizeStream((aclrtStream)h->stream); + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->owns_device_ctx && h->device_ctx) { + aclrtFree(h->device_ctx); + } + if (h->stream) rtStreamDestroy(h->stream); + if (h->hccl_comm) HcclCommDestroy(h->hccl_comm); + + // NOTE: Do NOT call aclrtResetDevice / aclFinalize here. + // Device lifecycle is owned by DeviceRunner (static singleton) whose + // destructor frees all tracked device memory before resetting the device. + // Resetting early would invalidate pointers still held by MemoryAllocator. + + delete h; + return 0; +} diff --git a/src/a2a3/platform/onboard/host/device_runner.cpp b/src/a2a3/platform/onboard/host/device_runner.cpp index 4e0431bd..50f9b577 100644 --- a/src/a2a3/platform/onboard/host/device_runner.cpp +++ b/src/a2a3/platform/onboard/host/device_runner.cpp @@ -9,6 +9,8 @@ #include +#include "acl/acl.h" + // Include HAL constants from CANN (header only, library loaded dynamically) #include "ascend_hal.h" #include "host/host_regs.h" // Register address retrieval @@ -577,10 +579,19 @@ int DeviceRunner::finalize() { // Free all remaining allocations (including handshake buffer and binGmAddr) mem_alloc_.finalize(); + int saved_device_id = device_id_; device_id_ = -1; worker_count_ = 0; aicore_kernel_binary_.clear(); + // Reset device and finalize ACL AFTER all device memory is freed. + // This was previously done in comm_destroy, but that ran before the + // static DeviceRunner destructor, causing rtFree failures (107000). + if (saved_device_id >= 0) { + aclrtResetDevice(saved_device_id); + aclFinalize(); + } + LOG_INFO("DeviceRunner finalized"); return 0; } diff --git a/src/a2a3/platform/sim/host/CMakeLists.txt b/src/a2a3/platform/sim/host/CMakeLists.txt index 1e304455..11f9dad8 100644 --- a/src/a2a3/platform/sim/host/CMakeLists.txt +++ b/src/a2a3/platform/sim/host/CMakeLists.txt @@ -33,6 +33,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../aicpu/platform_aicpu_affinity.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_sim.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) @@ -76,6 +77,7 @@ target_link_libraries(host_runtime PRIVATE pthread dl + rt ) set_target_properties(host_runtime PROPERTIES diff --git a/src/a2a3/platform/sim/host/comm_sim.cpp b/src/a2a3/platform/sim/host/comm_sim.cpp new file mode 100644 index 00000000..01693a3f --- /dev/null +++ b/src/a2a3/platform/sim/host/comm_sim.cpp @@ -0,0 +1,198 @@ +/** + * Simulation backend for the comm_* distributed communication API. + * + * Uses POSIX shared memory (shm_open + mmap) so that multiple *processes* + * (one per rank, spawned by DistributedCodeRunner) share the same RDMA + * window region. Synchronization primitives (barrier counters) live in + * the shared region itself, using GCC __atomic builtins which are safe + * on lock-free-capable types in mmap'd memory. + * + * Shared memory layout (page-aligned header + per-rank windows): + * [ SharedHeader (4096 bytes) ][ rank-0 window ][ rank-1 window ] ... + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static constexpr size_t HEADER_SIZE = 4096; + +namespace { + +struct SharedHeader { + volatile int nranks; + volatile int alloc_done; + volatile int ready_count; + volatile int barrier_count; + volatile int barrier_phase; + volatile int destroy_count; + size_t per_rank_win_size; +}; + +std::string make_shm_name(const char* rootinfo_path) { + size_t h = std::hash{}(rootinfo_path ? rootinfo_path : "default"); + char buf[64]; + std::snprintf(buf, sizeof(buf), "/simpler_comm_%zx", h); + return buf; +} + +} // anonymous namespace + +// ============================================================================ +// Per-handle state (process-local) +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string shm_name; + + void* mmap_base = nullptr; + size_t mmap_size = 0; + bool is_creator = false; + + CommDeviceContext host_ctx{}; +}; + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + + h->rank = rank; + h->nranks = nranks; + h->shm_name = make_shm_name(rootinfo_path); + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + size_t total = HEADER_SIZE + win_size * static_cast(h->nranks); + + int fd = shm_open(h->shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR, 0600); + if (fd >= 0) { + h->is_creator = true; + if (ftruncate(fd, static_cast(total)) != 0) { + std::perror("comm_sim: ftruncate"); + close(fd); + shm_unlink(h->shm_name.c_str()); + return -1; + } + } else if (errno == EEXIST) { + fd = shm_open(h->shm_name.c_str(), O_RDWR, 0600); + if (fd < 0) { std::perror("comm_sim: shm_open"); return -1; } + + // Wait for creator to finish ftruncate by checking file size + for (int i = 0; i < 5000; ++i) { + struct stat st; + if (fstat(fd, &st) == 0 && static_cast(st.st_size) >= total) break; + usleep(1000); + } + } else { + std::perror("comm_sim: shm_open O_EXCL"); + return -1; + } + + void* base = mmap(nullptr, total, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + close(fd); + if (base == MAP_FAILED) { std::perror("comm_sim: mmap"); return -1; } + + h->mmap_base = base; + h->mmap_size = total; + + auto* hdr = static_cast(base); + + if (h->is_creator) { + hdr->per_rank_win_size = win_size; + hdr->ready_count = 0; + hdr->barrier_count = 0; + hdr->barrier_phase = 0; + hdr->destroy_count = 0; + __atomic_store_n(&hdr->nranks, h->nranks, __ATOMIC_RELEASE); + __atomic_store_n(&hdr->alloc_done, 1, __ATOMIC_RELEASE); + } else { + while (__atomic_load_n(&hdr->alloc_done, __ATOMIC_ACQUIRE) == 0) { + usleep(100); + } + } + + auto* win_base = static_cast(base) + HEADER_SIZE; + + auto& ctx = h->host_ctx; + ctx.workSpace = 0; + ctx.workSpaceSize = 0; + ctx.rankId = static_cast(h->rank); + ctx.rankNum = static_cast(h->nranks); + ctx.winSize = win_size; + for (int i = 0; i < h->nranks; ++i) { + ctx.windowsIn[i] = reinterpret_cast( + win_base + static_cast(i) * win_size); + } + + *device_ctx_out = reinterpret_cast(&h->host_ctx); + + __atomic_add_fetch(&hdr->ready_count, 1, __ATOMIC_ACQ_REL); + while (__atomic_load_n(&hdr->ready_count, __ATOMIC_ACQUIRE) < h->nranks) { + usleep(100); + } + + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h || !h->mmap_base) return -1; + + auto* hdr = static_cast(h->mmap_base); + int phase = __atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE); + int arrived = __atomic_add_fetch(&hdr->barrier_count, 1, __ATOMIC_ACQ_REL); + + if (arrived == h->nranks) { + __atomic_store_n(&hdr->barrier_count, 0, __ATOMIC_RELEASE); + __atomic_add_fetch(&hdr->barrier_phase, 1, __ATOMIC_ACQ_REL); + } else { + while (__atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE) == phase) { + usleep(50); + } + } + + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->mmap_base) { + auto* hdr = static_cast(h->mmap_base); + int gone = __atomic_add_fetch(&hdr->destroy_count, 1, __ATOMIC_ACQ_REL); + + munmap(h->mmap_base, h->mmap_size); + h->mmap_base = nullptr; + + if (gone >= h->nranks) { + shm_unlink(h->shm_name.c_str()); + } + } + + delete h; + return 0; +} diff --git a/src/a5/platform/include/common/comm_context.h b/src/a5/platform/include/common/comm_context.h new file mode 100644 index 00000000..d3b74c8b --- /dev/null +++ b/src/a5/platform/include/common/comm_context.h @@ -0,0 +1,30 @@ +/** + * CommDeviceContext — device-side distributed communication context. + * + * This struct is the ABI contract between host (comm_hccl.cpp / comm_sim.cpp) + * and device kernels. PTO communication instructions (TREDUCE, TGET, TPUT) + * access remote data through the GVA addresses in windowsIn[]/windowsOut[] + * via MTE2 DMA. + * + * On HCCL MESH topology the struct layout matches what HCCL returns directly. + * On RING topology the host builds it by extracting remote RDMA addresses + * from HcclOpResParam's remoteRes array. + * On simulation the host fills it with malloc'd pointers. + */ + +#pragma once + +#include + +static constexpr uint32_t COMM_MAX_RANK_NUM = 64; + +struct CommDeviceContext { + uint64_t workSpace; + uint64_t workSpaceSize; + + uint32_t rankId; + uint32_t rankNum; + uint64_t winSize; + uint64_t windowsIn[COMM_MAX_RANK_NUM]; + uint64_t windowsOut[COMM_MAX_RANK_NUM]; +}; diff --git a/src/a5/platform/include/host/comm.h b/src/a5/platform/include/host/comm.h new file mode 100644 index 00000000..1744e68b --- /dev/null +++ b/src/a5/platform/include/host/comm.h @@ -0,0 +1,92 @@ +/** + * Backend-neutral distributed communication C API. + * + * Provides five primitives for multi-rank communication: init, allocate + * shared windows, query local window base, barrier, and destroy. + * + * Implementations: + * onboard/host/comm_hccl.cpp — HCCL backend (links CANN hccl/hccl_fwk) + * sim/host/comm_sim.cpp — malloc-based simulation + * + * All functions are compiled into libhost_runtime.so. The linker selects + * the implementation at build time (onboard vs sim), with no runtime + * dispatch or virtual functions. + */ + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct CommHandle_* CommHandle; + +/** + * Initialize a communicator for the given rank. + * + * On the HCCL backend this performs ACL init, RootInfo exchange (rank 0 + * writes the file, others wait), and HcclCommInitRootInfo. + * + * @param rank This process's rank (0-based). + * @param nranks Total number of ranks. + * @param rootinfo_path Filesystem path used to exchange root info between + * ranks (rank 0 writes, others read). + * @return Opaque handle, or NULL on failure. + */ +CommHandle comm_init(int rank, int nranks, const char* rootinfo_path); + +/** + * Allocate RDMA / shared-memory windows and populate the device context. + * + * On HCCL this calls HcclAllocComResourceByTiling and extracts per-rank + * window addresses (MESH or RING topology). On sim it mallocs a shared + * region and partitions it. + * + * @param h Handle from comm_init(). + * @param win_size Window size hint (bytes per rank). The backend + * may allocate more; actual size is stored in the + * returned device context. + * @param device_ctx_out Receives a device pointer to a CommDeviceContext + * struct that can be passed to device kernels. + * @return 0 on success, non-zero on failure. + */ +int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out); + +/** + * Get the base address of this rank's local window. + * + * Window buffers allocated via comm_alloc_windows() are contiguous per + * rank. This returns the start of the local rank's region. + * + * @param h Handle from comm_init(). + * @param base_out Receives the device-pointer base address. + * @return 0 on success, non-zero on failure. + */ +int comm_get_local_window_base(CommHandle h, uint64_t* base_out); + +/** + * Synchronize all ranks. + * + * Blocks until every rank in the communicator has called comm_barrier(). + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_barrier(CommHandle h); + +/** + * Destroy the communicator and release all resources. + * + * After this call the handle is invalid. + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_destroy(CommHandle h); + +#ifdef __cplusplus +} +#endif diff --git a/src/a5/platform/onboard/host/CMakeLists.txt b/src/a5/platform/onboard/host/CMakeLists.txt index d83ecae7..11163e97 100644 --- a/src/a5/platform/onboard/host/CMakeLists.txt +++ b/src/a5/platform/onboard/host/CMakeLists.txt @@ -28,6 +28,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/host_log.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_hccl.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) foreach(SRC_DIR ${CUSTOM_SOURCE_DIRS}) @@ -87,6 +88,8 @@ target_link_libraries(host_runtime PRIVATE runtime ascendcl + hccl + hccl_fwk dl ) diff --git a/src/a5/platform/onboard/host/comm_hccl.cpp b/src/a5/platform/onboard/host/comm_hccl.cpp new file mode 100644 index 00000000..0ab606cb --- /dev/null +++ b/src/a5/platform/onboard/host/comm_hccl.cpp @@ -0,0 +1,475 @@ +/** + * HCCL backend for the comm_* distributed communication API. + * + * Implements the five functions declared in host/comm.h using Ascend + * HCCL (bundled with CANN). Handles both MESH and RING topologies + * when extracting per-rank RDMA window addresses. + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "hccl/hccl_comm.h" +#include "hccl/hccl_types.h" + +// Internal HCCL APIs (not in public headers) +extern "C" HcclResult HcclAllocComResourceByTiling(HcclComm comm, void* stream, + void* mc2Tiling, void** commContext); +extern "C" HcclResult HcomGetCommHandleByGroup(const char* group, HcclComm* commHandle); + +using CommTopo = uint32_t; +extern "C" HcclResult HcomGetL0TopoTypeEx(const char* group, CommTopo* topoType, + uint32_t isSetDevice); + +static constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0; +static constexpr uint32_t COMM_TOPO_MESH = 0b1u; + +using rtStream_t = void*; +static constexpr int32_t RT_STREAM_PRIORITY_DEFAULT = 0; +extern "C" int32_t rtStreamCreate(rtStream_t* stream, int32_t priority); +extern "C" int32_t rtStreamDestroy(rtStream_t stream); + +// ============================================================================ +// HCCL tiling structures (required by HcclAllocComResourceByTiling) +// ============================================================================ + +namespace { + +static constexpr uint32_t MAX_CC_TILING_NUM = 8U; +static constexpr uint32_t GROUP_NAME_SIZE = 128U; +static constexpr uint32_t ALG_CONFIG_SIZE = 128U; + +struct Mc2InitTilingInner { + uint32_t version; + uint32_t mc2HcommCnt; + uint32_t offset[MAX_CC_TILING_NUM]; + uint8_t debugMode; + uint8_t preparePosition; + uint16_t queueNum; + uint16_t commBlockNum; + uint8_t devType; + char reserved[17]; +}; + +struct Mc2cCTilingInner { + uint8_t skipLocalRankCopy; + uint8_t skipBufferWindowCopy; + uint8_t stepSize; + uint8_t version; + char reserved[9]; + uint8_t commEngine; + uint8_t srcDataType; + uint8_t dstDataType; + char groupName[GROUP_NAME_SIZE]; + char algConfig[ALG_CONFIG_SIZE]; + uint32_t opType; + uint32_t reduceType; +}; + +struct Mc2CommConfigV2 { + Mc2InitTilingInner init; + Mc2cCTilingInner inner; +}; + +// HCCL compat structs for RING topology parsing +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +static constexpr uint32_t COMPAT_LOCAL_NOTIFY_MAX_NUM = 64; +static constexpr uint32_t COMPAT_LOCAL_STREAM_MAX_NUM = 19; +static constexpr uint32_t COMPAT_AICPU_OP_NOTIFY_MAX_NUM = 2; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[COMPAT_LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[COMPAT_LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[COMPAT_AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; +}; + +struct AlgoTopoInfo { + uint32_t userRank; + uint32_t userRankSize; + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interXLinkDisable; + uint32_t floatOverflowMode; + uint32_t multiQpThreshold; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HcclMC2WorkSpace { + uint64_t workspace; + uint64_t workspaceSize; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParamHead { + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; +}; + +struct HcclOpResParam { + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[1]; +}; + +} // anonymous namespace + +// ============================================================================ +// Internal state +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string rootinfo_path; + + rtStream_t stream = nullptr; + HcclComm hccl_comm = nullptr; + + CommDeviceContext host_ctx{}; + CommDeviceContext* device_ctx = nullptr; + bool owns_device_ctx = false; +}; + +// ============================================================================ +// Helpers +// ============================================================================ + +static bool wait_for_file(const std::string& path, int timeout_sec = 120) { + for (int i = 0; i < timeout_sec * 10; ++i) { + std::ifstream f(path, std::ios::binary); + if (f.good()) { + auto sz = f.seekg(0, std::ios::end).tellg(); + if (sz >= static_cast(HCCL_ROOT_INFO_BYTES)) return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; +} + +static void file_barrier(const std::string& dir, int rank, int nranks, const std::string& tag) { + std::string my_marker = dir + "/barrier_" + tag + "_" + std::to_string(rank) + ".ready"; + { std::ofstream(my_marker) << "1"; } + + for (int r = 0; r < nranks; ++r) { + std::string marker = dir + "/barrier_" + tag + "_" + std::to_string(r) + ".ready"; + while (true) { + std::ifstream f(marker); + if (f.good()) break; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + + h->rank = rank; + h->nranks = nranks; + h->rootinfo_path = rootinfo_path; + + // ACL init + constexpr int kAclRepeatInit = 100002; + aclError aRet = aclInit(nullptr); + if (aRet != ACL_SUCCESS && static_cast(aRet) != kAclRepeatInit) { + fprintf(stderr, "[comm rank %d] aclInit failed: %d\n", rank, (int)aRet); + delete h; + return nullptr; + } + + // NOTE: Do NOT call aclrtSetDevice here — the caller (distributed_worker) + // already set the correct physical device via set_device(device_id). + // Calling aclrtSetDevice(rank) would override the context when + // rank != device_id (e.g. devices=[2,4,5,7]). + + // RootInfo exchange + HcclRootInfo rootInfo{}; + if (rank == 0) { + HcclResult hret = HcclGetRootInfo(&rootInfo); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank 0] HcclGetRootInfo failed: %d\n", (int)hret); + delete h; + return nullptr; + } + std::ofstream fout(rootinfo_path, std::ios::binary); + fout.write(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + fout.close(); + } else { + if (!wait_for_file(rootinfo_path)) { + fprintf(stderr, "[comm rank %d] Timeout waiting for rootinfo\n", rank); + delete h; + return nullptr; + } + std::ifstream fin(rootinfo_path, std::ios::binary); + fin.read(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + } + + // Create stream for HCCL operations + rtStreamCreate(&h->stream, RT_STREAM_PRIORITY_DEFAULT); + + // Init communicator + HcclResult hret = HcclCommInitRootInfo( + static_cast(nranks), &rootInfo, static_cast(rank), &h->hccl_comm); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank %d] HcclCommInitRootInfo failed: %d\n", rank, (int)hret); + if (h->stream) rtStreamDestroy(h->stream); + delete h; + return nullptr; + } + + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t /*win_size*/, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + char group[128] = {}; + HcclResult hret = HcclGetCommName(h->hccl_comm, group); + if (hret != HCCL_SUCCESS) return -1; + + CommTopo topoType = 0; + hret = HcomGetL0TopoTypeEx(group, &topoType, COMM_IS_NOT_SET_DEVICE); + if (hret != HCCL_SUCCESS) return -1; + + HcclComm commHandle = nullptr; + hret = HcomGetCommHandleByGroup(group, &commHandle); + if (hret != HCCL_SUCCESS) return -1; + + // File barrier so all ranks have completed HcclCommInitRootInfo + std::string barrier_dir = h->rootinfo_path; + auto last_slash = barrier_dir.rfind('/'); + if (last_slash != std::string::npos) + barrier_dir = barrier_dir.substr(0, last_slash); + file_barrier(barrier_dir, h->rank, h->nranks, "hccl_init"); + + // Tiling configuration for HcclAllocComResourceByTiling + Mc2CommConfigV2 tiling{}; + memset(&tiling, 0, sizeof(tiling)); + tiling.init.version = 100U; + tiling.init.mc2HcommCnt = 1U; + tiling.init.commBlockNum = 48U; + tiling.init.devType = 4U; + tiling.init.offset[0] = static_cast( + reinterpret_cast(&tiling.inner) - reinterpret_cast(&tiling.init)); + tiling.inner.opType = 18U; + tiling.inner.commEngine = 3U; + tiling.inner.version = 1U; + strncpy(tiling.inner.groupName, group, GROUP_NAME_SIZE - 1); + strncpy(tiling.inner.algConfig, "BatchWrite=level0:fullmesh", ALG_CONFIG_SIZE - 1); + + void* ctxPtr = nullptr; + hret = HcclAllocComResourceByTiling(commHandle, h->stream, &tiling, &ctxPtr); + if (hret != HCCL_SUCCESS || ctxPtr == nullptr) return -1; + + // Extract CommDeviceContext (topology-dependent) + aclError aRet; + if (topoType == COMM_TOPO_MESH) { + h->device_ctx = reinterpret_cast(ctxPtr); + aRet = aclrtMemcpy(&h->host_ctx, sizeof(h->host_ctx), + h->device_ctx, sizeof(h->host_ctx), ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + } else { + // RING topology: parse HcclOpResParam structure on device + auto* rawCtx = reinterpret_cast(ctxPtr); + + HcclOpResParamHead head{}; + const size_t headOff = offsetof(HcclOpResParam, localUsrRankId); + aRet = aclrtMemcpy(&head, sizeof(head), rawCtx + headOff, sizeof(head), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + const size_t remoteResOff = offsetof(HcclOpResParam, remoteRes); + const size_t remoteResBytes = head.rankSize * sizeof(RemoteResPtr); + std::vector remoteResArr(head.rankSize); + aRet = aclrtMemcpy(remoteResArr.data(), remoteResBytes, + rawCtx + remoteResOff, remoteResBytes, ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + memset(&h->host_ctx, 0, sizeof(h->host_ctx)); + + uint64_t wsFields[2] = {0, 0}; + aclrtMemcpy(wsFields, sizeof(wsFields), rawCtx, sizeof(wsFields), ACL_MEMCPY_DEVICE_TO_HOST); + h->host_ctx.workSpace = wsFields[0]; + h->host_ctx.workSpaceSize = wsFields[1]; + h->host_ctx.rankId = head.localUsrRankId; + h->host_ctx.rankNum = head.rankSize; + h->host_ctx.winSize = head.winSize; + + for (uint32_t i = 0; i < head.rankSize; ++i) { + if (i == head.localUsrRankId) { + h->host_ctx.windowsIn[i] = head.localWindowsIn; + continue; + } + uint64_t devPtr = remoteResArr[i].nextDevicePtr; + if (devPtr == 0) return -1; + + HcclRankRelationResV2 remoteInfo{}; + aRet = aclrtMemcpy(&remoteInfo, sizeof(remoteInfo), + reinterpret_cast(devPtr), sizeof(remoteInfo), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + h->host_ctx.windowsIn[i] = remoteInfo.windowsIn; + } + + void* newDevMem = nullptr; + aRet = aclrtMalloc(&newDevMem, sizeof(CommDeviceContext), ACL_MEM_MALLOC_HUGE_FIRST); + if (aRet != ACL_SUCCESS) return -1; + + aRet = aclrtMemcpy(newDevMem, sizeof(CommDeviceContext), + &h->host_ctx, sizeof(CommDeviceContext), ACL_MEMCPY_HOST_TO_DEVICE); + if (aRet != ACL_SUCCESS) { + aclrtFree(newDevMem); + return -1; + } + h->device_ctx = reinterpret_cast(newDevMem); + h->owns_device_ctx = true; + } + + *device_ctx_out = reinterpret_cast(h->device_ctx); + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h) return -1; + HcclBarrier(h->hccl_comm, (aclrtStream)h->stream); + aclrtSynchronizeStream((aclrtStream)h->stream); + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->owns_device_ctx && h->device_ctx) { + aclrtFree(h->device_ctx); + } + if (h->stream) rtStreamDestroy(h->stream); + if (h->hccl_comm) HcclCommDestroy(h->hccl_comm); + + // NOTE: Do NOT call aclrtResetDevice / aclFinalize here. + // Device lifecycle is owned by DeviceRunner (static singleton) whose + // destructor frees all tracked device memory before resetting the device. + // Resetting early would invalidate pointers still held by MemoryAllocator. + + delete h; + return 0; +} diff --git a/src/a5/platform/sim/host/CMakeLists.txt b/src/a5/platform/sim/host/CMakeLists.txt index bdb0ce53..2c2f7537 100644 --- a/src/a5/platform/sim/host/CMakeLists.txt +++ b/src/a5/platform/sim/host/CMakeLists.txt @@ -32,6 +32,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/host_log.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_sim.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) @@ -75,6 +76,7 @@ target_link_libraries(host_runtime PRIVATE pthread dl + rt ) set_target_properties(host_runtime PROPERTIES diff --git a/src/a5/platform/sim/host/comm_sim.cpp b/src/a5/platform/sim/host/comm_sim.cpp new file mode 100644 index 00000000..01693a3f --- /dev/null +++ b/src/a5/platform/sim/host/comm_sim.cpp @@ -0,0 +1,198 @@ +/** + * Simulation backend for the comm_* distributed communication API. + * + * Uses POSIX shared memory (shm_open + mmap) so that multiple *processes* + * (one per rank, spawned by DistributedCodeRunner) share the same RDMA + * window region. Synchronization primitives (barrier counters) live in + * the shared region itself, using GCC __atomic builtins which are safe + * on lock-free-capable types in mmap'd memory. + * + * Shared memory layout (page-aligned header + per-rank windows): + * [ SharedHeader (4096 bytes) ][ rank-0 window ][ rank-1 window ] ... + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static constexpr size_t HEADER_SIZE = 4096; + +namespace { + +struct SharedHeader { + volatile int nranks; + volatile int alloc_done; + volatile int ready_count; + volatile int barrier_count; + volatile int barrier_phase; + volatile int destroy_count; + size_t per_rank_win_size; +}; + +std::string make_shm_name(const char* rootinfo_path) { + size_t h = std::hash{}(rootinfo_path ? rootinfo_path : "default"); + char buf[64]; + std::snprintf(buf, sizeof(buf), "/simpler_comm_%zx", h); + return buf; +} + +} // anonymous namespace + +// ============================================================================ +// Per-handle state (process-local) +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string shm_name; + + void* mmap_base = nullptr; + size_t mmap_size = 0; + bool is_creator = false; + + CommDeviceContext host_ctx{}; +}; + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + + h->rank = rank; + h->nranks = nranks; + h->shm_name = make_shm_name(rootinfo_path); + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + size_t total = HEADER_SIZE + win_size * static_cast(h->nranks); + + int fd = shm_open(h->shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR, 0600); + if (fd >= 0) { + h->is_creator = true; + if (ftruncate(fd, static_cast(total)) != 0) { + std::perror("comm_sim: ftruncate"); + close(fd); + shm_unlink(h->shm_name.c_str()); + return -1; + } + } else if (errno == EEXIST) { + fd = shm_open(h->shm_name.c_str(), O_RDWR, 0600); + if (fd < 0) { std::perror("comm_sim: shm_open"); return -1; } + + // Wait for creator to finish ftruncate by checking file size + for (int i = 0; i < 5000; ++i) { + struct stat st; + if (fstat(fd, &st) == 0 && static_cast(st.st_size) >= total) break; + usleep(1000); + } + } else { + std::perror("comm_sim: shm_open O_EXCL"); + return -1; + } + + void* base = mmap(nullptr, total, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + close(fd); + if (base == MAP_FAILED) { std::perror("comm_sim: mmap"); return -1; } + + h->mmap_base = base; + h->mmap_size = total; + + auto* hdr = static_cast(base); + + if (h->is_creator) { + hdr->per_rank_win_size = win_size; + hdr->ready_count = 0; + hdr->barrier_count = 0; + hdr->barrier_phase = 0; + hdr->destroy_count = 0; + __atomic_store_n(&hdr->nranks, h->nranks, __ATOMIC_RELEASE); + __atomic_store_n(&hdr->alloc_done, 1, __ATOMIC_RELEASE); + } else { + while (__atomic_load_n(&hdr->alloc_done, __ATOMIC_ACQUIRE) == 0) { + usleep(100); + } + } + + auto* win_base = static_cast(base) + HEADER_SIZE; + + auto& ctx = h->host_ctx; + ctx.workSpace = 0; + ctx.workSpaceSize = 0; + ctx.rankId = static_cast(h->rank); + ctx.rankNum = static_cast(h->nranks); + ctx.winSize = win_size; + for (int i = 0; i < h->nranks; ++i) { + ctx.windowsIn[i] = reinterpret_cast( + win_base + static_cast(i) * win_size); + } + + *device_ctx_out = reinterpret_cast(&h->host_ctx); + + __atomic_add_fetch(&hdr->ready_count, 1, __ATOMIC_ACQ_REL); + while (__atomic_load_n(&hdr->ready_count, __ATOMIC_ACQUIRE) < h->nranks) { + usleep(100); + } + + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h || !h->mmap_base) return -1; + + auto* hdr = static_cast(h->mmap_base); + int phase = __atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE); + int arrived = __atomic_add_fetch(&hdr->barrier_count, 1, __ATOMIC_ACQ_REL); + + if (arrived == h->nranks) { + __atomic_store_n(&hdr->barrier_count, 0, __ATOMIC_RELEASE); + __atomic_add_fetch(&hdr->barrier_phase, 1, __ATOMIC_ACQ_REL); + } else { + while (__atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE) == phase) { + usleep(50); + } + } + + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->mmap_base) { + auto* hdr = static_cast(h->mmap_base); + int gone = __atomic_add_fetch(&hdr->destroy_count, 1, __ATOMIC_ACQ_REL); + + munmap(h->mmap_base, h->mmap_size); + h->mmap_base = nullptr; + + if (gone >= h->nranks) { + shm_unlink(h->shm_name.c_str()); + } + } + + delete h; + return 0; +}