Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 134 additions & 23 deletions ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@ while [[ $# -gt 0 ]]; do
shift 2
;;
-d|--device)
DEVICE_RANGE="$2"
shift 2
shift
DEVICE_VALUES=()
while [[ $# -gt 0 && "$1" != -* ]]; do
DEVICE_VALUES+=("$1")
shift
done
if [[ ${#DEVICE_VALUES[@]} -eq 0 ]]; then
echo "Missing value for --device"
exit 1
fi
DEVICE_RANGE=$(IFS=,; echo "${DEVICE_VALUES[*]}")
;;
-r|--runtime)
RUNTIME="$2"
Expand Down Expand Up @@ -78,15 +87,22 @@ if [[ -n "$RUNTIME" ]]; then
fi
fi

# Parse device range (e.g., "5-8" or "5")
if [[ "$DEVICE_RANGE" == *-* ]]; then
IFS='-' read -r DEV_START DEV_END <<< "$DEVICE_RANGE"
DEVICES=()
for ((i=DEV_START; i<=DEV_END; i++)); do
DEVICES+=("$i")
done
# Parse device spec (e.g., "5-8", "5", or "0,1,3,5")
DEVICES=()
if [[ -z "$DEVICE_RANGE" ]]; then
DEVICES=("0")
else
DEVICES=("${DEVICE_RANGE:-0}")
IFS=',' read -r -a DEVICE_ITEMS <<< "$DEVICE_RANGE"
for item in "${DEVICE_ITEMS[@]}"; do
if [[ "$item" == *-* ]]; then
IFS='-' read -r DEV_START DEV_END <<< "$item"
for ((i=DEV_START; i<=DEV_END; i++)); do
DEVICES+=("$i")
done
else
DEVICES+=("$item")
fi
done
fi
NUM_DEVICES=${#DEVICES[@]}

Expand Down Expand Up @@ -199,13 +215,48 @@ pin_pto_isa_on_failure() {
return 0 # Pinned, caller should retry
}

get_task_device_count() {
local kernel_config="$1"
python - "$kernel_config" <<'PY'
import importlib.util
import sys

path = sys.argv[1]
spec = importlib.util.spec_from_file_location("kernel_config", path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
dist = getattr(mod, "DISTRIBUTED_CONFIG", None)
nranks = 1
if isinstance(dist, dict):
try:
nranks = int(dist.get("nranks", 1))
except (TypeError, ValueError):
nranks = 1
print(max(nranks, 1))
PY
}

format_device_spec() {
local count="$1"
if [[ "$count" -le 1 ]]; then
echo "${DEVICES[0]}"
return 0
fi

local selected=("${DEVICES[@]:0:$count}")
local joined
joined=$(IFS=,; echo "${selected[*]}")
echo "$joined"
}

# ---- Discover all tasks ----
EXAMPLES_DIR="examples"
DEVICE_TESTS_DIR="tests/device_tests"

declare -a HW_TASK_NAMES=()
declare -a HW_TASK_DIRS=()
declare -a HW_TASK_PLATS=()
declare -a HW_TASK_DEVICE_COUNTS=()
declare -a SIM_TASK_NAMES=()
declare -a SIM_TASK_DIRS=()
declare -a SIM_TASK_PLATS=()
Expand Down Expand Up @@ -245,18 +296,22 @@ while IFS= read -r -d '' example_dir; do
SIM_TASK_DIRS+=("${example_dir}")
SIM_TASK_PLATS+=("${PLATFORM}")
else
required_devices="$(get_task_device_count "$kernel_config")"
HW_TASK_NAMES+=("example:${example_name}")
HW_TASK_DIRS+=("${example_dir}")
HW_TASK_PLATS+=("${PLATFORM}")
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
fi
elif [[ "$OS" == "Darwin" ]]; then
SIM_TASK_NAMES+=("example:${example_name}")
SIM_TASK_DIRS+=("${example_dir}")
SIM_TASK_PLATS+=("${example_arch}sim")
else
required_devices="$(get_task_device_count "$kernel_config")"
HW_TASK_NAMES+=("example:${example_name}")
HW_TASK_DIRS+=("${example_dir}")
HW_TASK_PLATS+=("${example_arch}")
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
SIM_TASK_NAMES+=("example:${example_name}")
SIM_TASK_DIRS+=("${example_dir}")
SIM_TASK_PLATS+=("${example_arch}sim")
Expand Down Expand Up @@ -299,6 +354,7 @@ if [[ -d "$DEVICE_TESTS_DIR" ]]; then
HW_TASK_NAMES+=("device_test:${test_name}")
HW_TASK_DIRS+=("${test_dir}")
HW_TASK_PLATS+=("${PLATFORM:-${test_arch}}")
HW_TASK_DEVICE_COUNTS+=("$(get_task_device_count "$kernel_config")")
done < <(find "$DEVICE_TESTS_DIR" -mindepth 1 -type d -print0 | sort -z)
else
echo "Skipping device tests (hardware platforms only)"
Expand All @@ -314,7 +370,7 @@ MAX_RETRIES=3
# Log naming: ${safe_name}_${platform}_attempt${attempt}.log
# Result format: name|platform|PASS/FAIL|device:X|attempt:N|Xs
run_task() {
local name="$1" dir="$2" platform="$3" attempt="$4" device_id="$5" print_log_on_fail="${6:-true}"
local name="$1" dir="$2" platform="$3" attempt="$4" device_spec="$5" print_log_on_fail="${6:-true}" required_devices="${7:-1}"
local safe_name="${name//[:\/]/_}"
local task_log="${LOG_DIR}/${safe_name}_${platform}_attempt${attempt}.log"
local start_time=$SECONDS
Expand All @@ -323,10 +379,16 @@ run_task() {
cmd=(python examples/scripts/run_example.py
-k "${dir}/kernels" -g "${dir}/golden.py"
-p "$platform" --clone-protocol "$CLONE_PROTOCOL" "${commit_flag[@]}")
[[ -n "$device_id" ]] && cmd+=(-d "$device_id")
if [[ -n "$device_spec" ]]; then
if [[ "$required_devices" -gt 1 ]]; then
cmd+=(--devices "$device_spec" --nranks "$required_devices")
else
cmd+=(-d "$device_spec")
fi
fi

# Progress to stdout (not captured in log)
echo "[${platform}${device_id:+:dev${device_id}}] Running: $name (attempt $attempt)"
echo "[${platform}${device_spec:+:dev${device_spec}}] Running: $name (attempt $attempt)"

# Command output to log file only
"${cmd[@]}" > "$task_log" 2>&1
Expand All @@ -336,21 +398,46 @@ run_task() {
local status
if [[ $rc -eq 0 ]]; then
status="PASS"
echo "[${platform}${device_id:+:dev${device_id}}] PASS: $name (${elapsed}s)"
echo "[${platform}${device_spec:+:dev${device_spec}}] PASS: $name (${elapsed}s)"
else
status="FAIL"
echo "[${platform}${device_id:+:dev${device_id}}] FAIL: $name (${elapsed}s)"
echo "[${platform}${device_spec:+:dev${device_spec}}] FAIL: $name (${elapsed}s)"
if [[ "$print_log_on_fail" == "true" ]]; then
echo "--- LOG: $name (attempt $attempt) ---"
cat "$task_log"
echo "--- END ---"
fi
fi
echo "${name}|${platform}|${status}|device:${device_id:-sim}|attempt:${attempt}|${elapsed}s" \
echo "${name}|${platform}|${status}|device:${device_spec:-sim}|attempt:${attempt}|${elapsed}s" \
>> "$RESULTS_FILE"
return $rc
}

run_hw_multidevice_tasks() {
local attempt="$1"; shift
local indices=("$@")
HW_MULTI_FAILURES=()

for idx in "${indices[@]}"; do
local required_devices="${HW_TASK_DEVICE_COUNTS[$idx]}"
local platform="${HW_TASK_PLATS[$idx]}"
local name="${HW_TASK_NAMES[$idx]}"

if [[ "$required_devices" -gt "$NUM_DEVICES" ]]; then
echo "[${platform}] FAIL: $name requires ${required_devices} devices, only ${NUM_DEVICES} available"
echo "${name}|${platform}|FAIL|device:insufficient|attempt:${attempt}|0s" >> "$RESULTS_FILE"
HW_MULTI_FAILURES+=("$idx")
continue
fi

local device_spec
device_spec="$(format_device_spec "$required_devices")"
if ! run_task "$name" "${HW_TASK_DIRS[$idx]}" "$platform" "$attempt" "$device_spec" "true" "$required_devices"; then
HW_MULTI_FAILURES+=("$idx")
fi
done
}

# ---- SIM executor ----
# run_sim_tasks <attempt> <idx1> <idx2> ...
# Sets SIM_FAILURES to array of failed indices after return.
Expand Down Expand Up @@ -429,7 +516,7 @@ run_hw_tasks() {

IFS=':' read -r idx attempt <<< "$entry"

if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false"; then
if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false" "1"; then
echo "${idx}|PASS" >> "$hw_marker"
else
next=$((attempt + 1))
Expand Down Expand Up @@ -473,12 +560,36 @@ fi

# HW phase
if [[ ${#HW_TASK_NAMES[@]} -gt 0 ]]; then
ALL_HW=($(seq 0 $((${#HW_TASK_NAMES[@]} - 1))))
echo "---- HW: ${#ALL_HW[@]} tasks on ${NUM_DEVICES} devices ----"
run_hw_tasks "${ALL_HW[@]}"
if [[ ${#HW_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
echo "[CI] Retrying ${#HW_FAILURES[@]} HW failures with pinned PTO-ISA"
run_hw_tasks "${HW_FAILURES[@]}"
ALL_HW_SINGLE=()
ALL_HW_MULTI=()
for idx in $(seq 0 $((${#HW_TASK_NAMES[@]} - 1))); do
if [[ "${HW_TASK_DEVICE_COUNTS[$idx]}" -gt 1 ]]; then
ALL_HW_MULTI+=("$idx")
else
ALL_HW_SINGLE+=("$idx")
fi
done

echo "---- HW: ${#ALL_HW_SINGLE[@]} single-device tasks, ${#ALL_HW_MULTI[@]} multi-device tasks on ${NUM_DEVICES} devices ----"

HW_MULTI_FAILURES=()
if [[ ${#ALL_HW_MULTI[@]} -gt 0 ]]; then
run_hw_multidevice_tasks 0 "${ALL_HW_MULTI[@]}"
if [[ ${#HW_MULTI_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
echo "[CI] Retrying ${#HW_MULTI_FAILURES[@]} multi-device HW failures with pinned PTO-ISA"
run_hw_multidevice_tasks 1 "${HW_MULTI_FAILURES[@]}"
fi
fi

HW_SINGLE_FAILURES=()
if [[ ${#ALL_HW_SINGLE[@]} -gt 0 ]]; then
run_hw_tasks "${ALL_HW_SINGLE[@]}"
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
if [[ ${#HW_SINGLE_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
echo "[CI] Retrying ${#HW_SINGLE_FAILURES[@]} HW failures with pinned PTO-ISA"
run_hw_tasks "${HW_SINGLE_FAILURES[@]}"
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
fi
fi
fi

Expand Down
40 changes: 40 additions & 0 deletions examples/a2a3/aicpu_build_graph/treduce_distributed/golden.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Golden script for distributed TREDUCE.

Each rank r contributes input[i] = i + r * 100 for i in [0, 256).
Root rank reduces (Sum) all inputs.

Expected output on root:
output[i] = sum_{r=0}^{nranks-1} (i + r * 100)
= nranks * i + 100 * nranks * (nranks - 1) / 2
"""

TREDUCE_COUNT = 256
NRANKS = 4

__outputs__ = ["output"]

RTOL = 1e-5
ATOL = 1e-5


def generate_distributed_inputs(rank: int, nranks: int, root: int,
comm_ctx=None) -> list:
"""Each rank generates its own input; output is allocated on all ranks."""
input_data = [float(i + rank * 100) for i in range(TREDUCE_COUNT)]
output_data = [0.0] * TREDUCE_COUNT
return [
("input", input_data),
("output", output_data),
("nranks", nranks),
("root", root),
]


def compute_golden(tensors: dict, params: dict) -> None:
"""Compute expected output for the root rank."""
nranks = params.get("nranks", NRANKS)
output = tensors["output"]
for i in range(TREDUCE_COUNT):
output[i] = float(
nranks * i + 100 * nranks * (nranks - 1) // 2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/**
* TREDUCE kernel for simpler's kernel_entry signature.
*
* Performs collective reduce (Sum) across multiple NPU ranks using PTO comm
* instructions. Each rank's input data resides in an RDMA window;
* the root rank gathers and sums all inputs into the output buffer.
*
* PTO communication instructions access remote data through GVA addresses
* (windowsIn[]) via MTE2 DMA over HCCS; no bound stream is required.
*
* args layout (all uint64_t, cast as needed):
* args[0] = __gm__ float* input (device addr in RDMA window)
* args[1] = __gm__ float* output (device addr, regular allocation)
* args[2] = int nranks
* args[3] = int root
* args[4] = __gm__ CommDeviceContext* ctx (device addr)
*/

#include <cstdint>
#include <pto/pto-inst.hpp>
#include "pto/comm/comm_types.hpp"
#include "pto/comm/pto_comm_inst.hpp"
#include "common/comm_context.h"

#ifndef __gm__
#define __gm__
#endif

#ifndef __aicore__
#define __aicore__ [aicore]
#endif

static constexpr size_t TREDUCE_COUNT = 256;
static constexpr int kMaxSupportedRanks = 16;

template <typename T>
AICORE inline __gm__ T *CommRemotePtr(
__gm__ CommDeviceContext *ctx, __gm__ T *localPtr, int pe)
{
uint64_t localBase = ctx->windowsIn[ctx->rankId];
uint64_t offset = (uint64_t)localPtr - localBase;
return (__gm__ T *)(ctx->windowsIn[pe] + offset);
}

extern "C" __aicore__ __attribute__((always_inline))
void kernel_entry(__gm__ int64_t* args) {
__gm__ float* input = reinterpret_cast<__gm__ float*>(args[0]);
__gm__ float* output = reinterpret_cast<__gm__ float*>(args[1]);
int nranks = static_cast<int>(args[2]);
int root = static_cast<int>(args[3]);
__gm__ CommDeviceContext* commCtx =
reinterpret_cast<__gm__ CommDeviceContext*>(args[4]);

using ShapeDyn = pto::Shape<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC,
pto::DYNAMIC, pto::DYNAMIC>;
using StrideDyn = pto::Stride<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC,
pto::DYNAMIC, pto::DYNAMIC>;
using Global = pto::GlobalTensor<float, ShapeDyn, StrideDyn,
pto::Layout::ND>;
using TileData = pto::Tile<pto::TileType::Vec, float, 1, TREDUCE_COUNT,
pto::BLayout::RowMajor, -1, -1>;

int my_rank = static_cast<int>(commCtx->rankId);

ShapeDyn shape(1, 1, 1, 1, TREDUCE_COUNT);
StrideDyn stride(TREDUCE_COUNT, TREDUCE_COUNT, TREDUCE_COUNT,
TREDUCE_COUNT, 1);

TileData accTile(1, TREDUCE_COUNT);
TileData recvTile(1, TREDUCE_COUNT);
TASSIGN(accTile, 0x0);
TASSIGN(recvTile, 0x10000);

if (nranks <= 0 || nranks > kMaxSupportedRanks || root < 0 || root >= nranks) {
pipe_barrier(PIPE_ALL);
return;
}

if (my_rank == root) {
Global outputG(output, shape, stride);
Global tensors[kMaxSupportedRanks];
for (int i = 0; i < nranks; ++i) {
__gm__ float* remoteInput = CommRemotePtr(commCtx, input, i);
tensors[i] = Global(remoteInput, shape, stride);
}
pto::comm::ParallelGroup<Global> pg(tensors, nranks, root);
pto::comm::TREDUCE(pg, outputG, accTile, recvTile,
pto::comm::ReduceOp::Sum);
}

pipe_barrier(PIPE_ALL);
}
Loading