diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 32c3cb2a2..84271b4aa 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -34,6 +34,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +permissions: + pull-requests: write + jobs: build_and_test: name: Build and Test on GPU (${{ matrix.runner }}) @@ -368,6 +371,168 @@ jobs: EOF )" + - name: Performance regression check + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + GH_REPO: ${{ github.repository }} + RUNNER_NAME: ${{ matrix.runner }} + run: | + set -ex + + # Map runner names to display names + case "${RUNNER_NAME}" in + linux-te-mi325*) DISPLAY_NAME="MI325" ;; + linux-te-mi355*) DISPLAY_NAME="MI355" ;; + *) DISPLAY_NAME="${RUNNER_NAME}" ;; + esac + + # Restore PR checkout no matter how this step exits, + # in case a later step needs to access the PR code. + # Note that the PR code is *not* recompiled. + trap 'git checkout ${{ github.sha }} && git submodule update --init --recursive' EXIT + + # Benchmark PR branch (already built) + docker exec te-runner bash -c "$(cat <<'OUTER' + set -ex + pip install pandas tabulate + cd /workspace + + mkdir -p perf_results/pr + for bench in benchmarks/microbenchmarks/benchmark_*.py; do + name=$(basename "$bench" .py) + echo "=== Running $name (PR) ===" + python "$bench" + mv "${name}.csv" perf_results/pr/ + done + + # Stash benchmark scripts so they survive the base branch checkout + mkdir -p .perf_stash + cp benchmarks/microbenchmarks/benchmark_*.py benchmarks/microbenchmarks/compare_results.py .perf_stash/ + OUTER + )" + + # Checkout base branch (on host, where git credentials exist) + git fetch origin ${{ github.base_ref }} --depth=1 + git checkout FETCH_HEAD + git submodule update --init --recursive + + # Rebuild base, benchmark, compare, build report + docker exec \ + -e GPU_ARCH=${{ steps.container-diag.outputs.arch }} \ + te-runner bash -c "$(cat <<'OUTER' + set -ex + cd /workspace + + # Rebuild base branch + export HIP_PATH="" + export PYTORCH_ROCM_ARCH=$GPU_ARCH + export NVTE_ROCM_ARCH=$GPU_ARCH + export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts + pip install ninja + git config --global --add safe.directory '*' + pip install --no-build-isolation . 2>&1 + + # Benchmark base branch + mkdir -p perf_results/base + for bench in .perf_stash/benchmark_*.py; do + name=$(basename "$bench" .py) + echo "=== Running $name (base) ===" + python "$bench" + mv "${name}.csv" perf_results/base/ + done + + # Compare and build report + mkdir -p perf_results/reports + SUMMARY="perf_results/reports/summary.md" + DETAILS="perf_results/reports/details.md" + : > "$SUMMARY" + : > "$DETAILS" + + for pr_csv in perf_results/pr/benchmark_*.csv; do + name=$(basename "$pr_csv" .csv) + base_csv="perf_results/base/${name}.csv" + [ -f "$base_csv" ] || continue + echo "========== Comparing: $name ==========" + python .perf_stash/compare_results.py "$base_csv" "$pr_csv" \ + --bench-name "$name" \ + --summary-file "$SUMMARY" \ + >> "$DETAILS" + done + OUTER + )" + + # Assemble this runner's section + SUMMARY="perf_results/reports/summary.md" + DETAILS="perf_results/reports/details.md" + [ -f "$SUMMARY" ] || exit 0 + + SECTION_START="" + SECTION_END="" + + CI_TRIGGERED_AT="$(TZ='America/Chicago' date -d '${{ github.event.pull_request.updated_at }}' '+%Y-%m-%d %H:%M:%S %Z')" + + SECTION=$(cat <PR commit: ${{ github.sha }} | Base: \`${{ github.base_ref }}\` | ${CI_TRIGGERED_AT} + + | Benchmark suite | Median speedup | Min speedup | Max speedup | + |---|---|---|---| + $(cat "$SUMMARY") + + $(cat "$DETAILS") + ${SECTION_END} + EOF + ) + + echo "$SECTION" > /tmp/perf_section.md + + echo "" + echo "========== Performance Report ==========" + cat /tmp/perf_section.md + echo "========================================" + + # Post or update the single shared PR comment (skip under nektos act) + if [ -n "${ACT:-}" ]; then + echo "Running under nektos act, skipping PR comment." + exit 0 + fi + + COMMENT_MARKER="" + + COMMENT_ID=$(gh api "repos/${GH_REPO}/issues/${PR_NUMBER}/comments" \ + --paginate --jq ".[] | select(.body | contains(\"${COMMENT_MARKER}\")) | .id" \ + | head -1) + + if [ -n "$COMMENT_ID" ]; then + gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" --jq .body \ + > /tmp/perf_existing.md + + if grep -qF "$SECTION_START" /tmp/perf_existing.md; then + awk -v start="$SECTION_START" -v end="$SECTION_END" -v sf="/tmp/perf_section.md" ' + $0 ~ start { skip=1; while((getline l < sf)>0) print l; next } + $0 ~ end { skip=0; next } + !skip { print } + ' /tmp/perf_existing.md > /tmp/perf_comment.md + else + { cat /tmp/perf_existing.md; echo ""; cat /tmp/perf_section.md; } > /tmp/perf_comment.md + fi + + gh api "repos/${GH_REPO}/issues/comments/${COMMENT_ID}" \ + --method PATCH --field body=@/tmp/perf_comment.md + else + { + echo "${COMMENT_MARKER}" + echo "## Performance Report" + echo "" + cat /tmp/perf_section.md + } > /tmp/perf_comment.md + + gh pr comment "$PR_NUMBER" --repo "$GH_REPO" \ + --body-file /tmp/perf_comment.md + fi + - name: Check Test Failure Status if: always() run: | diff --git a/benchmarks/microbenchmarks/benchmark_attention.py b/benchmarks/microbenchmarks/benchmark_attention.py new file mode 100755 index 000000000..6f419b62d --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_attention.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Attention micro-benchmark using te.DotProductAttention. + +Benchmarks fused multi-head attention (with flash attention backend) for +model configurations with grouped-query attention (GQA). + +Models: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim + (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) +Backward FLOPs = 2 * Forward FLOPs (approximately) + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_attention.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te + +# Sweep parameters +BATCH_SIZE = 2 +SEQ_LEN_LIST = [1024, 2048, 4096, 8192] + +# (name, num_q_heads, num_kv_heads, head_dim, tp) +MODEL_CONFIGS = [ + ("Llama3-8B/TP1", 32, 8, 128, 1), + ("Llama3-8B/TP8", 32, 8, 128, 8), + ("Llama3-70B/TP8", 64, 8, 128, 8), + ("Llama3-405B/TP8", 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 64, 8, 128, 8), +] + + +def _generate_attn_test_cases(): + test_cases = [] + for (name, n_q, n_kv, hd, tp) in MODEL_CONFIGS: + q_per_gpu = n_q // tp + kv_per_gpu = n_kv // tp + if q_per_gpu < 1 or kv_per_gpu < 1: + continue + for seq_len in SEQ_LEN_LIST: + test_cases.append({ + "Case": name, + "batch": BATCH_SIZE, + "seq_len": seq_len, + "num_q_heads": q_per_gpu, + "num_kv_heads": kv_per_gpu, + "head_dim": hd, + }) + return test_cases + + +def bench_attention(batch, seq_len, num_q_heads, num_kv_heads, head_dim): + device = "cuda" + dtype = torch.bfloat16 + + attn = te.DotProductAttention( + num_attention_heads=num_q_heads, + kv_channels=head_dim, + num_gqa_groups=num_kv_heads, + attn_mask_type="causal", + ).to(device=device, dtype=dtype) + + q = torch.randn(seq_len, batch, num_q_heads, head_dim, + dtype=dtype, device=device, requires_grad=True) + k = torch.randn(seq_len, batch, num_kv_heads, head_dim, + dtype=dtype, device=device, requires_grad=True) + v = torch.randn(seq_len, batch, num_kv_heads, head_dim, + dtype=dtype, device=device, requires_grad=True) + + fwd_func = lambda: attn(q, k, v) + out = fwd_func() + grad_out = torch.randn_like(out) + + def fwd_bwd_func(): + out = attn(q, k, v) + out.backward(grad_out) + q.grad = None + k.grad = None + v.grad = None + + fwd_bwd_func() + + # FLOPs: two matmuls (Q@K^T and attn@V), each 2*b*h*s^2*d + fwd_flops = 4 * batch * num_q_heads * seq_len * seq_len * head_dim + bwd_flops = 2 * fwd_flops + + # Warmup + for _ in range(20): + fwd_func() + fwd_bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") + print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") + + return { + "TE Forward Time (ms)": f"{fwd_ms:.2f}", + "TE Forward TFLOPS": f"{fwd_tflops:.2f}", + "TE Backward Time (ms)": f"{bwd_ms:.2f}", + "TE Backward TFLOPS": f"{bwd_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_attn_test_cases() + + columns = [ + "Case", "batch", "seq_len", "num_q_heads", "num_kv_heads", "head_dim", + "TE Forward Time (ms)", + "TE Forward TFLOPS", + "TE Backward Time (ms)", + "TE Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c['Case']} b={c['batch']} s={c['seq_len']} " + f"qh={c['num_q_heads']} kvh={c['num_kv_heads']} hd={c['head_dim']}") + print(f"{'='*60}") + bench_attention(batch=c["batch"], seq_len=c["seq_len"], + num_q_heads=c["num_q_heads"], num_kv_heads=c["num_kv_heads"], + head_dim=c["head_dim"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case['Case']} b={case['batch']} s={case['seq_len']} " + f"qh={case['num_q_heads']} kvh={case['num_kv_heads']} hd={case['head_dim']}") + print(f"{'='*60}") + try: + metrics = bench_attention( + batch=case["batch"], + seq_len=case["seq_len"], + num_q_heads=case["num_q_heads"], + num_kv_heads=case["num_kv_heads"], + head_dim=case["head_dim"], + ) + row = { + "Case": case["Case"], + "batch": case["batch"], + "seq_len": case["seq_len"], + "num_q_heads": case["num_q_heads"], + "num_kv_heads": case["num_kv_heads"], + "head_dim": case["head_dim"], + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case['Case']}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_attention.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_casting.py b/benchmarks/microbenchmarks/benchmark_casting.py new file mode 100755 index 000000000..5df5ff7a3 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_casting.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +FP8 casting micro-benchmark. + +Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for +both E4M3 (activations/weights) and E5M2 (gradients) formats. + +Shapes are (M, hidden_size) matching the activation tensors from models: + - Llama 3.1 8B, 70B, 405B + - Qwen 2.5 7B, 72B + +These casts are memory-bound; we report GB/s (input + output bytes). + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_casting.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark +import transformer_engine +import transformer_engine_torch as tex +from transformer_engine.pytorch import Float8Quantizer + + +TE_FP8_E4M3 = tex.DType.kFloat8E4M3 +TE_FP8_E5M2 = tex.DType.kFloat8E5M2 + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# (model_name, hidden_size) +MODEL_HIDDEN_SIZES = [ + ("Llama3-8B", 4096), + ("Llama3-70B", 8192), + ("Llama3-405B", 16384), + ("Qwen2.5-7B", 3584), + ("Qwen2.5-72B", 8192), +] + +CAST_CONFIGS = [ + # (name, direction, fp8_dtype) + ("BF16-to-FP8-E4M3", "quantize", TE_FP8_E4M3), + ("FP8-E4M3-to-BF16", "dequantize", TE_FP8_E4M3), + ("BF16-to-FP8-E5M2", "quantize", TE_FP8_E5M2), + ("FP8-E5M2-to-BF16", "dequantize", TE_FP8_E5M2), +] + + +def _generate_cast_test_cases(): + test_cases = [] + for model_name, hidden in MODEL_HIDDEN_SIZES: + for cast_name, direction, fp8_dtype in CAST_CONFIGS: + for M in M_SIZE_LIST: + test_cases.append({ + "Case": f"{model_name}/{cast_name}", + "M": M, + "hidden_size": hidden, + "direction": direction, + "fp8_dtype": fp8_dtype, + "dtype_str": cast_name, + }) + return test_cases + + +def bench_cast(M, hidden_size, direction, fp8_dtype): + device = "cuda" + + numel = M * hidden_size + scale = torch.ones(1, dtype=torch.float32, device=device) + amax = torch.zeros(1, dtype=torch.float32, device=device) + quantizer = Float8Quantizer(scale, amax, fp8_dtype) + + if direction == "quantize": + x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) + cast_func = lambda: quantizer(x) + + # BF16 read (2 bytes) + FP8 write (1 byte) + total_bytes = numel * (2 + 1) + else: + x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device) + fp8_tensor = quantizer(x) + cast_func = lambda: fp8_tensor.dequantize() + + # FP8 read (1 byte) + BF16 write (2 bytes) + total_bytes = numel * (1 + 2) + + cast_func() + + # Warmup + for _ in range(20): + cast_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + ms = benchmark.Timer(stmt="fn()", globals={"fn": cast_func}).timeit(n_iters).mean * 1e3 + gbps = total_bytes / (ms * 1e-3) / 1e9 + + print(f" {ms:.4f} ms | {gbps:.1f} GB/s") + + return { + "Cast Time (ms)": f"{ms:.4f}", + "Cast GB/s": f"{gbps:.1f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_cast_test_cases() + + columns = [ + "Case", "M", "hidden_size", "dtype_str", + "Cast Time (ms)", + "Cast GB/s", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c['Case']} M={c['M']} hidden={c['hidden_size']}") + print(f"{'='*60}") + bench_cast(M=c["M"], hidden_size=c["hidden_size"], + direction=c["direction"], fp8_dtype=c["fp8_dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") + print(f"{'='*60}") + try: + metrics = bench_cast( + M=case["M"], + hidden_size=case["hidden_size"], + direction=case["direction"], + fp8_dtype=case["fp8_dtype"], + ) + row = { + "Case": case["Case"], + "M": case["M"], + "hidden_size": case["hidden_size"], + "dtype_str": case["dtype_str"], + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case['Case']}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_casting.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_gemm.py b/benchmarks/microbenchmarks/benchmark_gemm.py new file mode 100755 index 000000000..cd651c172 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_gemm.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# Model configurations +# Sources: +# - Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + +# - Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + +# - Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) +# https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + +# - Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) +# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + +# - Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) +# https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +MODEL_CONFIGS = [ + # (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) + ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), + ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), + ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), + ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), +] + + +def _generate_gemm_test_cases(): + test_cases = [] + + for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: + shapes = { + f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), + f"{name}-AttnOut": (hidden, (n_q * hd) // tp), + f"{name}-GateUp": ((2 * intermediate) // tp, hidden), + f"{name}-Down": (hidden, intermediate // tp), + } + + for M in M_SIZE_LIST: + for case_name, (N, K) in shapes.items(): + test_cases.append({ + "Case": case_name, + "M": M, + "N": N, + "K": K, + "dtype": torch.bfloat16, + }) + return test_cases + + +def bench_gemm(M, N, K, dtype): + device = "cuda" + + linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) + x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) + + fwd_func = lambda: linear(x) + out = fwd_func() + grad_out = torch.randn_like(out) + + def bwd_func(): + out = linear(x) + out.backward(grad_out) + # Clear grads so they don't accumulate across iterations + x.grad = None + linear.weight.grad = None + + bwd_func() + + fwd_flops = 2 * M * N * K + bwd_flops = 2 * fwd_flops # dX + dW + + # Warmup + for _ in range(20): + fwd_func() + bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") + print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") + + return { + "TE Forward Time (ms)": f"{fwd_ms:.2f}", + "TE Forward TFLOPS": f"{fwd_tflops:.2f}", + "TE Backward Time (ms)": f"{bwd_ms:.2f}", + "TE Backward TFLOPS": f"{bwd_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_gemm_test_cases() + + columns = [ + "Case", "M", "N", "K", "dtype", + "TE Forward Time (ms)", + "TE Forward TFLOPS", + "TE Backward Time (ms)", + "TE Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c}") + print(f"{'='*60}") + bench_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case}") + print(f"{'='*60}") + try: + metrics = bench_gemm( + M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_gemm.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_gemm_fp8.py b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py new file mode 100755 index 000000000..3d96edbd3 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_gemm_fp8.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +FP8 GEMM micro-benchmark using te.Linear under fp8_autocast. + +Same model shapes as benchmark_gemm.py: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Each model contributes four GEMM shapes: + QKV projection (column-parallel) N = (Qheads + 2*KVheads)*head_dim / TP, K = hidden + Attention output (row-parallel) N = hidden, K = Qheads*head_dim / TP + MLP Gate+Up (column-parallel) N = 2*intermediate / TP, K = hidden (SwiGLU) + MLP Down (row-parallel) N = hidden, K = intermediate / TP + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_gemm_fp8.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# (name, hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODEL_CONFIGS = [ + ("Llama3-8B/TP1", 4096, 14336, 32, 8, 128, 1), + ("Llama3-8B/TP8", 4096, 14336, 32, 8, 128, 8), + ("Llama3-70B/TP8", 8192, 28672, 64, 8, 128, 8), + ("Llama3-405B/TP8", 16384, 53248, 128, 8, 128, 8), + ("Qwen2.5-7B/TP1", 3584, 18944, 28, 4, 128, 1), + ("Qwen2.5-72B/TP8", 8192, 29568, 64, 8, 128, 8), +] + +FP8_RECIPE = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max", +) + + +def _generate_gemm_test_cases(): + test_cases = [] + for (name, hidden, intermediate, n_q, n_kv, hd, tp) in MODEL_CONFIGS: + shapes = { + f"{name}-QKV": ((n_q * hd + 2 * n_kv * hd) // tp, hidden), + f"{name}-AttnOut": (hidden, (n_q * hd) // tp), + f"{name}-GateUp": ((2 * intermediate) // tp, hidden), + f"{name}-Down": (hidden, intermediate // tp), + } + for M in M_SIZE_LIST: + for case_name, (N, K) in shapes.items(): + test_cases.append({ + "Case": case_name, + "M": M, + "N": N, + "K": K, + "dtype": torch.bfloat16, + }) + return test_cases + + +def bench_fp8_gemm(M, N, K, dtype): + device = "cuda" + + linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype) + x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True) + grad_out = torch.randn(M, N, dtype=dtype, device=device) + + # Forward under fp8_autocast + def fwd_func(): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + return linear(x) + + # Combined fwd+bwd (TE consumes saved state on backward, no retain_graph) + def fwd_bwd_func(): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + out = linear(x) + out.backward(grad_out) + x.grad = None + linear.weight.grad = None + + # Sanity run + fwd_func() + fwd_bwd_func() + + fwd_flops = 2 * M * N * K + bwd_flops = 2 * fwd_flops + + # Warmup + for _ in range(20): + fwd_func() + fwd_bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_tflops = fwd_flops / (fwd_ms * 1e-3) / 1e12 + bwd_tflops = bwd_flops / (bwd_ms * 1e-3) / 1e12 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS") + print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)") + + return { + "FP8 Forward Time (ms)": f"{fwd_ms:.2f}", + "FP8 Forward TFLOPS": f"{fwd_tflops:.2f}", + "FP8 Backward Time (ms)": f"{bwd_ms:.2f}", + "FP8 Backward TFLOPS": f"{bwd_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_gemm_test_cases() + + columns = [ + "Case", "M", "N", "K", "dtype", + "FP8 Forward Time (ms)", + "FP8 Forward TFLOPS", + "FP8 Backward Time (ms)", + "FP8 Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c}") + print(f"{'='*60}") + bench_fp8_gemm(M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case}") + print(f"{'='*60}") + try: + metrics = bench_fp8_gemm( + M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_gemm_fp8.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_grouped_gemm.py b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py new file mode 100755 index 000000000..7cee6edd4 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_grouped_gemm.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import os +import torch +import torch.utils.benchmark as benchmark + +def generate_grouped_gemm_group_lens(b, m, balance: bool): + if balance: + return torch.full((b,), m, dtype=torch.int64) + else: + dist = 0.2 + 0.8 * torch.rand(b) + dist /= dist.sum() + group_lens = (dist * b * m).to(torch.int64) + error = b * m - group_lens.sum() + group_lens[-1] += error + return group_lens + +M_SIZE_LIST = [512, 1024, 2048, 4096]#, 8192, 16384] +EP_SIZE_LIST = [32, 16, 8] + + +def _generate_moe_test_cases( + name_prefix: str, + n_routed_experts: int, + moe_intermediate_size: int, + hidden_size: int, +): + test_cases = [] + shapes_dict = { + f"{name_prefix}-GateUP": (2 * moe_intermediate_size, hidden_size), + f"{name_prefix}-Down": (hidden_size, moe_intermediate_size), + } + + for ep in EP_SIZE_LIST: + if n_routed_experts % ep != 0: + continue + B = n_routed_experts // ep + if B < 1: + continue + for M in M_SIZE_LIST: + for name, (N, K) in shapes_dict.items(): + for dtype in [torch.bfloat16]: + test_cases.append( + { + "Case": name, + "B": B, + "M": M, + "N": N, + "K": K, + "dtype": dtype, + } + ) + return test_cases + + +def generate_deepseekv3_test_cases(): + return _generate_moe_test_cases( + "DSV3", n_routed_experts=256, moe_intermediate_size=2048, hidden_size=7168 + ) + + +def generate_deepseekv2_test_cases(): + return _generate_moe_test_cases( + "DSV2", n_routed_experts=160, moe_intermediate_size=1536, hidden_size=5120 + ) + + +def generate_deepseekv2_lite_test_cases(): + return _generate_moe_test_cases( + "DSV2-Lite", n_routed_experts=64, moe_intermediate_size=1408, hidden_size=2048 + ) + + +def generate_grok_v2_test_cases(): + return _generate_moe_test_cases( + "Grok-V2", n_routed_experts=8, moe_intermediate_size=16384, hidden_size=8192 + ) + + +def make_fwd_bwd_funcs_te(x, w, group_lens, activation_dtype): + from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace + from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm + + B = int(group_lens.numel()) + N = int(w.shape[1]) + K = int(w.shape[2]) + + m_splits = [int(v) for v in group_lens.tolist()] + assert len(m_splits) == B + sum_M = sum(m_splits) + assert x.numel() > 0 and x.shape[0] == sum_M + + x_view = x.reshape(-1, x.shape[-1]) + xs = list(torch.split(x_view, m_splits)) + weights = [w[i] for i in range(B)] + + workspaces = get_multi_stream_cublas_workspace() + + # Forward output buffer + out = torch.empty((sum_M, N), device=x.device, dtype=activation_dtype) + + def fwd_func_te(): + general_grouped_gemm( + A=weights, + B=xs, + out=[out], + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=True, + m_splits=m_splits, + use_bias=False, + bias=None, + layout="TN", + ) + return out + + # dx buffers + dx = torch.empty((sum_M, K), device=x.device, dtype=activation_dtype) + dxs = list(torch.split(dx, m_splits)) + + # dw buffers + dw_stacked = torch.empty((B, N, K), device=x.device, dtype=activation_dtype) + dws = [dw_stacked[i] for i in range(B)] + + def bwd_func_te(grad_out): + go = grad_out.view(-1, grad_out.shape[-1]) + splits = torch.split(go, m_splits) + + general_grouped_gemm( + A=weights, + B=splits, + out=dxs, + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=False, + layout="NN", + m_splits=m_splits, + grad=False, + use_bias=False, + bias=None, + ) + + general_grouped_gemm( + A=xs, + B=splits, + out=dws, + out_dtype=activation_dtype, + workspaces=workspaces, + single_output=False, + layout="NT", + m_splits=m_splits, + grad=False, + use_bias=False, + bias=None, + accumulate=False, + ) + + return dx, dw_stacked + + return fwd_func_te, bwd_func_te + + +def bench_grouped_gemm(B, M, N, K, dtype): + device = "cuda" + + x = torch.randn((B * M, K), dtype=dtype, device=device, requires_grad=True) + w = torch.randn((B, N, K), dtype=dtype, device=device, requires_grad=True) + group_lens = generate_grouped_gemm_group_lens(B, M, balance=True).to(device) + print("group_lens: ", group_lens) + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" + + # TE grouped (CK_Tile) + x_te = x.clone().detach() + w_te = w.clone().detach() + fwd_func_te, bwd_func_te_inner = make_fwd_bwd_funcs_te( + x_te, w_te, group_lens, activation_dtype=dtype + ) + + out_te = fwd_func_te() + grad_out = torch.randn_like(out_te) + bwd_func_te = lambda: bwd_func_te_inner(grad_out) + dx_te, dw_te = bwd_func_te() + + # FLOPs + fwd_total_flops = 2 * B * M * N * K + bwd_total_flops = 2 * fwd_total_flops + + # Warmup + for _ in range(20): + fwd_func_te() + bwd_func_te() + + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func_te}).timeit(n_iters).mean * 1e3 + bwd_te_ms = benchmark.Timer(stmt="fn()", globals={"fn": bwd_func_te}).timeit(n_iters).mean * 1e3 + + fwd_te_tflops = fwd_total_flops / (fwd_te_ms * 1e-3) / 1e12 + bwd_te_tflops = bwd_total_flops / (bwd_te_ms * 1e-3) / 1e12 + + print(f"TE (CK_Tile) Forward {fwd_te_ms:.3f} ms | {fwd_te_tflops:.2f} TFLOPS") + print(f"TE (CK_Tile) Backward {bwd_te_ms:.3f} ms | {bwd_te_tflops:.2f} TFLOPS") + + return { + "TE (CK_Tile) Forward Time (ms)": f"{fwd_te_ms:.2f}", + "TE (CK_Tile) Forward TFLOPS": f"{fwd_te_tflops:.2f}", + "TE (CK_Tile) Backward Time (ms)": f"{bwd_te_ms:.2f}", + "TE (CK_Tile) Backward TFLOPS": f"{bwd_te_tflops:.2f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = ( + generate_deepseekv2_lite_test_cases() + + generate_deepseekv2_test_cases() + + generate_deepseekv3_test_cases() + + generate_grok_v2_test_cases() + ) + + columns = [ + "Case", "B", "M", "N", "K", "dtype", + "TE (CK_Tile) Forward Time (ms)", + "TE (CK_Tile) Forward TFLOPS", + "TE (CK_Tile) Backward Time (ms)", + "TE (CK_Tile) Backward TFLOPS", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*50}") + print(f"WARMUP: {c}") + print(f"{'='*50}") + bench_grouped_gemm(B=c["B"], M=c["M"], N=c["N"], K=c["K"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*50}") + print(f"Testing: {case}") + print(f"{'='*50}") + try: + metrics = bench_grouped_gemm( + B=case["B"], M=case["M"], N=case["N"], K=case["K"], dtype=case["dtype"] + ) + row = { + "Case": case["Case"], + "B": case["B"], + "M": case["M"], + "N": case["N"], + "K": case["K"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_grouped_gemm.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/benchmark_normalization.py b/benchmarks/microbenchmarks/benchmark_normalization.py new file mode 100755 index 000000000..1caa04f43 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_normalization.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Normalization micro-benchmark using te.LayerNorm and te.RMSNorm. + +Shapes are derived from training workloads: + - Llama 3 8B, 70B, 405B (all use RMSNorm) + - Qwen 2.5 7B, 72B (all use RMSNorm) + +Modern models predominantly use RMSNorm, but we benchmark both +LayerNorm and RMSNorm since TE supports both and they share the +same kernel infrastructure. + +The M dimension (batch * seq_len) is swept across typical training sizes. + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Output: benchmark_normalization.csv (written to cwd) +""" + +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te + +# Sequence / batch-token sizes to sweep +M_SIZE_LIST = [1024, 2048, 4096, 8192] + +# (model_name, hidden_size) +MODEL_HIDDEN_SIZES = [ + ("Llama3-8B", 4096), + ("Llama3-70B", 8192), + ("Llama3-405B", 16384), + ("Qwen2.5-7B", 3584), + ("Qwen2.5-72B", 8192), +] + +NORM_TYPES = [ + ("RMSNorm", te.RMSNorm), + ("LayerNorm", te.LayerNorm), +] + + +def _generate_norm_test_cases(): + test_cases = [] + for model_name, hidden in MODEL_HIDDEN_SIZES: + for norm_name, norm_cls in NORM_TYPES: + for M in M_SIZE_LIST: + test_cases.append({ + "Case": f"{model_name}/{norm_name}", + "M": M, + "hidden_size": hidden, + "norm_name": norm_name, + "norm_cls": norm_cls, + "dtype": torch.bfloat16, + }) + return test_cases + + +def bench_norm(M, hidden_size, norm_cls, dtype): + device = "cuda" + + norm = norm_cls(hidden_size).to(device=device, dtype=dtype) + x = torch.randn(M, hidden_size, dtype=dtype, device=device, requires_grad=True) + + fwd_func = lambda: norm(x) + out = fwd_func() + grad_out = torch.randn_like(out) + + def fwd_bwd_func(): + out = norm(x) + out.backward(grad_out) + x.grad = None + for p in norm.parameters(): + p.grad = None + + fwd_bwd_func() + + # Normalization is memory-bound; report bandwidth instead of FLOPS. + # Each element is read once (fwd) or read+written (bwd), plus the + # weight/bias vectors. We report effective GB/s based on the + # minimum data movement: fwd reads x and writes y, bwd reads + # grad_out+x+saved_stats and writes grad_x+grad_weight. + elem_bytes = x.element_size() + fwd_bytes = 2 * M * hidden_size * elem_bytes # read x, write y + bwd_bytes = 4 * M * hidden_size * elem_bytes # read grad+x+y, write grad_x + + # Warmup + for _ in range(20): + fwd_func() + fwd_bwd_func() + torch.cuda.synchronize() + + # Benchmark + n_iters = 100 + + fwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_func}).timeit(n_iters).mean * 1e3 + fwd_bwd_ms = benchmark.Timer(stmt="fn()", globals={"fn": fwd_bwd_func}).timeit(n_iters).mean * 1e3 + + bwd_ms = max(fwd_bwd_ms - fwd_ms, 0.0) + + fwd_gbps = fwd_bytes / (fwd_ms * 1e-3) / 1e9 + bwd_gbps = bwd_bytes / (bwd_ms * 1e-3) / 1e9 if bwd_ms > 0 else 0.0 + + print(f" Forward {fwd_ms:.3f} ms | {fwd_gbps:.1f} GB/s") + print(f" Backward {bwd_ms:.3f} ms | {bwd_gbps:.1f} GB/s (derived)") + + return { + "TE Forward Time (ms)": f"{fwd_ms:.4f}", + "TE Forward GB/s": f"{fwd_gbps:.1f}", + "TE Backward Time (ms)": f"{bwd_ms:.4f}", + "TE Backward GB/s": f"{bwd_gbps:.1f}", + } + + +if __name__ == "__main__": + import pandas as pd + + test_cases = _generate_norm_test_cases() + + columns = [ + "Case", "M", "hidden_size", "dtype", + "TE Forward Time (ms)", + "TE Forward GB/s", + "TE Backward Time (ms)", + "TE Backward GB/s", + ] + rows = [] + + # Warmup run + c = test_cases[0] + print(f"\n{'='*60}") + print(f"WARMUP: {c['Case']} M={c['M']} hidden={c['hidden_size']}") + print(f"{'='*60}") + bench_norm(M=c["M"], hidden_size=c["hidden_size"], + norm_cls=c["norm_cls"], dtype=c["dtype"]) + + for case in test_cases: + print(f"\n{'='*60}") + print(f"Testing: {case['Case']} M={case['M']} hidden={case['hidden_size']}") + print(f"{'='*60}") + try: + metrics = bench_norm( + M=case["M"], + hidden_size=case["hidden_size"], + norm_cls=case["norm_cls"], + dtype=case["dtype"], + ) + row = { + "Case": case["Case"], + "M": case["M"], + "hidden_size": case["hidden_size"], + "dtype": str(case["dtype"]), + **metrics, + } + rows.append(row) + except Exception as e: + print(f"FAILED: {case['Case']}: {e}") + raise + + results = pd.DataFrame(rows, columns=columns) + + out_csv = "benchmark_normalization.csv" + results.to_csv(out_csv, index=False) + print(f"\nResults saved to {out_csv}") diff --git a/benchmarks/microbenchmarks/compare_results.py b/benchmarks/microbenchmarks/compare_results.py new file mode 100755 index 000000000..7353066bc --- /dev/null +++ b/benchmarks/microbenchmarks/compare_results.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Compare two CSVs from the same benchmark (base branch vs PR branch). + +Auto-detects metric columns (containing "TFLOPS"/ "GB/s") and key columns. +Outputs a markdown
block to stdout with per-config results, +and optionally appends a summary table row to --summary-file. + +Usage: + python compare_results.py base.csv pr.csv --bench-name NAME --summary-file FILE +""" + +import argparse +import sys + +import numpy as np +import pandas as pd + +SKIP_COLS = {"TestID", "Label"} + + +def auto_detect_columns(df): + metric_cols = [c for c in df.columns if "TFLOPS" in c or "GB/s" in c] + key_cols = [ + c for c in df.columns + if c not in metric_cols and c not in SKIP_COLS + and "Time" not in c + ] + return key_cols, metric_cols + + +def main(): + parser = argparse.ArgumentParser(description="Compare benchmark CSVs") + parser.add_argument("base_csv", help="Base branch CSV") + parser.add_argument("pr_csv", help="PR branch CSV") + parser.add_argument("--bench-name", default="benchmark", + help="Benchmark name for markdown output") + parser.add_argument("--summary-file", default=None, + help="Append a summary table row (markdown) to this file") + args = parser.parse_args() + + base_df = pd.read_csv(args.base_csv) + pr_df = pd.read_csv(args.pr_csv) + + key_cols, metric_cols = auto_detect_columns(base_df) + + if not metric_cols: + print("No metric columns found.") + return 0 + + for col in metric_cols: + base_df[col] = pd.to_numeric(base_df[col], errors="coerce") + pr_df[col] = pd.to_numeric(pr_df[col], errors="coerce") + + merged = base_df.merge(pr_df, on=key_cols, suffixes=("_base", "_pr"), how="inner") + if merged.empty: + print("WARNING: No matching rows between base and PR.") + return 0 + + all_speedups = [] + per_row_data = [] + + for idx in merged.index: + row_keys = {k: merged.loc[idx, k] for k in key_cols} + row_metrics = {} + + for metric in metric_cols: + bc, pc = f"{metric}_base", f"{metric}_pr" + bv = merged.loc[idx, bc] + pv = merged.loc[idx, pc] + + if pd.isna(bv) or pd.isna(pv) or bv < 0.5: + continue + + speedup = pv / bv + all_speedups.append(speedup) + row_metrics[metric] = {"base": bv, "pr": pv, "speedup": speedup} + + if row_metrics: + per_row_data.append({"keys": row_keys, "metrics": row_metrics}) + + if not all_speedups: + print("WARNING: No valid comparisons found.") + return 0 + + speedups = np.array(all_speedups) + median_sp = float(np.median(speedups)) + min_sp = float(np.min(speedups)) + max_sp = float(np.max(speedups)) + + # Details block + print("
") + print(f"{args.bench_name} " + f"(median {median_sp:.3f}x, min {min_sp:.3f}x, max {max_sp:.3f}x)") + print() + + header_cols = list(key_cols) + for m in metric_cols: + short = m.replace(" TFLOPS", "") + header_cols.extend([f"{short} Base", f"{short} PR", f"{short} Speedup"]) + + print("| " + " | ".join(header_cols) + " |") + print("|" + "|".join(["---"] * len(header_cols)) + "|") + + for row in per_row_data: + cells = [str(row["keys"].get(k, "")) for k in key_cols] + for metric in metric_cols: + if metric in row["metrics"]: + v = row["metrics"][metric] + cells.append(f"{v['base']:.2f}") + cells.append(f"{v['pr']:.2f}") + cells.append(f"{v['speedup']:.3f}x") + else: + cells.extend(["", "", ""]) + print("| " + " | ".join(cells) + " |") + + print() + print("
") + print() + + # Summary row + if args.summary_file: + with open(args.summary_file, "a") as f: + f.write(f"| {args.bench_name} | {median_sp:.3f}x | {min_sp:.3f}x | {max_sp:.3f}x |\n") + + return 0 + + +if __name__ == "__main__": + sys.exit(main())