From d7c643c04434494b3707073128e7690b366abd16 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 16 Mar 2026 17:52:30 -0500 Subject: [PATCH 1/5] Initial benchmark porting to ASV --- .github/workflows/rocm-ci.yml | 73 +++++++++++ .gitignore | 1 + asv.conf.json | 16 +++ benchmarks/asv/README.md | 166 ++++++++++++++++++++++++++ benchmarks/asv/__init__.py | 0 benchmarks/asv/bench_attention.py | 56 +++++++++ benchmarks/asv/bench_casting.py | 51 ++++++++ benchmarks/asv/bench_gemm.py | 55 +++++++++ benchmarks/asv/bench_gemm_fp8.py | 60 ++++++++++ benchmarks/asv/bench_grouped_gemm.py | 64 ++++++++++ benchmarks/asv/bench_normalization.py | 36 ++++++ 11 files changed, 578 insertions(+) create mode 100644 asv.conf.json create mode 100644 benchmarks/asv/README.md create mode 100644 benchmarks/asv/__init__.py create mode 100644 benchmarks/asv/bench_attention.py create mode 100644 benchmarks/asv/bench_casting.py create mode 100644 benchmarks/asv/bench_gemm.py create mode 100644 benchmarks/asv/bench_gemm_fp8.py create mode 100644 benchmarks/asv/bench_grouped_gemm.py create mode 100644 benchmarks/asv/bench_normalization.py diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 32c3cb2a2..c35e5d0ad 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -368,6 +368,79 @@ jobs: EOF )" + - name: Restore previous ASV results + if: github.event_name == 'push' && github.ref_name == 'dev' + continue-on-error: true + env: + ARTIFACTORY_API_KEY: ${{ secrets.ARTIFACTORY_API_KEY }} + run: | + set -x + BASE_URL="https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/asv-results" + ARTIFACT_URL="${BASE_URL}/${{ matrix.runner }}/results.tar.gz" + + curl -sf -H "X-JFrog-Art-Api:${ARTIFACTORY_API_KEY}" \ + -o /tmp/asv-results.tar.gz "$ARTIFACT_URL" || { + echo "::notice::No previous ASV results found. Starting fresh." + exit 0 + } + + mkdir -p asv-results + tar xzf /tmp/asv-results.tar.gz -C asv-results/ + + # Copy into the container's ASV results directory + docker exec te-runner mkdir -p /workspace/benchmarks/.asv/results + docker cp asv-results/. te-runner:/workspace/benchmarks/.asv/results/ + echo "Restored previous ASV results from Artifactory." + + - name: Performance benchmarks (ASV) + if: github.event_name == 'push' && github.ref_name == 'dev' + continue-on-error: true + env: + RUNNER_NAME: ${{ matrix.runner }} + run: | + set -ex + + # Derive a stable machine name from the runner label + case "${RUNNER_NAME}" in + linux-te-mi325*) MACHINE_NAME="mi325" ;; + linux-te-mi355*) MACHINE_NAME="mi355" ;; + *) MACHINE_NAME="${RUNNER_NAME}" ;; + esac + + docker exec -e MACHINE_NAME="$MACHINE_NAME" te-runner bash -c "$(cat <<'OUTER' + set -ex + pip install asv + cd /workspace + asv machine --yes --machine "$MACHINE_NAME" + asv run --python=same --launch-method spawn \ + 2>&1 | tee /workspace/asv_results.txt + OUTER + )" + + # Copy results out of the container for upload + rm -rf asv-results + docker cp te-runner:/workspace/benchmarks/.asv/results/. ./asv-results/ || true + + - name: Upload ASV results + if: github.event_name == 'push' && github.ref_name == 'dev' + continue-on-error: true + env: + ARTIFACTORY_API_KEY: ${{ secrets.ARTIFACTORY_API_KEY }} + run: | + set -ex + if [[ ! -d asv-results ]] || [[ -z "$(ls -A asv-results)" ]]; then + echo "::notice::No ASV results to upload." + exit 0 + fi + + BASE_URL="https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/asv-results" + tar czf /tmp/asv-results.tar.gz -C asv-results . + + curl -sf -H "X-JFrog-Art-Api:${ARTIFACTORY_API_KEY}" \ + -T /tmp/asv-results.tar.gz \ + "${BASE_URL}/${{ matrix.runner }}/results.tar.gz" + echo "Uploaded ASV results to Artifactory." + - name: Check Test Failure Status if: always() run: | diff --git a/.gitignore b/.gitignore index d3b18b358..a5fd89b4b 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ artifacts/ **/times.csv transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* +.asv/ diff --git a/asv.conf.json b/asv.conf.json new file mode 100644 index 000000000..dc71bf345 --- /dev/null +++ b/asv.conf.json @@ -0,0 +1,16 @@ +{ + "version": 1, + "project": "TransformerEngine", + "project_url": "https://github.com/ROCm/TransformerEngine", + "repo": ".", + "branches": ["dev"], + "environment_type": "existing", + "install_command": [], + "build_command": [], + "benchmark_dir": "benchmarks/asv", + "results_dir": "benchmarks/.asv/results", + "html_dir": "benchmarks/.asv/html", + "install_timeout": 600, + "benchmark_timeout": 1200, + "launch_method": "spawn" +} diff --git a/benchmarks/asv/README.md b/benchmarks/asv/README.md new file mode 100644 index 000000000..7de4fd6c5 --- /dev/null +++ b/benchmarks/asv/README.md @@ -0,0 +1,166 @@ +# ASV Benchmarks for TransformerEngine + +Performance benchmarks built on [ASV (Air Speed Velocity)](https://asv.readthedocs.io/), +a framework for benchmarking Python packages over their lifetime. + +## Prerequisites + +- TransformerEngine must already be built and installed in the current Python environment. +- A ROCm or CUDA GPU must be available. +- Install ASV: `pip install asv` + +ASV is configured with `environment_type: "existing"` (in `asv.conf.json` at the repo root), +meaning it uses the current Python environment directly — it does not create virtualenvs or +attempt to build TE itself. + +## Local usage + +All commands are run from the **repository root** (where `asv.conf.json` lives). + +### Register your machine + +```bash +asv machine --yes --machine my-machine-name +``` + +This creates a machine profile in `benchmarks/.asv/results/my-machine-name/machine.json`. +Use a descriptive name (e.g., `mi325`, `mi300x-dev`) — results are stored per machine, so +the name must be consistent across runs for historical comparison. + +### Run all benchmarks + +```bash +asv run --python=same --launch-method spawn +``` + +- `--python=same` — use the current interpreter (required with `environment_type: "existing"`) +- `--launch-method spawn` — required for CUDA (fork causes "Cannot re-initialize CUDA in forked subprocess") + +### Run a single suite + +```bash +asv run --python=same --launch-method spawn --bench bench_casting +``` + +The `--bench` argument accepts a regex that matches benchmark file or class names. + +### Quick smoke test + +```bash +asv run --python=same --launch-method spawn --quick --bench bench_casting +``` + +`--quick` runs each benchmark only once with no statistical analysis. Useful for verifying +benchmarks work, but note that results are **not saved to disk** in quick mode. + +### Compare two commits + +```bash +asv continuous --python=same --launch-method spawn HEAD~1 HEAD +``` + +This checks out each commit, runs benchmarks on both, and reports regressions. +Note: this only works if the benchmark files exist on both commits. + +### Generate an HTML dashboard + +```bash +asv publish +asv preview +``` + +`asv publish` generates static HTML from stored results into `benchmarks/.asv/html/`. +`asv preview` serves it locally on `http://localhost:8080`. + +## How results are stored + +### Local results + +ASV stores results as JSON files under `benchmarks/.asv/results/`: + +``` +benchmarks/.asv/results/ + my-machine-name/ + machine.json # Hardware/OS metadata + .json # Timing results for that commit + .json + ... +``` + +Each commit JSON contains the wall-clock timings for every benchmark + parameter combination +run on that machine. The `benchmarks/.asv/` directory is in `.gitignore`. + +### CI results (Artifactory) + +In CI, benchmarks run **only on pushes to `dev`** (not on PRs). This builds a historical +record of performance on the main branch. + +The CI pipeline (`.github/workflows/rocm-ci.yml`) follows this flow: + +1. **Restore** — Download `results.tar.gz` from Artifactory for the current runner +2. **Benchmark** — Run `asv run`, which appends a new `{commit}.json` to the results directory +3. **Upload** — Tar up the results directory and upload back to Artifactory + +Results are stored per machine at: +``` +https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/asv-results/ + linux-te-mi325-8/results.tar.gz + linux-te-mi355-8/results.tar.gz +``` + +Each tarball contains the full ASV results directory for that machine, accumulating +a new commit JSON on every push to `dev`. ASV machine names map to hardware: +`mi325` for MI325X runners, `mi355` for MI355X runners. + +### Downloading CI results locally + +To inspect CI results on your local machine (requires Artifactory access): + +```bash +# Download results for a specific machine +curl -sf -H "X-JFrog-Art-Api:${ARTIFACTORY_API_KEY}" \ + -o results.tar.gz \ + "https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/asv-results/linux-te-mi325-8/results.tar.gz" + +# Extract into your local ASV results directory +mkdir -p benchmarks/.asv/results +tar xzf results.tar.gz -C benchmarks/.asv/results/ + +# Generate and view the dashboard +asv publish +asv preview +``` + +This can also be provided statically via github pages. + +## Writing new benchmarks + +Create a new file in `benchmarks/asv/` following the naming convention `bench_.py`. + +```python +import torch +import transformer_engine.pytorch as te + +class BenchSomething: + params = [[1024, 4096], ["config_a", "config_b"]] + param_names = ["M", "config"] + timeout = 300 # seconds, per parameter combination + + def setup(self, M, config): + # Allocate tensors, create modules. + # This runs before each time_* method but is NOT timed. + ... + + def time_forward(self, M, config): + # ASV times this method (adaptive iterations + statistics). + # MUST call torch.cuda.synchronize() to ensure GPU work completes. + self.module(self.x) + torch.cuda.synchronize() +``` + +Key rules: +- Method names starting with `time_` are automatically timed by ASV. +- Always call `torch.cuda.synchronize()` at the end of `time_*` methods. +- Clear `.grad` attributes in backward benchmarks to prevent memory accumulation. +- ASV runs each `time_*` method in a **separate subprocess** — no shared state between methods. +- The `params` list defines a cross-product; keep the matrix size reasonable. diff --git a/benchmarks/asv/__init__.py b/benchmarks/asv/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/asv/bench_attention.py b/benchmarks/asv/bench_attention.py new file mode 100644 index 000000000..9c64888f6 --- /dev/null +++ b/benchmarks/asv/bench_attention.py @@ -0,0 +1,56 @@ +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Fused multi-head attention (GQA) benchmarks via te.DotProductAttention. + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim +Backward FLOPs ~ 2x forward +""" + +import torch +import transformer_engine.pytorch as te + +BATCH = 2 + +# (num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (32, 8, 128, 1), + "Llama3-8B_TP8": (32, 8, 128, 8), + "Llama3-70B_TP8": (64, 8, 128, 8), + "Llama3-405B_TP8": (128, 8, 128, 8), + "Qwen2.5-7B_TP1": (28, 4, 128, 1), + "Qwen2.5-72B_TP8": (64, 8, 128, 8), +} + + +class BenchAttention: + params = [[1024, 2048, 4096, 8192], list(MODELS)] + param_names = ["seq_len", "model"] + timeout = 300 + + def setup(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh, kvh = n_q // tp, n_kv // tp + dtype = torch.bfloat16 + + self.attn = te.DotProductAttention( + num_attention_heads=qh, kv_channels=hd, + num_gqa_groups=kvh, attn_mask_type="causal", + ).to(device="cuda", dtype=dtype) + + self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v)) + + def time_forward(self, seq_len, model): + self.attn(self.q, self.k, self.v) + torch.cuda.synchronize() + + def time_forward_backward(self, seq_len, model): + out = self.attn(self.q, self.k, self.v) + out.backward(self.grad_out) + self.q.grad = self.k.grad = self.v.grad = None + torch.cuda.synchronize() diff --git a/benchmarks/asv/bench_casting.py b/benchmarks/asv/bench_casting.py new file mode 100644 index 000000000..7195a01ab --- /dev/null +++ b/benchmarks/asv/bench_casting.py @@ -0,0 +1,51 @@ +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""FP8 casting micro-benchmarks. + +Memory-bound quantization/dequantization between BF16 and FP8 formats. +""" + +import torch + +if hasattr(torch, "float8_e4m3fnuz"): + FP8_E4M3 = torch.float8_e4m3fnuz + FP8_E5M2 = torch.float8_e5m2fnuz +else: + FP8_E4M3 = torch.float8_e4m3fn + FP8_E5M2 = torch.float8_e5m2 + +HIDDEN_SIZES = { + "Llama3-8B": 4096, + "Llama3-70B": 8192, + "Llama3-405B": 16384, + "Qwen2.5-7B": 3584, + "Qwen2.5-72B": 8192, +} + +CAST_CONFIGS = { + "BF16_to_E4M3": (torch.bfloat16, FP8_E4M3), + "E4M3_to_BF16": (FP8_E4M3, torch.bfloat16), + "BF16_to_E5M2": (torch.bfloat16, FP8_E5M2), + "E5M2_to_BF16": (FP8_E5M2, torch.bfloat16), +} + + +class BenchCasting: + params = [[1024, 2048, 4096, 8192], list(HIDDEN_SIZES), list(CAST_CONFIGS)] + param_names = ["M", "model", "cast"] + timeout = 120 + + def setup(self, M, model, cast): + hidden = HIDDEN_SIZES[model] + src_dtype, self.dst_dtype = CAST_CONFIGS[cast] + if src_dtype in (FP8_E4M3, FP8_E5M2): + self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda").to(src_dtype) + else: + self.x = torch.randn(M, hidden, dtype=src_dtype, device="cuda") + + def time_cast(self, M, model, cast): + self.x.to(self.dst_dtype) + torch.cuda.synchronize() diff --git a/benchmarks/asv/bench_gemm.py b/benchmarks/asv/bench_gemm.py new file mode 100644 index 000000000..6a09a08b5 --- /dev/null +++ b/benchmarks/asv/bench_gemm.py @@ -0,0 +1,55 @@ +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""BF16 GEMM benchmarks via te.Linear. + +GEMM shapes derived from transformer layer projections: + QKV, AttnOut, GateUp (SwiGLU), Down. +""" + +import torch +import transformer_engine.pytorch as te + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +# Pre-compute (N, K) for each GEMM shape +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + + +class BenchGemm: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.linear(self.x)) + + def time_forward(self, M, shape): + self.linear(self.x) + torch.cuda.synchronize() + + def time_forward_backward(self, M, shape): + out = self.linear(self.x) + out.backward(self.grad_out) + self.x.grad = None + self.linear.weight.grad = None + torch.cuda.synchronize() diff --git a/benchmarks/asv/bench_gemm_fp8.py b/benchmarks/asv/bench_gemm_fp8.py new file mode 100644 index 000000000..9d70d8879 --- /dev/null +++ b/benchmarks/asv/bench_gemm_fp8.py @@ -0,0 +1,60 @@ +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""FP8 GEMM benchmarks via te.Linear under fp8_autocast. + +Same shapes as bench_gemm.py but with FP8 quantized compute. +""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + +FP8_RECIPE = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max", +) + + +class BenchGemmFP8: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn(M, N, dtype=dtype, device="cuda") + + def time_forward(self, M, shape): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + self.linear(self.x) + torch.cuda.synchronize() + + def time_forward_backward(self, M, shape): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + out = self.linear(self.x) + out.backward(self.grad_out) + self.x.grad = None + self.linear.weight.grad = None + torch.cuda.synchronize() diff --git a/benchmarks/asv/bench_grouped_gemm.py b/benchmarks/asv/bench_grouped_gemm.py new file mode 100644 index 000000000..3c35737f5 --- /dev/null +++ b/benchmarks/asv/bench_grouped_gemm.py @@ -0,0 +1,64 @@ +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Grouped GEMM benchmarks via te.GroupedLinear. + +MoE model configurations with GateUp and Down projections. +""" + +import torch +import transformer_engine.pytorch as te + +# (n_routed_experts, moe_intermediate_size, hidden_size) +MOE_MODELS = { + "DSV2-Lite": (64, 1408, 2048), + "DSV2": (160, 1536, 5120), + "DSV3": (256, 2048, 7168), + "Grok-V2": (8, 16384, 8192), +} + +# Build (config_key -> (num_gemms, N, K)) mapping +CONFIGS = {} +for model, (n_experts, inter, hidden) in MOE_MODELS.items(): + for ep in [32, 16, 8]: + if n_experts % ep != 0: + continue + B = n_experts // ep + CONFIGS[f"{model}_EP{ep}-GateUp"] = (B, 2 * inter, hidden) + CONFIGS[f"{model}_EP{ep}-Down"] = (B, hidden, inter) + + +class BenchGroupedGemm: + params = [[512, 1024, 2048, 4096], list(CONFIGS)] + param_names = ["M", "config"] + timeout = 300 + + def setup(self, M, config): + B, N, K = CONFIGS[config] + dtype = torch.bfloat16 + + self.module = te.GroupedLinear( + num_gemms=B, in_features=K, out_features=N, bias=False, + ).to(device="cuda", dtype=dtype) + + self.xs = [ + torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + for _ in range(B) + ] + outs = self.module(self.xs) + self.grad_outs = [torch.randn_like(o) for o in outs] + + def time_forward(self, M, config): + self.module(self.xs) + torch.cuda.synchronize() + + def time_forward_backward(self, M, config): + outs = self.module(self.xs) + torch.autograd.backward(outs, self.grad_outs) + for x in self.xs: + x.grad = None + for p in self.module.parameters(): + p.grad = None + torch.cuda.synchronize() diff --git a/benchmarks/asv/bench_normalization.py b/benchmarks/asv/bench_normalization.py new file mode 100644 index 000000000..f68b60a51 --- /dev/null +++ b/benchmarks/asv/bench_normalization.py @@ -0,0 +1,36 @@ +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""RMSNorm and LayerNorm benchmarks on activation-sized tensors.""" + +import torch +import transformer_engine.pytorch as te + +NORMS = {"RMSNorm": te.RMSNorm, "LayerNorm": te.LayerNorm} +HIDDEN_SIZES = [3584, 4096, 8192, 16384] + + +class BenchNormalization: + params = [[1024, 2048, 4096, 8192], HIDDEN_SIZES, list(NORMS)] + param_names = ["M", "hidden", "norm_type"] + timeout = 120 + + def setup(self, M, hidden, norm_type): + dtype = torch.bfloat16 + self.norm = NORMS[norm_type](hidden).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, hidden, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.norm(self.x)) + + def time_forward(self, M, hidden, norm_type): + self.norm(self.x) + torch.cuda.synchronize() + + def time_forward_backward(self, M, hidden, norm_type): + out = self.norm(self.x) + out.backward(self.grad_out) + self.x.grad = None + for p in self.norm.parameters(): + p.grad = None + torch.cuda.synchronize() From b8291223203b89bdb098a27a54728fdb174fd755 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 17 Mar 2026 08:51:18 -0500 Subject: [PATCH 2/5] Update casting benchmark --- benchmarks/asv/bench_casting.py | 42 ++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/benchmarks/asv/bench_casting.py b/benchmarks/asv/bench_casting.py index 7195a01ab..fb594a7c2 100644 --- a/benchmarks/asv/bench_casting.py +++ b/benchmarks/asv/bench_casting.py @@ -5,17 +5,13 @@ ############################################################################### """FP8 casting micro-benchmarks. -Memory-bound quantization/dequantization between BF16 and FP8 formats. +Memory-bound quantization/dequantization between BF16 and FP8 formats +using Transformer Engine's quantized tensor infrastructure. """ import torch - -if hasattr(torch, "float8_e4m3fnuz"): - FP8_E4M3 = torch.float8_e4m3fnuz - FP8_E5M2 = torch.float8_e5m2fnuz -else: - FP8_E4M3 = torch.float8_e4m3fn - FP8_E5M2 = torch.float8_e5m2 +from transformer_engine.pytorch import Float8CurrentScalingQuantizer +from transformer_engine_torch import DType as TE_DType HIDDEN_SIZES = { "Llama3-8B": 4096, @@ -26,10 +22,10 @@ } CAST_CONFIGS = { - "BF16_to_E4M3": (torch.bfloat16, FP8_E4M3), - "E4M3_to_BF16": (FP8_E4M3, torch.bfloat16), - "BF16_to_E5M2": (torch.bfloat16, FP8_E5M2), - "E5M2_to_BF16": (FP8_E5M2, torch.bfloat16), + "BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3), + "E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3), + "BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2), + "E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2), } @@ -40,12 +36,24 @@ class BenchCasting: def setup(self, M, model, cast): hidden = HIDDEN_SIZES[model] - src_dtype, self.dst_dtype = CAST_CONFIGS[cast] - if src_dtype in (FP8_E4M3, FP8_E5M2): - self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda").to(src_dtype) + direction, fp8_dtype = CAST_CONFIGS[cast] + self.direction = direction + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + rowwise=True, + columnwise=False, + ) + if direction == "dequantize": + bf16_tensor = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.x = quantizer.quantize(bf16_tensor) else: - self.x = torch.randn(M, hidden, dtype=src_dtype, device="cuda") + self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.quantizer = quantizer def time_cast(self, M, model, cast): - self.x.to(self.dst_dtype) + if self.direction == "quantize": + self.quantizer.quantize(self.x) + else: + self.x.dequantize(dtype=torch.bfloat16) torch.cuda.synchronize() From 21678b41426f3bc7e30f56022c93f66444a462d6 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 18 Mar 2026 16:26:16 -0500 Subject: [PATCH 3/5] Added helper script and documentation --- benchmarks/asv/README.md | 31 +++++++++++++++- benchmarks/asv/run_benchmarks.sh | 63 ++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) create mode 100755 benchmarks/asv/run_benchmarks.sh diff --git a/benchmarks/asv/README.md b/benchmarks/asv/README.md index 7de4fd6c5..bd02b4991 100644 --- a/benchmarks/asv/README.md +++ b/benchmarks/asv/README.md @@ -13,7 +13,36 @@ ASV is configured with `environment_type: "existing"` (in `asv.conf.json` at the meaning it uses the current Python environment directly — it does not create virtualenvs or attempt to build TE itself. -## Local usage +## Helper script + +A convenience wrapper (`benchmarks/asv/run_benchmarks.sh`) is provided for common tasks. +It can be run from anywhere — it automatically `cd`s to the repo root. Available benchmark +suites are discovered dynamically from `bench_*.py` files. + +```bash +bash benchmarks/asv/run_benchmarks.sh [options] +``` + +| Command | Description | +|---|---| +| `setup [name]` | Register machine with ASV (defaults to `hostname`) | +| `run [suite]` | Run all benchmarks, or a single suite (e.g. `bench_casting`) | +| `quick [suite]` | Smoke test — single iteration, results not saved | +| `compare [ref] [new]` | Compare two commits (defaults to `HEAD~1` vs `HEAD`) | +| `view` | Generate HTML dashboard and serve on `localhost:8080` | +| `list` | List available benchmark suites | + +Examples: + +```bash +bash benchmarks/asv/run_benchmarks.sh setup mi325 +bash benchmarks/asv/run_benchmarks.sh run bench_casting +bash benchmarks/asv/run_benchmarks.sh quick +bash benchmarks/asv/run_benchmarks.sh compare HEAD~3 HEAD +bash benchmarks/asv/run_benchmarks.sh view +``` + +## Local usage (manual ASV commands) All commands are run from the **repository root** (where `asv.conf.json` lives). diff --git a/benchmarks/asv/run_benchmarks.sh b/benchmarks/asv/run_benchmarks.sh new file mode 100755 index 000000000..5f07c23ff --- /dev/null +++ b/benchmarks/asv/run_benchmarks.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# Helper script for common ASV benchmark tasks. +# Run from the repository root (where asv.conf.json lives). +set -euo pipefail + +cd "$(git rev-parse --show-toplevel)" + +BENCH_DIR="benchmarks/asv" +mapfile -t SUITES < <(find "$BENCH_DIR" -maxdepth 1 -name 'bench_*.py' -printf '%f\n' | sed 's/\.py$//' | sort) + +usage() { + cat < [options] + +Commands: + setup Register this machine with ASV + run [SUITE] Run all benchmarks, or a single suite (e.g. bench_casting) + quick [SUITE] Smoke-test run (single iteration, results not saved) + compare [REF] [NEW] Compare two commits (default: HEAD~1 vs HEAD) + view Generate HTML dashboard and open preview server + list List available benchmark suites + +EOF +} + +case "${1:-}" in + setup) + MACHINE="${2:-$(hostname)}" + echo "Registering machine as: $MACHINE" + asv machine --yes --machine "$MACHINE" + ;; + run) + CMD=(asv run --python=same --launch-method spawn) + [[ -n "${2:-}" ]] && CMD+=(--bench "$2") + echo "Running: ${CMD[*]}" + "${CMD[@]}" + ;; + quick) + CMD=(asv run --python=same --launch-method spawn --quick) + [[ -n "${2:-}" ]] && CMD+=(--bench "$2") + echo "Running (quick): ${CMD[*]}" + "${CMD[@]}" + ;; + compare) + REF="${2:-HEAD~1}" + NEW="${3:-HEAD}" + echo "Comparing $REF vs $NEW" + asv continuous --python=same --launch-method spawn "$REF" "$NEW" + ;; + view) + asv publish + echo "Starting preview server at http://localhost:8080" + asv preview + ;; + list) + echo "Available benchmark suites:" + for s in "${SUITES[@]}"; do echo " $s"; done + ;; + *) + usage + exit 1 + ;; +esac From 6cb91a56a86bfb3653d333de851a282bf7b8ee97 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 19 Mar 2026 13:09:32 -0500 Subject: [PATCH 4/5] Corrected local benchmarking --- asv.conf.json | 2 +- benchmarks/asv/README.md | 13 ++++++++----- benchmarks/asv/run_benchmarks.sh | 6 ++++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/asv.conf.json b/asv.conf.json index dc71bf345..1b17c9c9e 100644 --- a/asv.conf.json +++ b/asv.conf.json @@ -3,7 +3,7 @@ "project": "TransformerEngine", "project_url": "https://github.com/ROCm/TransformerEngine", "repo": ".", - "branches": ["dev"], + "branches": ["HEAD", "dev"], "environment_type": "existing", "install_command": [], "build_command": [], diff --git a/benchmarks/asv/README.md b/benchmarks/asv/README.md index bd02b4991..5d2686ed2 100644 --- a/benchmarks/asv/README.md +++ b/benchmarks/asv/README.md @@ -11,7 +11,8 @@ a framework for benchmarking Python packages over their lifetime. ASV is configured with `environment_type: "existing"` (in `asv.conf.json` at the repo root), meaning it uses the current Python environment directly — it does not create virtualenvs or -attempt to build TE itself. +attempt to build TE itself. The config sets `branches: ["HEAD", "dev"]` so that `asv publish` +accepts results from both the currently checked-out branch and `dev` (for CI history). ## Helper script @@ -26,7 +27,7 @@ bash benchmarks/asv/run_benchmarks.sh [options] | Command | Description | |---|---| | `setup [name]` | Register machine with ASV (defaults to `hostname`) | -| `run [suite]` | Run all benchmarks, or a single suite (e.g. `bench_casting`) | +| `run [suite]` | Run benchmarks for the current commit (optionally a single suite) | | `quick [suite]` | Smoke test — single iteration, results not saved | | `compare [ref] [new]` | Compare two commits (defaults to `HEAD~1` vs `HEAD`) | | `view` | Generate HTML dashboard and serve on `localhost:8080` | @@ -59,16 +60,18 @@ the name must be consistent across runs for historical comparison. ### Run all benchmarks ```bash -asv run --python=same --launch-method spawn +asv run --python=same --launch-method spawn --set-commit-hash $(git rev-parse HEAD) ``` - `--python=same` — use the current interpreter (required with `environment_type: "existing"`) - `--launch-method spawn` — required for CUDA (fork causes "Cannot re-initialize CUDA in forked subprocess") +- `--set-commit-hash` — **required** with `environment_type: "existing"`. Without it, ASV + runs benchmarks but silently discards results. The helper script sets this automatically. ### Run a single suite ```bash -asv run --python=same --launch-method spawn --bench bench_casting +asv run --python=same --launch-method spawn --set-commit-hash $(git rev-parse HEAD) --bench bench_casting ``` The `--bench` argument accepts a regex that matches benchmark file or class names. @@ -76,7 +79,7 @@ The `--bench` argument accepts a regex that matches benchmark file or class name ### Quick smoke test ```bash -asv run --python=same --launch-method spawn --quick --bench bench_casting +asv run --python=same --launch-method spawn --quick --set-commit-hash $(git rev-parse HEAD) --bench bench_casting ``` `--quick` runs each benchmark only once with no statistical analysis. Useful for verifying diff --git a/benchmarks/asv/run_benchmarks.sh b/benchmarks/asv/run_benchmarks.sh index 5f07c23ff..7e9a21d23 100755 --- a/benchmarks/asv/run_benchmarks.sh +++ b/benchmarks/asv/run_benchmarks.sh @@ -30,13 +30,15 @@ case "${1:-}" in asv machine --yes --machine "$MACHINE" ;; run) - CMD=(asv run --python=same --launch-method spawn) + CMD=(asv run --python=same --launch-method spawn + --set-commit-hash "$(git rev-parse HEAD)") [[ -n "${2:-}" ]] && CMD+=(--bench "$2") echo "Running: ${CMD[*]}" "${CMD[@]}" ;; quick) - CMD=(asv run --python=same --launch-method spawn --quick) + CMD=(asv run --python=same --launch-method spawn --quick + --set-commit-hash "$(git rev-parse HEAD)") [[ -n "${2:-}" ]] && CMD+=(--bench "$2") echo "Running (quick): ${CMD[*]}" "${CMD[@]}" From 1a98989da5e79b9f53da90485b0157d5d256670a Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 19 Mar 2026 13:28:09 -0500 Subject: [PATCH 5/5] Added direct-run option to bypass subprocess overhead --- asv.conf.json | 2 +- benchmarks/asv/README.md | 8 ++- benchmarks/asv/direct_run.py | 92 ++++++++++++++++++++++++++++++++ benchmarks/asv/run_benchmarks.sh | 11 ++++ 4 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 benchmarks/asv/direct_run.py diff --git a/asv.conf.json b/asv.conf.json index 1b17c9c9e..482e20c60 100644 --- a/asv.conf.json +++ b/asv.conf.json @@ -3,7 +3,7 @@ "project": "TransformerEngine", "project_url": "https://github.com/ROCm/TransformerEngine", "repo": ".", - "branches": ["HEAD", "dev"], + "branches": ["HEAD"], "environment_type": "existing", "install_command": [], "build_command": [], diff --git a/benchmarks/asv/README.md b/benchmarks/asv/README.md index 5d2686ed2..5d88d8ad3 100644 --- a/benchmarks/asv/README.md +++ b/benchmarks/asv/README.md @@ -11,8 +11,9 @@ a framework for benchmarking Python packages over their lifetime. ASV is configured with `environment_type: "existing"` (in `asv.conf.json` at the repo root), meaning it uses the current Python environment directly — it does not create virtualenvs or -attempt to build TE itself. The config sets `branches: ["HEAD", "dev"]` so that `asv publish` -accepts results from both the currently checked-out branch and `dev` (for CI history). +attempt to build TE itself. The config sets `branches: ["HEAD"]` so that `asv publish` accepts results from +whichever branch is currently checked out — this works for both local development +and CI (where `HEAD` points to `dev`). ## Helper script @@ -29,6 +30,7 @@ bash benchmarks/asv/run_benchmarks.sh [options] | `setup [name]` | Register machine with ASV (defaults to `hostname`) | | `run [suite]` | Run benchmarks for the current commit (optionally a single suite) | | `quick [suite]` | Smoke test — single iteration, results not saved | +| `direct suite [method]` | Fast in-process run — no subprocesses, no ASV overhead | | `compare [ref] [new]` | Compare two commits (defaults to `HEAD~1` vs `HEAD`) | | `view` | Generate HTML dashboard and serve on `localhost:8080` | | `list` | List available benchmark suites | @@ -39,6 +41,8 @@ Examples: bash benchmarks/asv/run_benchmarks.sh setup mi325 bash benchmarks/asv/run_benchmarks.sh run bench_casting bash benchmarks/asv/run_benchmarks.sh quick +bash benchmarks/asv/run_benchmarks.sh direct bench_casting +bash benchmarks/asv/run_benchmarks.sh direct bench_gemm time_forward bash benchmarks/asv/run_benchmarks.sh compare HEAD~3 HEAD bash benchmarks/asv/run_benchmarks.sh view ``` diff --git a/benchmarks/asv/direct_run.py b/benchmarks/asv/direct_run.py new file mode 100644 index 000000000..1e59e99e8 --- /dev/null +++ b/benchmarks/asv/direct_run.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +"""Run ASV benchmark classes directly in-process, bypassing subprocess overhead. + +Usage: + python benchmarks/asv/direct_run.py [options] [method_filter] + +Examples: + python benchmarks/asv/direct_run.py bench_casting + python benchmarks/asv/direct_run.py bench_gemm time_forward + python benchmarks/asv/direct_run.py -w 5 -n 20 bench_casting +""" + +import argparse +import importlib +import itertools +import math +import sys +import time + + +def run_class(cls, class_name, method_filter=None, warmup=3, iters=7): + methods = sorted(m for m in dir(cls) if m.startswith("time_")) + if method_filter: + methods = [m for m in methods if method_filter in m] + if not methods: + return + + params = getattr(cls, "params", [[]]) + param_names = getattr(cls, "param_names", []) + combos = list(itertools.product(*params)) + + print(f"\n{class_name} ({len(combos)} combos x {len(methods)} methods, " + f"{warmup} warmup, {iters} timed)") + print("-" * 90) + print(f" {'median':>10} {'mean':>10} {'stdev':>10} {'method':<30} params") + print("-" * 90) + + for combo in combos: + label = ", ".join(f"{n}={v}" for n, v in zip(param_names, combo)) + instance = cls() + try: + instance.setup(*combo) + except Exception as e: + print(f" SKIP {label} setup failed: {e}") + continue + + for method_name in methods: + method = getattr(instance, method_name) + + for _ in range(warmup): + method(*combo) + + times = [] + for _ in range(iters): + t0 = time.perf_counter() + method(*combo) + times.append(time.perf_counter() - t0) + + times.sort() + median = times[len(times) // 2] + mean = sum(times) / len(times) + stdev = math.sqrt(sum((t - mean) ** 2 for t in times) / len(times)) + print(f" {median*1000:>8.3f}ms {mean*1000:>8.3f}ms " + f"{stdev*1000:>8.3f}ms {method_name:<30} {label}") + + +def main(): + parser = argparse.ArgumentParser( + description="Run ASV benchmarks directly in-process (no subprocess overhead).") + parser.add_argument("suite", help="Benchmark module name (e.g. bench_casting)") + parser.add_argument("method_filter", nargs="?", default=None, + help="Only run time_* methods containing this string") + parser.add_argument("-w", "--warmup", type=int, default=3, + help="Number of warmup iterations (default: 3)") + parser.add_argument("-n", "--iters", type=int, default=7, + help="Number of timed iterations (default: 7)") + args = parser.parse_args() + + mod = importlib.import_module(args.suite) + + for name in sorted(dir(mod)): + obj = getattr(mod, name) + if isinstance(obj, type) and name.startswith("Bench"): + run_class(obj, name, args.method_filter, args.warmup, args.iters) + + +if __name__ == "__main__": + import os + + os.chdir(os.path.dirname(os.path.abspath(__file__))) + sys.path.insert(0, ".") + main() diff --git a/benchmarks/asv/run_benchmarks.sh b/benchmarks/asv/run_benchmarks.sh index 7e9a21d23..4ca71881c 100755 --- a/benchmarks/asv/run_benchmarks.sh +++ b/benchmarks/asv/run_benchmarks.sh @@ -16,6 +16,8 @@ Commands: setup Register this machine with ASV run [SUITE] Run all benchmarks, or a single suite (e.g. bench_casting) quick [SUITE] Smoke-test run (single iteration, results not saved) + direct [-w W] [-n N] SUITE [METHOD] + Fast in-process run (no subprocesses, no ASV overhead) compare [REF] [NEW] Compare two commits (default: HEAD~1 vs HEAD) view Generate HTML dashboard and open preview server list List available benchmark suites @@ -43,6 +45,15 @@ case "${1:-}" in echo "Running (quick): ${CMD[*]}" "${CMD[@]}" ;; + direct) + shift + if [[ $# -eq 0 ]]; then + echo "Usage: $0 direct [options] SUITE [METHOD]" + echo "Options: -w WARMUP -n ITERS" + exit 1 + fi + python "$BENCH_DIR/direct_run.py" "$@" + ;; compare) REF="${2:-HEAD~1}" NEW="${3:-HEAD}"