diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 44045c274..c2de2dac3 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1635,13 +1635,9 @@ dsv4-fp8-mi355x-vllm: search-space: - { tp: 8, conc-start: 1, conc-end: 1 } -# Day-0 single-sequence marker for DeepSeek-V4 on ATOM (ROCm/ATOM#650). -# PR1 of the ATOM DSv4 series still uses torch sparse-attention fallbacks -# that OOM once warmup/prefill batches multiple requests; keep CONC=1 until -# the AITER sparse-attention kernel / multi-request path lands upstream. -# --enforce-eager and ATOM_USE_TRITON_MOE=1 are required on gfx950. Image is -# the standard atom0.1.2.post MI355X base (matching qwen3.5-fp8-mi355x-atom); -# the DSv4 PR is overlaid at runtime by dsv4_fp4_mi355x_atom.sh at a pinned SHA. +# DeepSeek-V4 on ATOM using the updated atom0.1.2.post image. The launcher +# overlays ROCm/ATOM#650 only for DSv4 model registration/skeleton support, +# then overlays ROCm/aiter#2998 for sparse/indexer kernels. dsv4-fp4-mi355x-atom: image: rocm/atom:rocm7.2.2_ubuntu24.04_py3.12_pytorch_release_2.10.0_atom0.1.2.post model: deepseek-ai/DeepSeek-V4-Pro @@ -1655,8 +1651,8 @@ dsv4-fp4-mi355x-atom: - isl: 1024 osl: 1024 search-space: - - { tp: 8, ep: 1, conc-start: 1, conc-end: 1 } + - { tp: 8, ep: 1, conc-start: 4, conc-end: 128 } - isl: 8192 osl: 1024 search-space: - - { tp: 8, ep: 1, conc-start: 1, conc-end: 1 } + - { tp: 8, ep: 1, conc-start: 4, conc-end: 64 } diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index d03b98d19..716724d05 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -161,8 +161,8 @@ jobs: - If jobs cannot be run, say exactly what you could not run and why - **Important** Modify perf-changelog.yaml for any config changes affecting performance - ## Profiling (SGLang only) - When asked to profile a config, dispatch the `profile.yml` workflow. **Only SGLang configs can be profiled** — the profiler uses SGLang's `/start_profile` and `/stop_profile` HTTP endpoints. Reject profiling requests for vLLM, TRT, or other frameworks. + ## Profiling + When asked to profile a config, dispatch the `profile.yml` workflow. SGLang, vLLM, and ATOM single-node configs can be profiled through their `/start_profile` and `/stop_profile` HTTP endpoints when the server is launched with the corresponding torch profiler directory. Reject profiling requests for TRT, disaggregated/multi-node configs, or other frameworks. **Syntax:** ``` @@ -172,9 +172,10 @@ jobs: workflow_id="profile.yml", ref="main", inputs={ - "config-key": "", + "config-key": "", "config-file": "<.github/configs/nvidia-master.yaml or amd-master.yaml>", - "conc": "" + "conc": "", + "seq-len": "<1k1k or 8k1k>" } ) ``` @@ -184,19 +185,16 @@ jobs: - Model: "deepseek" / "dsr1" → model-prefix `dsr1`; "gptoss" → `gptoss`; "qwen" → `qwen3.5` - Precision: "fp4" / "fp8" / "bf16" - Runner/hardware: "b200", "h200", "h100", "mi300x", "mi325x", "mi355x", etc. - - Framework: must be "sglang" (reject if not) + - Framework: must be "sglang", "vllm", or "atom" (reject TRT and disaggregated/multi-node) - Concurrency: "conc=N" → `"conc": "N"`. Default to `"64"` if not specified. + - Sequence length: default to `"1k1k"` unless the user asks for `"8k1k"`. - Construct the config-key as: `{model-prefix}-{precision}-{runner}-sglang` + Construct the config-key as: `{model-prefix}-{precision}-{runner}-{framework}` Choose config-file: NVIDIA runners (b200, h200, h100, gb200, gb300) → `nvidia-master.yaml`; AMD runners (mi300x, mi325x, mi355x) → `amd-master.yaml` - **Available SGLang config keys:** - NVIDIA: `dsr1-fp4-b200-sglang`, `dsr1-fp8-b200-sglang`, `dsr1-fp8-h200-sglang`, `qwen3.5-bf16-b200-sglang` - AMD: `dsr1-fp4-mi355x-sglang`, `dsr1-fp8-mi300x-sglang`, `dsr1-fp8-mi325x-sglang`, `dsr1-fp8-mi355x-sglang`, `qwen3.5-bf16-mi355x-sglang`, `qwen3.5-fp8-mi355x-sglang` - **Examples:** - - "profile sglang b200 deepseek fp4 conc=4" → `config-key: dsr1-fp4-b200-sglang`, `config-file: .github/configs/nvidia-master.yaml`, `conc: 4` - - "profile sglang mi355x dsr1 fp8" → `config-key: dsr1-fp8-mi355x-sglang`, `config-file: .github/configs/amd-master.yaml`, `conc: 64` + - "profile sglang b200 deepseek fp4 conc=4" → `config-key: dsr1-fp4-b200-sglang`, `config-file: .github/configs/nvidia-master.yaml`, `conc: 4`, `seq-len: 1k1k` + - "profile atom mi355x dsv4 fp4 conc=4 8k1k" → `config-key: dsv4-fp4-mi355x-atom`, `config-file: .github/configs/amd-master.yaml`, `conc: 4`, `seq-len: 8k1k` **After dispatch:** Monitor with `mcp__github__get_workflow_run`. The profile workflow takes ~15-30 minutes. When complete, the **Perfetto relay link** is in the workflow run's step summary. Retrieve it with: diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml index 8152d47a5..c47a89dfa 100644 --- a/.github/workflows/profile.yml +++ b/.github/workflows/profile.yml @@ -17,6 +17,14 @@ on: required: false type: string default: '64' + seq-len: + description: "Sequence length config to profile" + required: false + type: choice + options: + - 1k1k + - 8k1k + default: 1k1k moe-debug: description: "Enable MoE debug patch and log (MOE_DEBUG_LOG)" required: false @@ -54,7 +62,7 @@ jobs: name: Generate matrix via script run: | pip install pydantic - CLI_ARGS="test-config --config-files ${{ inputs.config-file }} --config-keys ${{ inputs.config-key }} --conc ${{ inputs.conc }}" + CLI_ARGS="test-config --config-files ${{ inputs.config-file }} --config-keys ${{ inputs.config-key }} --conc ${{ inputs.conc }} --seq-lens ${{ inputs.seq-len }}" CONFIG_JSON=$(python3 ${GITHUB_WORKSPACE}/utils/matrix_logic/generate_sweep_configs.py $CLI_ARGS) echo "raw=$CONFIG_JSON" >> $GITHUB_OUTPUT @@ -148,13 +156,16 @@ jobs: ref: ${{ inputs.ref || github.sha }} clean: false - - name: Launch + Profile (single-node sglang/vllm) + - name: Launch + Profile (single-node) id: run env: RUNNER_NAME: ${{ runner.name }} PROFILE: '1' SGLANG_TORCH_PROFILER_DIR: /workspace/ VLLM_TORCH_PROFILER_DIR: /workspace/ + ATOM_TORCH_PROFILER_DIR: /workspace/atom_profiles + PROFILE_NUM_STEPS: '1' + PROFILE_OUTPUT_LEN: '1' VLLM_RPC_TIMEOUT: '1800000' shell: bash run: | @@ -173,6 +184,11 @@ jobs: trace_path="profile_${res_name}.trace.json.gz" if [ -f "$trace_path" ]; then + if [ ! -s "$trace_path" ]; then + echo "Profile trace is empty: $trace_path" >&2 + exit 1 + fi + gzip -t "$trace_path" echo "trace=$trace_path" >> "$GITHUB_OUTPUT" if [ "${FRAMEWORK}" = "sglang" ]; then # Try to locate corresponding TP-0 traces produced by SGLang profiler @@ -193,16 +209,31 @@ jobs: fi else echo "Profile trace not found: $trace_path" >&2 + exit 1 fi - name: Process result (json -> agg) + continue-on-error: true env: RUNNER_TYPE: ${{ matrix.config.runner }} run: | python3 utils/process_result.py + - name: Upload profile diagnostics + if: ${{ always() && env.RESULT_FILENAME != '' }} + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: profile_diagnostics_${{ env.RESULT_FILENAME }} + path: | + ${{ env.RESULT_FILENAME }}.json + agg_${{ env.RESULT_FILENAME }}.json + server.log + gpu_metrics.csv + atom_profiles/**/*.trace.json.gz + if-no-files-found: ignore + - name: Upload profile as artifact - if: ${{ steps.run.outputs.trace != '' }} + if: ${{ always() && steps.run.outputs.trace != '' }} uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: profile_${{ env.RESULT_FILENAME }} @@ -210,7 +241,7 @@ jobs: if-no-files-found: ignore - name: Upload TP-0-DECODE trace as artifact - if: ${{ steps.run.outputs.tp0_decode != '' }} + if: ${{ always() && steps.run.outputs.tp0_decode != '' }} uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: profile_${{ env.RESULT_FILENAME }}_TP0_DECODE @@ -218,7 +249,7 @@ jobs: if-no-files-found: ignore - name: Upload TP-0-EXTEND trace as artifact - if: ${{ steps.run.outputs.tp0_extend != '' }} + if: ${{ always() && steps.run.outputs.tp0_extend != '' }} uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: profile_${{ env.RESULT_FILENAME }}_TP0_EXTEND diff --git a/benchmarks/benchmark_lib.sh b/benchmarks/benchmark_lib.sh index 0cb8fdcd0..221fcdb20 100644 --- a/benchmarks/benchmark_lib.sh +++ b/benchmarks/benchmark_lib.sh @@ -330,15 +330,18 @@ run_benchmark_serving() { fi # Profiling support: when PROFILE=1, ensure profiler dir exists, add --profile flag, - # and cap num_prompts to keep traces small. + # and cap the run to a tiny one-step window by default. local profile_flag=() if [[ "${PROFILE:-}" == "1" ]]; then - local _prof_dir="${SGLANG_TORCH_PROFILER_DIR:-${VLLM_TORCH_PROFILER_DIR:-}}" - if [[ -n "$_prof_dir" ]]; then - mkdir -p "$_prof_dir" - fi + local _prof_dir="" + for _prof_dir in "${SGLANG_TORCH_PROFILER_DIR:-}" "${VLLM_TORCH_PROFILER_DIR:-}" "${ATOM_TORCH_PROFILER_DIR:-}"; do + if [[ -n "$_prof_dir" ]]; then + mkdir -p "$_prof_dir" + fi + done profile_flag+=(--profile) - num_prompts="$max_concurrency" + num_prompts="${PROFILE_NUM_PROMPTS:-$max_concurrency}" + output_len="${PROFILE_OUTPUT_LEN:-${PROFILE_NUM_STEPS:-1}}" fi # Build benchmark command @@ -391,16 +394,17 @@ run_benchmark_serving() { local benchmark_pid=$! # Monitor loop: check both benchmark and server status + set +x while kill -0 "$benchmark_pid" 2>/dev/null; do if ! kill -0 "$server_pid" 2>/dev/null; then echo "ERROR: Server process $server_pid died during benchmark" kill "$benchmark_pid" 2>/dev/null wait "$benchmark_pid" 2>/dev/null - set +x return 1 fi sleep 2 done + set -x # Benchmark finished, get its exit code wait "$benchmark_pid" @@ -425,6 +429,20 @@ run_benchmark_serving() { # Profiling trace helpers # -------------------------------- +setup_atom_profile_args() { + ATOM_PROFILE_ARGS=() + if [[ "${PROFILE:-}" == "1" ]]; then + ATOM_TORCH_PROFILER_DIR=${ATOM_TORCH_PROFILER_DIR:-/workspace/atom_profiles} + case "$ATOM_TORCH_PROFILER_DIR" in + /workspace/atom_profiles|/workspace/atom_profiles/*|/tmp/*) + rm -rf "$ATOM_TORCH_PROFILER_DIR" + ;; + esac + mkdir -p "$ATOM_TORCH_PROFILER_DIR" + ATOM_PROFILE_ARGS+=(--torch-profiler-dir "$ATOM_TORCH_PROFILER_DIR") + fi +} + _find_latest_profile_trace() { local latest="" local dir="" candidate="" base="" @@ -434,6 +452,9 @@ _find_latest_profile_trace() { search_roots=() if [[ -d "$dir" ]]; then search_roots+=("$dir") + while IFS= read -r -d '' candidate; do + search_roots+=("$candidate") + done < <(find "$dir" -mindepth 1 -maxdepth 1 -type d -print0 2>/dev/null) fi if [[ -d "$dir/profiles" ]]; then search_roots+=("$dir/profiles") @@ -447,6 +468,9 @@ _find_latest_profile_trace() { if [[ "$base" == profile_*.trace.json.gz ]]; then continue fi + if [[ ! -s "$candidate" ]]; then + continue + fi if [[ -z "$latest" || "$candidate" -nt "$latest" ]]; then latest="$candidate" fi @@ -460,6 +484,32 @@ _find_latest_profile_trace() { printf '%s' "$latest" } +_profile_trace_is_ready() { + local trace_file="$1" + local size_before="" size_after="" + + if [[ ! -s "$trace_file" ]]; then + return 1 + fi + + size_before="$(wc -c < "$trace_file" 2>/dev/null || printf '0')" + sleep "${PROFILE_TRACE_STABLE_SLEEP:-2}" + if [[ ! -s "$trace_file" ]]; then + return 1 + fi + size_after="$(wc -c < "$trace_file" 2>/dev/null || printf '0')" + + if [[ "$size_before" != "$size_after" || "$size_after" -le 0 ]]; then + return 1 + fi + + if [[ "$trace_file" == *.gz ]]; then + gzip -t "$trace_file" >/dev/null 2>&1 || return 1 + fi + + return 0 +} + # Move profiler trace into a stable workspace path for workflow relay/upload. move_profile_trace_for_relay() { if [[ "${PROFILE:-}" != "1" ]]; then @@ -473,11 +523,12 @@ move_profile_trace_for_relay() { local sglang_dir="${SGLANG_TORCH_PROFILER_DIR:-/workspace}" local vllm_dir="${VLLM_TORCH_PROFILER_DIR:-/workspace}" + local atom_dir="${ATOM_TORCH_PROFILER_DIR:-/workspace}" local -a search_dirs=() local dir="" existing="" local seen=0 - for dir in "$sglang_dir" "$vllm_dir" "/workspace"; do + for dir in "$sglang_dir" "$vllm_dir" "$atom_dir" "/workspace"; do if [[ -z "$dir" ]]; then continue fi @@ -494,17 +545,21 @@ move_profile_trace_for_relay() { done local trace_file="" - local wait_attempts=10 + local wait_attempts="${PROFILE_TRACE_WAIT_ATTEMPTS:-30}" for (( i=1; i<=wait_attempts; i++ )); do trace_file="$(_find_latest_profile_trace "${search_dirs[@]}")" - if [[ -n "$trace_file" ]]; then + if [[ -n "$trace_file" ]] && _profile_trace_is_ready "$trace_file"; then break fi - sleep 10 + if [[ -n "$trace_file" ]]; then + echo "[PROFILE] Waiting for trace to finish writing: $trace_file" >&2 + fi + trace_file="" + sleep "${PROFILE_TRACE_WAIT_SLEEP:-5}" done if [[ -z "$trace_file" ]]; then - echo "[PROFILE] No trace found for relay under: ${search_dirs[*]}" >&2 + echo "[PROFILE] No complete trace found for relay under: ${search_dirs[*]}" >&2 return 0 fi @@ -515,6 +570,12 @@ move_profile_trace_for_relay() { gzip -c "$trace_file" > "$dest_trace" fi + if [[ ! -s "$dest_trace" ]] || ! gzip -t "$dest_trace" >/dev/null 2>&1; then + echo "[PROFILE] Relay trace is invalid after staging: $dest_trace" >&2 + rm -f "$dest_trace" + return 0 + fi + echo "[PROFILE] Relay trace prepared: $dest_trace (source: $trace_file)" } @@ -548,9 +609,30 @@ _patch_lm_eval() { patch_dir="$(mktemp -d)" cat > "$patch_dir/sitecustomize.py" <<'PY' # --- Patch LocalChatCompletion.parse_generations to handle empty content with reasoning_content --- -import re, sys, unicodedata, json +import os, re, sys, unicodedata, json from lm_eval.filters import extraction as ex from lm_eval.models.openai_completions import LocalChatCompletion as _LCC +try: + from lm_eval.models.openai_completions import LocalCompletions as _LC +except Exception: + _LC = None + +def _truncate_dsv4_generation(text): + if os.environ.get("EVAL_DSV4_CHAT_TEMPLATE") != "1": + return text + if not isinstance(text, str): + return "" if text is None else str(text) + stops = [ + "<|end▁of▁sentence|>", + "<|begin▁of▁sentence|>", + "<|User|>", + "<|Assistant|>", + "\u202e", + ] + positions = [text.find(stop) for stop in stops if text.find(stop) >= 0] + if not positions: + return text + return text[: min(positions)] def _le_parse_generations(outputs, **kwargs): res = [] @@ -566,7 +648,29 @@ def _le_parse_generations(outputs, **kwargs): content = msg.get("content") if content in (None, "", []): content = msg.get("reasoning_content") or "" - tmp[idx] = content + tmp[idx] = _truncate_dsv4_generation(content) + except Exception: + tmp = [""] + res.extend(tmp) + return res + +def _lc_parse_generations(outputs, **kwargs): + res = [] + if not isinstance(outputs, list): + outputs = [outputs] + for out in (outputs or []): + try: + choices = out.get("choices", []) + tmp = ["" for _ in choices] + for choice in choices: + idx = choice.get("index", 0) + if idx >= len(tmp): + tmp.extend([""] * (idx - len(tmp) + 1)) + content = choice.get("text") + if content in (None, "", []): + msg = choice.get("message") or {} + content = msg.get("content") or msg.get("reasoning_content") or "" + tmp[idx] = _truncate_dsv4_generation(content) except Exception: tmp = [""] res.extend(tmp) @@ -574,8 +678,10 @@ def _le_parse_generations(outputs, **kwargs): # Keep staticmethod semantics _LCC.parse_generations = staticmethod(_le_parse_generations) +if _LC is not None: + _LC.parse_generations = staticmethod(_lc_parse_generations) -# --- Patch TemplateAPI.apply_chat_template to avoid injecting "type": "text" for TRT --- +# --- Patch TemplateAPI.apply_chat_template --- try: from lm_eval.models import api_models as _api_models _TemplateAPI = _api_models.TemplateAPI @@ -586,6 +692,56 @@ except Exception: if _TemplateAPI is not None and _JsonChatStr is not None: _orig_apply_chat_template = _TemplateAPI.apply_chat_template + _dsv4_encode_messages = None + + def _content_to_text(content): + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text", item.get("content", "")))) + else: + parts.append(str(item)) + return "\n".join(part for part in parts if part) + if content is None: + return "" + return str(content) + + def _load_dsv4_encoder(): + global _dsv4_encode_messages + if _dsv4_encode_messages is not None: + return _dsv4_encode_messages + + roots = [ + os.environ.get("INFMAX_WORKSPACE"), + os.environ.get("GITHUB_WORKSPACE"), + os.getcwd(), + "/workspace", + "/infmax-workspace", + ] + for root in roots: + if not root: + continue + candidate = os.path.join(root, "utils", "bench_serving") + if os.path.exists(os.path.join(candidate, "encoding_dsv4.py")) and candidate not in sys.path: + sys.path.insert(0, candidate) + + from encoding_dsv4 import encode_messages + + _dsv4_encode_messages = encode_messages + return _dsv4_encode_messages + + def _apply_dsv4_chat_template(chat_history): + encode_messages = _load_dsv4_encoder() + messages = [] + for item in chat_history: + normalized = {**item} + normalized.pop("type", None) + normalized["content"] = _content_to_text(normalized.get("content")) + messages.append(normalized) + return encode_messages(messages, thinking_mode="thinking") def _patched_apply_chat_template( self, @@ -593,6 +749,8 @@ if _TemplateAPI is not None and _JsonChatStr is not None: add_generation_prompt: bool = True, ): """Applies a chat template to a list of chat history between user and model.""" + if os.environ.get("EVAL_DSV4_CHAT_TEMPLATE") == "1": + return _apply_dsv4_chat_template(chat_history) if self.tokenizer_backend == "huggingface" and self.tokenized_requests: return self.tokenizer.apply_chat_template( chat_history, @@ -683,7 +841,8 @@ run_lm_eval() { local eval_context_len="${EVAL_MAX_MODEL_LEN:-16384}" local temperature=0 local top_p=1 - local concurrent_requests="${EVAL_CONCURRENT_REQUESTS:-64}" + local concurrent_requests="${EVAL_CONCURRENT_REQUESTS:-${CONC:-64}}" + local eval_limit="${EVAL_LIMIT:-}" while [[ $# -gt 0 ]]; do case $1 in @@ -693,17 +852,37 @@ run_lm_eval() { --gen-max-tokens) eval_context_len="$2"; shift 2 ;; --temperature) temperature="$2"; shift 2 ;; --top-p) top_p="$2"; shift 2 ;; + --limit) eval_limit="$2"; shift 2 ;; *) echo "Unknown parameter: $1"; return 1 ;; esac done - _install_lm_eval_deps - _patch_lm_eval - local openai_server_base="http://0.0.0.0:${port}" local openai_chat_base="${openai_server_base}/v1/chat/completions" + local openai_completions_base="${openai_server_base}/v1/completions" export OPENAI_API_KEY=${OPENAI_API_KEY:-EMPTY} - MODEL_NAME=${MODEL_NAME:-$MODEL} # Prefer MODEL_NAME, else MODEL + export MODEL_NAME="${MODEL_NAME:-$MODEL}" # Prefer MODEL_NAME, else MODEL + + local lm_eval_model="local-chat-completions" + local lm_eval_base_url="$openai_chat_base" + local lm_eval_eos_string="${EVAL_EOS_STRING:-}" + local lm_eval_tokenizer_args="tokenized_requests=False" + local is_dsv4_eval=false + + if [[ "${MODEL_PREFIX:-}" == "dsv4" || "${MODEL_NAME:-}" == *"DeepSeek-V4"* || "${MODEL:-}" == *"DeepSeek-V4"* ]]; then + is_dsv4_eval=true + export EVAL_DSV4_CHAT_TEMPLATE=1 + lm_eval_model="local-completions" + lm_eval_base_url="$openai_completions_base" + lm_eval_eos_string="${EVAL_EOS_STRING:-<|end▁of▁sentence|>}" + lm_eval_tokenizer_args="tokenizer_backend=None,tokenized_requests=False" + echo "Using DeepSeek-V4 eval prompt encoding via utils/bench_serving/encoding_dsv4.py" + else + unset EVAL_DSV4_CHAT_TEMPLATE + fi + + _install_lm_eval_deps + _patch_lm_eval # Cap output tokens: must fit within context window (leave room for input), # and avoid excessive KV cache reservation per request on TRT. @@ -711,16 +890,30 @@ run_lm_eval() { if [ "$max_output_tokens" -gt 16384 ]; then max_output_tokens=16384 fi + if [ -n "${EVAL_MAX_OUTPUT_TOKENS:-}" ]; then + max_output_tokens="$EVAL_MAX_OUTPUT_TOKENS" + elif [ "$is_dsv4_eval" = "true" ]; then + local dsv4_max_output_tokens="${EVAL_DSV4_MAX_OUTPUT_TOKENS:-1024}" + if [ "$max_output_tokens" -gt "$dsv4_max_output_tokens" ]; then + max_output_tokens="$dsv4_max_output_tokens" + fi + fi echo "Eval budget: eval_context_len=${eval_context_len}, max_output_tokens=${max_output_tokens}" # Export for append_lm_eval_summary to pick up export EVAL_RESULT_DIR="$results_dir" + local limit_args=() + if [ -n "$eval_limit" ]; then + limit_args=(--limit "$eval_limit") + echo "Eval sample limit: ${eval_limit}" + fi set -x - python3 -m lm_eval --model local-chat-completions --apply_chat_template \ + python3 -m lm_eval --model "${lm_eval_model}" --apply_chat_template \ --tasks "${tasks_dir}" \ --output_path "${results_dir}" \ --log_samples \ - --model_args "model=${MODEL_NAME},base_url=${openai_chat_base},api_key=${OPENAI_API_KEY},eos_string=,max_retries=5,num_concurrent=${concurrent_requests},timeout=1800,tokenized_requests=False,max_length=${eval_context_len}" \ + "${limit_args[@]}" \ + --model_args "model=${MODEL_NAME},base_url=${lm_eval_base_url},api_key=${OPENAI_API_KEY},eos_string=${lm_eval_eos_string},max_retries=5,num_concurrent=${concurrent_requests},timeout=1800,${lm_eval_tokenizer_args},max_length=${eval_context_len}" \ --gen_kwargs "max_tokens=${max_output_tokens},temperature=${temperature},top_p=${top_p}" local eval_exit=$? set +x diff --git a/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh b/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh index 31554fc22..7adee9f21 100644 --- a/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh @@ -48,12 +48,14 @@ start_gpu_monitor set -x BLOCK_SIZE=${BLOCK_SIZE:-16} +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ - --block-size $BLOCK_SIZE > $SERVER_LOG 2>&1 & + --block-size $BLOCK_SIZE \ + "${ATOM_PROFILE_ARGS[@]}" > $SERVER_LOG 2>&1 & SERVER_PID=$! @@ -80,4 +82,4 @@ fi # Stop GPU monitoring stop_gpu_monitor -set +x \ No newline at end of file +set +x diff --git a/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh b/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh index 1d557684e..9da5c778d 100644 --- a/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh +++ b/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh @@ -49,12 +49,14 @@ set -x export AMDGCN_USE_BUFFER_OPS=1 +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --method mtp \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh b/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh index 31554fc22..7adee9f21 100644 --- a/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh @@ -48,12 +48,14 @@ start_gpu_monitor set -x BLOCK_SIZE=${BLOCK_SIZE:-16} +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ - --block-size $BLOCK_SIZE > $SERVER_LOG 2>&1 & + --block-size $BLOCK_SIZE \ + "${ATOM_PROFILE_ARGS[@]}" > $SERVER_LOG 2>&1 & SERVER_PID=$! @@ -80,4 +82,4 @@ fi # Stop GPU monitoring stop_gpu_monitor -set +x \ No newline at end of file +set +x diff --git a/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh b/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh index 69179cec0..ea5bbc5b1 100644 --- a/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh +++ b/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh @@ -47,6 +47,7 @@ start_gpu_monitor set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -54,6 +55,7 @@ python3 -m atom.entrypoints.openai_server \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --method mtp \ --num-speculative-tokens 3 \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index 21708ba1d..574cd76c9 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -19,19 +19,8 @@ fi echo "TP: $TP, CONC: $CONC, ISL: $ISL, OSL: $OSL, EP_SIZE: $EP_SIZE" -# ROCm/ATOM#650 is still a single-request marker for DSv4. Run -# 24953107645 showed CONC>1 fails in two ways: 1k warmup can exhaust the KV -# budget after sparse-attn temporaries raise peak memory, and 8k prefill OOMs -# in the torch sparse_attn fallback when two long requests are batched. Keep -# this fatal guard until ATOM lands the AITER sparse-attention / multi-request -# path for DeepSeek-V4. -if [ "$CONC" -ne 1 ]; then - echo "FATAL: ROCm/ATOM#650 DSv4 path is single-request only; CONC must be 1, got $CONC" >&2 - exit 1 -fi - if [ "$EP_SIZE" -ne 1 ]; then - echo "FATAL: ROCm/ATOM#650 PR1 has not validated expert parallel serving; EP_SIZE must be 1, got $EP_SIZE" >&2 + echo "FATAL: DSv4 ATOM benchmark expects EP_SIZE=1, got $EP_SIZE" >&2 exit 1 fi @@ -39,77 +28,29 @@ SERVER_LOG=/workspace/server.log PORT=${PORT:-8888} export OMP_NUM_THREADS=1 - -# DSv4-specific ATOM env vars. Prefer the native AITER MXFP4 MoE path after -# overlaying the AITER perf stack below. Set AITER_DSV4_FP4_MOE_BACKEND=triton -# to return to ROCm/ATOM#650's original triton_kernels matmul_ogs path. -if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then - DEFAULT_AITER_DSV4_FP4_MOE_BACKEND=native -else - DEFAULT_AITER_DSV4_FP4_MOE_BACKEND=triton -fi -AITER_DSV4_FP4_MOE_BACKEND=${AITER_DSV4_FP4_MOE_BACKEND:-$DEFAULT_AITER_DSV4_FP4_MOE_BACKEND} -if [ "$AITER_DSV4_FP4_MOE_BACKEND" = "triton" ]; then - export ATOM_USE_TRITON_MOE=1 -else - unset ATOM_USE_TRITON_MOE - unset ATOM_USE_TRITON_GEMM -fi export AITER_LOG_LEVEL=WARNING -# Pull in the AITER pieces that matter for DSv4 FP4 on MI355X: -# * origin/main@dde1703e includes ROCm/aiter#2770 a16w4 MoE support. -# * ROCm/aiter#2822 speeds up batched MXFP4 GEMM on gfx950. -# * ROCm/aiter#2900 fixes MXFP4 scale padding for non-256 K. -# * ROCm/aiter#2642 enables/fixes TP=4/8 MXFP4 MoE dispatch. -# * sunway513/aiter@e450e4d adds DSv4 FP4 MoE tuned rows that route -# eligible token counts to FlyDSL FP4 MoE kernels instead of default CK -# heuristics when the image has the optional flydsl package. -# -# ROCm/aiter#2916 is intentionally not cherry-picked here. That PR branch is -# based on a divergent fork and can conflict in unrelated test files; the -# narrow mhc_pre device fix is applied directly to installed aiter below. -# The non-mHC PRs cherry-pick cleanly over the pinned main SHA as of 2026-04-27. -# Keep this as a runtime overlay until AMD publishes an ATOM image with these -# AITER changes baked in; then remove this block and pin that image instead. -if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then - AITER_PERF_REPO=${AITER_PERF_REPO:-https://github.com/ROCm/aiter.git} - AITER_PERF_DIR=${AITER_PERF_DIR:-/tmp/aiter-dsv4-fp4-perf} - AITER_PERF_BASE_SHA=${AITER_PERF_BASE_SHA:-dde1703ebfc35d3724e07fc4e6e824023063494c} - AITER_PERF_PATCH_REFS=( - "${AITER_PERF_BATCHED_FP4_REF:-pull/2822/head}" - "${AITER_PERF_MXFP4_SCALE_REF:-pull/2900/head}" - "${AITER_PERF_MOE_REF:-pull/2642/head}" - ) - AITER_DSV4_TUNED_FMOE=${AITER_DSV4_TUNED_FMOE:-1} - AITER_DSV4_TUNED_FMOE_REPO=${AITER_DSV4_TUNED_FMOE_REPO:-https://github.com/sunway513/aiter.git} - AITER_DSV4_TUNED_FMOE_SHA=${AITER_DSV4_TUNED_FMOE_SHA:-e450e4deb992c5ecd9db5ef5ef79f1d40208bc9c} - AITER_DSV4_TUNED_FMOE_PATH=${AITER_DSV4_TUNED_FMOE_PATH:-aiter/configs/model_configs/dsv4_fp4_tuned_fmoe.csv} - - rm -rf "$AITER_PERF_DIR" - git clone --filter=blob:none "$AITER_PERF_REPO" "$AITER_PERF_DIR" +# Keep the runtime overlay narrow: this benchmark uses the updated ATOM image +# from amd-master.yaml and overlays ROCm/aiter#2998 for the DSv4 kernels. Install +# AITER before ATOM because the ATOM fork imports the DSv4 top-k/logits kernels +# at module load. +if [ "${AITER_DSV4_PR2998:-1}" = "1" ]; then + AITER_PR2998_REPO=${AITER_PR2998_REPO:-https://github.com/ROCm/aiter.git} + AITER_PR2998_REF=${AITER_PR2998_REF:-pull/2998/head} + AITER_PR2998_SHA=${AITER_PR2998_SHA:-a42ec8f6903a0eca016ff1a740e4b21e34ea5a5e} + AITER_PR2998_DIR=${AITER_PR2998_DIR:-/tmp/aiter-dsv4-pr2998} + + rm -rf "$AITER_PR2998_DIR" + git clone --filter=blob:none "$AITER_PR2998_REPO" "$AITER_PR2998_DIR" ( - cd "$AITER_PERF_DIR" - git fetch --depth=1 origin "$AITER_PERF_BASE_SHA" - git checkout --force "$AITER_PERF_BASE_SHA" - test "$(git rev-parse HEAD)" = "$AITER_PERF_BASE_SHA" - - for ref in "${AITER_PERF_PATCH_REFS[@]}"; do - # Do not use --depth=1 here. A shallow PR-head fetch hides the - # parent commit and makes git treat the cherry-pick as add/add - # conflicts across unrelated files. - git fetch origin "$ref" - git cherry-pick --no-commit FETCH_HEAD - done - - if [ "$AITER_DSV4_TUNED_FMOE" = "1" ]; then - mkdir -p "$(dirname "$AITER_DSV4_TUNED_FMOE_PATH")" - git fetch --depth=1 "$AITER_DSV4_TUNED_FMOE_REPO" "$AITER_DSV4_TUNED_FMOE_SHA" - test "$(git rev-parse FETCH_HEAD)" = "$AITER_DSV4_TUNED_FMOE_SHA" - git show "FETCH_HEAD:$AITER_DSV4_TUNED_FMOE_PATH" > "$AITER_DSV4_TUNED_FMOE_PATH" - grep -q '7168,512,385,6,ActivationType.Silu' "$AITER_DSV4_TUNED_FMOE_PATH" \ - || { echo "FATAL: DSv4 FP4 tuned fMoE rows not found in $AITER_DSV4_TUNED_FMOE_PATH"; exit 1; } + cd "$AITER_PR2998_DIR" + git fetch --depth=1 origin "$AITER_PR2998_REF" + fetched_sha="$(git rev-parse FETCH_HEAD)" + if [ "$fetched_sha" != "$AITER_PR2998_SHA" ]; then + echo "FATAL: $AITER_PR2998_REF resolved to $fetched_sha, expected $AITER_PR2998_SHA" >&2 + exit 1 fi + git checkout --force FETCH_HEAD if [ ! -d 3rdparty/composable_kernel/include ]; then git submodule update --init --recursive --depth=1 3rdparty/composable_kernel \ @@ -120,349 +61,207 @@ if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then python3 -m pip install --no-deps --no-build-isolation --force-reinstall -e . ) - if [ "$AITER_DSV4_TUNED_FMOE" = "1" ]; then - export AITER_DSV4_TUNED_FMOE_FILE="$AITER_PERF_DIR/$AITER_DSV4_TUNED_FMOE_PATH" - fi - if [ "$AITER_DSV4_TUNED_FMOE" = "1" ] && [ -z "${AITER_CONFIG_FMOE:-}" ]; then - export AITER_CONFIG_FMOE="$AITER_PERF_DIR/aiter/configs/tuned_fmoe.csv:$AITER_DSV4_TUNED_FMOE_FILE" - fi - python3 - <<'PYEOF' -import importlib.util -import csv -import os +import inspect +import ast from pathlib import Path -import aiter - -root = Path(aiter.__file__).resolve().parent -moe = (root / "fused_moe.py").read_text() -fp4_utils = (root / "utility" / "fp4_utils.py").read_text() -dsv4_tuned_fmoe = Path(os.environ["AITER_DSV4_TUNED_FMOE_FILE"]) if os.environ.get("AITER_DSV4_TUNED_FMOE_FILE") else None -required = { - "native MXFP4 MoE skip_inter_quant": "skip_inter_quant" in moe, - "MXFP4 scaleN_pad fix": "scaleN_pad" in fp4_utils, - "DSv4 FP4 tuned fMoE config": dsv4_tuned_fmoe is None or dsv4_tuned_fmoe.exists(), -} -missing = [name for name, ok in required.items() if not ok] +import aiter.ops.topk as topk_module +from aiter.ops.topk import top_k_per_row_decode, top_k_per_row_prefill +from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk +from aiter.ops.triton.attention.sparse_mqa_sink import sparse_mqa_sink +from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + +for fn in ( + top_k_per_row_decode, + top_k_per_row_prefill, + fp8_mqa_logits, + dsv4_indexer_topk, + sparse_mqa_sink, +): + if not callable(fn): + raise SystemExit(f"FATAL: imported non-callable AITER symbol: {fn!r}") + +indexer_params = inspect.signature(dsv4_indexer_topk).parameters +missing = [name for name in ("seq_ids", "kv_lens") if name not in indexer_params] if missing: - raise SystemExit(f"FATAL: AITER DSv4 perf stack verification failed: {missing}") - -if dsv4_tuned_fmoe is not None and dsv4_tuned_fmoe.exists(): - config_paths = os.environ.get("AITER_CONFIG_FMOE", "").split(":") - if str(dsv4_tuned_fmoe) not in config_paths: - print( - "WARN: AITER_CONFIG_FMOE was user-supplied and does not include " - f"{dsv4_tuned_fmoe}; DSv4 tuned fMoE rows may not be active." - ) - try: - from aiter.ops.flydsl import is_flydsl_available - except Exception as exc: - print(f"aiter DSv4 tuned fMoE installed; FlyDSL availability check failed: {exc!r}") - else: - flydsl_available = is_flydsl_available() - print(f"aiter FlyDSL available: {flydsl_available}") - if flydsl_available: - from aiter.ops.flydsl.moe_kernels import get_flydsl_kernel_params - - missing_kernels = set() - with dsv4_tuned_fmoe.open(newline="") as handle: - for row in csv.DictReader(handle): - for name in (row.get("kernelName1", ""), row.get("kernelName2", "")): - if name.startswith("flydsl_") and get_flydsl_kernel_params(name) is None: - missing_kernels.add(name) - if missing_kernels: - raise SystemExit( - "FATAL: DSv4 FP4 tuned fMoE references missing FlyDSL kernels: " - f"{sorted(missing_kernels)[:5]}" - ) -print(f"aiter DSv4 perf stack imported from: {root}") + raise SystemExit(f"FATAL: AITER PR2998 DSv4 Indexer API missing {missing}") +topk_tree = ast.parse(Path(topk_module.__file__).read_text()) +topk_defs = { + node.name: node for node in topk_tree.body if isinstance(node, ast.FunctionDef) +} +missing_k = [ + name + for name in ("top_k_per_row_decode", "top_k_per_row_prefill") + if name not in topk_defs + or not any(arg.arg == "k" for arg in topk_defs[name].args.args) +] +if missing_k: + raise SystemExit(f"FATAL: AITER top-k wrappers missing dynamic k parameter: {missing_k}") +print("AITER PR2998 DSv4 sparse/top-k/indexer ops imported successfully") PYEOF else - echo "WARN: AITER_DSV4_PERF_STACK=0; using image-provided aiter" + echo "WARN: AITER_DSV4_PR2998=0; using image-provided AITER" fi -# Ensure the pure-Python part of ROCm/aiter#2916 is present. The AITER perf -# stack above already includes it; this block is kept as a fallback for -# AITER_DSV4_PERF_STACK=0 or future images that ship aiter without the fix. -export AITER_MHC_FIX_SHA="76ea1ed5b2a5f8176ed7a16b1640dd972546a925" -python3 - <<'PYEOF' -import importlib.util -import os -import sys -from pathlib import Path - -required_snippets = [ - " device = residual.device\n out_pad = torch.empty(", - "selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device", - "sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device)", - "post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)", - "comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)", - "layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)", -] - -spec = importlib.util.find_spec("aiter.ops.mhc") -if spec is None or spec.origin is None: - sys.exit("FATAL: cannot locate installed aiter.ops.mhc for ROCm/aiter#2916 patch") - -mhc_path = Path(spec.origin) -source = mhc_path.read_text() - -if all(snippet in source for snippet in required_snippets): - print(f"aiter mhc device patch already present: {mhc_path}") - sys.exit(0) - -replacements = [ - ( - " out_pad = torch.empty(\n" - " selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32\n" - " )", - " device = residual.device\n" - " out_pad = torch.empty(\n" - " selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device\n" - " )", - ), - ( - " sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32)", - " sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device)", - ), - ( - " post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32)", - " post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)", - ), +# The updated ATOM image still does not ship DeepSeek-V4 model registration. +# Overlay the ATOM branch stacked on ROCm/ATOM#650 that wires the DSv4 Indexer +# path to ROCm/aiter#2998 top-k/logits kernels. +if [ "${ATOM_DSV4_PR650:-1}" = "1" ]; then + ATOM_PR650_REPO=${ATOM_PR650_REPO:-https://github.com/Oseltamivir/ATOM.git} + ATOM_PR650_REF=${ATOM_PR650_REF:-dsv4-deep-l0-diag} + ATOM_PR650_SHA=${ATOM_PR650_SHA:-d8583464086fd1a9899374e0777ede3b5c16ff5a} + ATOM_PR650_DIR=${ATOM_PR650_DIR:-/tmp/atom-dsv4-pr650} + + rm -rf "$ATOM_PR650_DIR" + git clone --filter=blob:none "$ATOM_PR650_REPO" "$ATOM_PR650_DIR" ( - " comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32)", - " comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)", - ), - ( - " layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16)", - " layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)", - ), -] - -missing = [old for old, _ in replacements if old not in source] -if missing: - sys.exit( - f"FATAL: {mhc_path} does not match the expected pre-ROCm/aiter#2916 " - f"source; refusing to patch mhc_pre blindly. Missing patterns: " - f"{[m.splitlines()[0].strip() for m in missing]}" - ) - -patched = source -for old, new in replacements: - patched = patched.replace(old, new, 1) - -mhc_path.write_text(patched) -patched_source = mhc_path.read_text() -if not all(snippet in patched_source for snippet in required_snippets): - sys.exit(f"FATAL: ROCm/aiter#2916 mhc device patch failed verification for {mhc_path}") - -print( - f"applied ROCm/aiter#2916 ({os.environ['AITER_MHC_FIX_SHA']}) " - f"mhc device patch: {mhc_path}" -) -PYEOF - -# Apply ROCm/ATOM#650 (DSv4 PR1 skeleton) over the image's wheel-installed -# atom. The chosen base image ships atom as a built wheel, not editable, so -# we overlay an editable install from the PR branch at a pinned SHA. Bump -# this SHA when the PR moves; do not track the branch tip (the run becomes -# a moving target if the branch is force-pushed). -ATOM_PR_SHA="cdbff359d3db7afd3801e28b38fc71253121ee84" -export ATOM_PR_DIR="/tmp/atom-pr650" + cd "$ATOM_PR650_DIR" + git fetch --depth=1 origin "$ATOM_PR650_REF" + fetched_sha="$(git rev-parse FETCH_HEAD)" + if [ "$fetched_sha" != "$ATOM_PR650_SHA" ]; then + echo "FATAL: $ATOM_PR650_REF resolved to $fetched_sha, expected $ATOM_PR650_SHA" >&2 + exit 1 + fi + git checkout --force FETCH_HEAD -if [ ! -d "$ATOM_PR_DIR/.git" ]; then - git clone --filter=blob:none https://github.com/ROCm/ATOM.git "$ATOM_PR_DIR" -fi -( - cd "$ATOM_PR_DIR" - # Try a targeted fetch first (fast); fall back to fetching the PR ref if - # the server doesn't allow fetching the SHA directly. - git fetch --depth=1 origin "$ATOM_PR_SHA" 2>/dev/null \ - || git fetch --depth=1 origin pull/650/head - git checkout --force "$ATOM_PR_SHA" - test "$(git rev-parse HEAD)" = "$ATOM_PR_SHA" - - # ROCm/aiter#2916 keeps ATOM's mhc_pre fast path usable. Fail if the - # pinned ATOM checkout no longer exposes that aiter hook; silently - # disabling it would hide the regression this benchmark is meant to catch. - grep -q 'mhc_pre = getattr(_aiter, "mhc_pre", None)' atom/models/deepseek_v4.py \ - || { echo "FATAL: ATOM DSv4 mhc_pre aiter hook not found"; exit 1; } - - # ROCm/ATOM#650 sparse_attn_v4.py is a correctness-first torch fallback. - # Add two local mitigations while we wait for a serving-compatible AITER - # sparse-attention kernel: - # 1. chunk prefill over the M dimension to keep temporary scores under - # memory pressure, making higher-conc experiments less likely to OOM; - # 2. use a B=1,M=1 decode fast path that avoids the fallback's large - # broadcast/mask/concat intermediates on every generated token. - python3 - <<'PYEOF' + python3 - <<'PYEOF' from pathlib import Path -path = Path("atom/model_ops/sparse_attn_v4.py") +v4_model_types = '("deepseek_v4", "deepseek_v4_pro", "deepseek_v4_flash")' + +path = Path("atom/model_engine/model_runner.py") source = path.read_text() -marker = "ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS" -if marker not in source: - source = source.replace( - "from typing import Tuple\n\nimport torch\n", - "from typing import Tuple\n\nimport os\n\nimport torch\n", - 1, - ) - old = """ out_dtype = q.dtype - device = q.device - - # ----- Gather KV per query position ----- -""" - new = """ out_dtype = q.dtype - device = q.device - - chunk_tokens = int(os.environ.get("ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS", "0") or "0") - if B == 1 and chunk_tokens > 0 and M > chunk_tokens: - return torch.cat( - [ - sparse_attn( - q[:, start : start + chunk_tokens], - kv, - attn_sink, - topk_idxs[:, start : start + chunk_tokens], - softmax_scale, - ) - for start in range(0, M, chunk_tokens) - ], - dim=1, +old = ''' def is_deepseek_v4(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + return self.hf_text_config.model_type == "deepseek_v4" +''' +new = f''' def is_deepseek_v4(self) -> bool: + model_type = getattr(self.hf_text_config, "model_type", None) + architectures = getattr(self.hf_text_config, "architectures", []) or [] + return model_type in {v4_model_types} or any( + "DeepseekV4" in arch for arch in architectures ) +''' +if old in source: + source = source.replace(old, new, 1) +elif "deepseek_v4_pro" not in source[source.find("def is_deepseek_v4"): source.find("def is_mimo_v2")]: + raise SystemExit("FATAL: model_runner.py is_deepseek_v4 did not match expected source") +old = ''' mt = self.config.hf_config.model_type + known = _IOProc._per_req_cache_model_types() # noqa: SLF001 + assert mt in known, ( + f"Attention builder {type(self.attn_metadata_builder).__name__} " + f"reports per_req_cache_bytes>0 but model_type={mt!r} is not in " + f"InputOutputProcessor.per_req_cache_model_types ({sorted(known)}). " + "Add it to the set or sequences will not be assigned slots " + "(silent corruption)." + ) +''' +new = f''' mt = self.config.hf_config.model_type + architectures = getattr(self.config.hf_config, "architectures", []) or [] + known = _IOProc._per_req_cache_model_types() # noqa: SLF001 + is_v4 = mt in {v4_model_types} or any( + "DeepseekV4" in arch for arch in architectures + ) + assert mt in known or is_v4, ( + f"Attention builder {{type(self.attn_metadata_builder).__name__}} " + f"reports per_req_cache_bytes>0 but model_type={{mt!r}} is not in " + f"InputOutputProcessor.per_req_cache_model_types ({{sorted(known)}}) " + "and is not a recognized DeepSeek-V4 architecture. Add it to " + "the set or sequences will not be assigned slots " + "(silent corruption)." + ) +''' +if old in source: + source = source.replace(old, new, 1) +elif "is not a recognized DeepSeek-V4 architecture" not in source: + raise SystemExit("FATAL: model_runner.py per-req cache assertion anchor missing") +path.write_text(source) - if B == 1 and M == 1: - valid_1d = topk_idxs[0, 0] != -1 - if not bool(valid_1d.any()): - return torch.zeros_like(q) - idx_1d = topk_idxs[0, 0] - if bool(valid_1d.all()): - kv_f32 = kv[0].index_select(0, idx_1d.long()).float() - else: - kv_f32 = kv[0].index_select(0, idx_1d[valid_1d].long()).float() - q_f32 = q[0, 0].float() - scores = torch.matmul(q_f32, kv_f32.transpose(0, 1)) * float(softmax_scale) - sink = attn_sink.float().view(H, 1) - cmax = torch.maximum(scores.amax(dim=-1, keepdim=True), sink) - exp_scores = (scores - cmax).exp() - denom = exp_scores.sum(dim=-1, keepdim=True) + (sink - cmax).exp() - out = (exp_scores / denom.clamp(min=1e-30)).matmul(kv_f32) - return out.view(1, 1, H, D).to(out_dtype) - - # ----- Gather KV per query position ----- -""" +path = Path("atom/model_engine/llm_engine.py") +source = path.read_text() +old = ''' "deepseek_v4", +''' +new = ''' "deepseek_v4", + "deepseek_v4_pro", + "deepseek_v4_flash", +''' +if "deepseek_v4_pro" not in source: if old not in source: - raise SystemExit("FATAL: sparse_attn_v4.py did not match expected PR650 source") + raise SystemExit("FATAL: llm_engine.py per-req cache model list anchor missing") + source = source.replace(old, new, 1) +old = ''' if self.config.hf_config.model_type in self._per_req_cache_model_types(): + self.has_per_req_cache = True +''' +new = ''' hf_model_type = getattr(self.config.hf_config, "model_type", None) + hf_architectures = getattr(self.config.hf_config, "architectures", []) or [] + if ( + hf_model_type in self._per_req_cache_model_types() + or any("DeepseekV4" in arch for arch in hf_architectures) + ): + self.has_per_req_cache = True +''' +if old in source: + source = source.replace(old, new, 1) +elif "hf_architectures = getattr(self.config.hf_config" not in source: + raise SystemExit("FATAL: llm_engine.py per-req cache detection anchor missing") +path.write_text(source) + +path = Path("atom/config.py") +source = path.read_text() +if '"deepseek_v4_pro": "deepseek_v3"' not in source: + anchor = ''' "deepseek_v4": "deepseek_v3", # V4 reuses V3 schema; V4-specific fields +''' + insert = ''' "deepseek_v4": "deepseek_v3", # V4 reuses V3 schema; V4-specific fields + "deepseek_v4_pro": "deepseek_v3", + "deepseek_v4_flash": "deepseek_v3", +''' + if anchor not in source: + raise SystemExit("FATAL: config.py V4 registry anchor missing") + source = source.replace(anchor, insert, 1) +old = ''' if getattr(self.hf_config, "model_type", None) == "deepseek_v4": +''' +new = f''' hf_model_type = getattr(self.hf_config, "model_type", None) + hf_architectures = getattr(self.hf_config, "architectures", []) or [] + if hf_model_type in {v4_model_types} or any( + "DeepseekV4" in arch for arch in hf_architectures + ): +''' +if old in source: source = source.replace(old, new, 1) - path.write_text(source) - print(f"applied DSv4 sparse_attn_v4 decode/chunk patch: {path}") -else: - print(f"DSv4 sparse_attn_v4 decode/chunk patch already present: {path}") +elif "hf_model_type in" not in source: + raise SystemExit("FATAL: config.py V4 block-size guard did not match expected source") +path.write_text(source) PYEOF - # --no-deps: don't churn the image's pinned ROCm/torch/triton/aiter. - # --force-reinstall: replace the wheel-installed atom with the editable copy. - pip install --no-deps --force-reinstall -e . -) - -# Install triton_kernels. The release atom0.1.2.post image cleans up -# /triton-test/ from the build stage, so it's typically absent. Fall back -# to ROCm/triton's RI3.5.x branch — NOT triton-lang/triton upstream: -# -# * Upstream triton-lang/triton refactored the matmul_ogs module into -# matmul.py (and removed routing.py). PR #650's fused_moe_triton.py -# imports `from triton_kernels.matmul_ogs import matmul_ogs, -# PrecisionConfig` and `from triton_kernels.routing import routing`, -# which only resolve against the ROCm fork's release-internal branch. -# * ROCm/triton RI3.5.x at e491726 has matmul_ogs.py (with PrecisionConfig -# and matmul_ogs), routing.py, CDNA4MXScaleLayout in layout.py (the -# class PR #650 imports), and target_info.py that imports only is_hip / -# is_hip_cdna3 / is_hip_cdna4 — no is_hip_gfx1250, which the image's -# bundled triton would reject. -# -# triton_kernels is a self-contained subpackage (pyproject deps: numpy, -# pytest); installing it does not perturb the image's triton itself. -# Bump only after AMD ships a newer ATOM image whose bundled triton -# exports is_hip_gfx1250, at which point we can move to a newer RI branch. -TRITON_KERNELS_SHA="e49172654d55f460c6fc24d77a3ea8a286bcaee8" -# --force-reinstall mirrors the atom install above: triton_kernels also ships -# as a wheel in the image, and without --force-reinstall pip can short-circuit -# the editable switch when name/version match, leaving the wheel build active. -if [ -d /triton-test/python/triton_kernels/ ]; then - pip install --no-deps --force-reinstall -e /triton-test/python/triton_kernels/ -else - TRITON_DIR="/tmp/rocm-triton" - if [ ! -d "$TRITON_DIR/.git" ]; then - git clone --filter=blob:none https://github.com/ROCm/triton.git "$TRITON_DIR" - fi - ( - cd "$TRITON_DIR" - git fetch --depth=1 origin "$TRITON_KERNELS_SHA" 2>/dev/null \ - || git fetch --depth=1 origin RI3.5.x - git checkout --force "$TRITON_KERNELS_SHA" - pip install --no-deps --force-reinstall -e python/triton_kernels/ + python3 -m pip install --no-deps --no-build-isolation --force-reinstall -e . ) -fi -# Preflight version checks. The chosen base image -# (atom0.1.2.post, rebuilt 2026-04-23) was tagged after ATOM pinned -# transformers==5.2.0 (commit 67d6cb61, 2026-03-13), so transformers compat -# is expected; we still assert it explicitly to fail fast with a clear -# message rather than timing out wait_for_server_ready on a confusing -# import error inside the server log. The two non-trivial deps the PR -# introduces are transformers' deepseek_v3 config class (mapped from -# deepseek_v4 in atom/config.py) and triton_kernels.CDNA4MXScaleLayout -# (renamed from GFX950MXScaleLayout in fused_moe_triton.py). -python3 - <<'PYEOF' -import importlib, os, sys -import atom - -# Verify the editable install actually took effect — Python could still be -# importing the wheel-installed atom if pip's --force-reinstall silently no-op'd -# (e.g., the wheel and the editable copy share a setup.py path mismatch). -atom_path = os.path.abspath(atom.__file__) -expected = os.path.abspath(os.environ["ATOM_PR_DIR"]) -print(f"atom imported from: {atom_path}") -if expected not in atom_path: - sys.exit(f"FATAL: atom is importing from {atom_path}, not from PR checkout {expected}. " - f"The pip --force-reinstall -e . did not take effect.") - -import transformers -print(f"transformers version: {transformers.__version__}") - -# Use CONFIG_MAPPING directly: AutoConfig.for_model() returns an instance -# (transformers 5.2.0 source: `return config_class(*args, **kwargs)`), not a -# class, so `.__name__` would AttributeError. CONFIG_MAPPING maps model_type -# to the config class directly and is unambiguous. -from transformers.models.auto.configuration_auto import CONFIG_MAPPING -if "deepseek_v3" not in CONFIG_MAPPING: - sys.exit(f"FATAL: transformers in this image cannot resolve deepseek_v3 model_type. " - f"ATOM PR #650 maps deepseek_v4 -> deepseek_v3 in _CONFIG_REGISTRY and needs " - f"transformers to know the v3 schema. Available types: " - f"{sorted(k for k in CONFIG_MAPPING if 'deepseek' in k)}") -print(f"deepseek_v3 config class: {CONFIG_MAPPING['deepseek_v3'].__name__}") - -try: - layout_mod = importlib.import_module("triton_kernels.tensor_details.layout") - if not hasattr(layout_mod, "CDNA4MXScaleLayout"): - avail = [n for n in dir(layout_mod) if "Layout" in n] - sys.exit(f"FATAL: triton_kernels.tensor_details.layout has no CDNA4MXScaleLayout. " - f"PR #650's fused_moe_triton.py change renamed GFX950MXScaleLayout -> " - f"CDNA4MXScaleLayout, but this image's triton_kernels still uses the old " - f"name. Available Layout classes: {avail}") - print("triton_kernels.CDNA4MXScaleLayout: present") -except ModuleNotFoundError as e: - sys.exit(f"FATAL: triton_kernels not importable. PR #650's MoE path needs it. Error: {e}") + python3 - <<'PYEOF' +import inspect +from atom.model_engine.model_runner import support_model_arch_dict +from atom.models.deepseek_v4 import Indexer + +target = support_model_arch_dict.get("DeepseekV4ForCausalLM") +if target != "atom.models.deepseek_v4.DeepseekV4ForCausalLM": + raise SystemExit(f"FATAL: DeepseekV4ForCausalLM maps to {target!r}") +source = inspect.getsource(Indexer.forward_batched) +source += inspect.getsource(Indexer) +missing = [ + name + for name in ("top_k_per_row_prefill", "top_k_per_row_decode", "fp8_mqa_logits") + if name not in source +] +if missing: + raise SystemExit(f"FATAL: ATOM DSv4 Indexer is not wired to AITER symbols {missing}") +print("ATOM DSv4 architecture registration and AITER top-k/logits wiring imported successfully") PYEOF +else + echo "WARN: ATOM_DSV4_PR650=0; using image-provided ATOM" +fi -# DSv4-Pro's native max_position_embeddings is 1,048,576 (1M tokens), so we -# can't leave --max-model-len blank for 1k1k the way the dsr1-atom scripts -# do — ATOM would allocate KV cache for 1M context and OOM during warmup -# (~240 GiB consumed before the dummy forward, then sparse_attn's -# torch.where wants another ~36 GiB and there isn't 36 GiB free). DSR1's -# native context is only 128k, which is why the same blank pattern works -# there. Set 1k1k explicitly; 8k1k retains the existing 10240 cap that's -# already running successfully. +# DSv4-Pro advertises a 1M native context. Set the benchmark context +# explicitly so ATOM does not reserve KV cache for the full native length. if [ "$ISL" = "1024" ] && [ "$OSL" = "1024" ]; then MAX_MODEL_LEN_VALUE=2304 else @@ -482,38 +281,61 @@ else EP=" " fi -# Start GPU monitoring (power, temperature, clocks every second) +if [[ -z "${ATOM_MAX_NUM_BATCHED_TOKENS:-}" ]]; then + # Keep ATOM startup/warmup bounded without carrying a fork-only warmup knob. + # Do not set this below ISL, otherwise an 8k prefill may never be admitted. + ATOM_MAX_NUM_BATCHED_TOKENS=4096 + if [ "$ISL" -gt "$ATOM_MAX_NUM_BATCHED_TOKENS" ]; then + ATOM_MAX_NUM_BATCHED_TOKENS="$ISL" + fi +fi + +if [ "${EVAL_ONLY:-false}" = "true" ] && [ "${ATOM_DSV4_COMPONENT_DIAG:-1}" = "1" ]; then + export ATOM_DSV4_DIAG_EQUIV="${ATOM_DSV4_DIAG_EQUIV:-1}" + export ATOM_DSV4_DIAG_LAYERS="${ATOM_DSV4_DIAG_LAYERS:-all}" + export ATOM_DSV4_DIAG_VERBOSE="${ATOM_DSV4_DIAG_VERBOSE:-1}" + export ATOM_DSV4_DIAG_TOKEN_LIMIT="${ATOM_DSV4_DIAG_TOKEN_LIMIT:-3}" + export ATOM_DSV4_DIAG_FULL_SEQ_LIMIT="${ATOM_DSV4_DIAG_FULL_SEQ_LIMIT:-128}" + export ATOM_DSV4_DEEP_ATTN_REF_DIAG="${ATOM_DSV4_DEEP_ATTN_REF_DIAG:-1}" + echo "DSv4 component diagnostics enabled: layers=${ATOM_DSV4_DIAG_LAYERS}, token_limit=${ATOM_DSV4_DIAG_TOKEN_LIMIT}, full_seq_limit=${ATOM_DSV4_DIAG_FULL_SEQ_LIMIT}, attn_ref=${ATOM_DSV4_DEEP_ATTN_REF_DIAG}" +fi + +run_dsv4_atom_eval_diagnostics() { + local diag_file="sample_dsv4_atom_eval_diag_${RESULT_FILENAME}.jsonl" + local diag_conc_list="${ATOM_DSV4_DIAG_CONCURRENCY_LIST:-1,$CONC}" + local diag_script + diag_script="$(dirname "$0")/../../utils/dsv4_atom_eval_diag.py" + if [ ! -f "$diag_script" ]; then + diag_script="/workspace/utils/dsv4_atom_eval_diag.py" + fi + echo "[DSv4 diag] Running concurrency isolation matrix; output: ${diag_file}" + DIAG_PORT="$PORT" \ + DIAG_MODEL="$MODEL" \ + DIAG_CONC_LIST="$diag_conc_list" \ + DIAG_ISL="$ISL" \ + DIAG_OUT="$diag_file" \ + python3 "$diag_script" +} + start_gpu_monitor set -x -BLOCK_SIZE=${BLOCK_SIZE:-16} -export ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS=${ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS:-256} -# --enforce-eager is required: ROCm/ATOM#650 (PR1 skeleton) has no CUDAGraph -# support yet (deferred to a follow-up PR). max-num-seqs is sized to the -# client concurrency with a floor at 4 — the ATOM default (512) makes the -# KV/GDN-mamba allocator overshoot the GPU budget ("GDN mamba tensor -# exceeds available KV budget"), and using 1 hangs warmup at 0% GPU. 4 -# is the minimum we've seen complete warmup successfully (also the PR's -# offline repro value). The PR1 kv_cache[:1,...] hardcode in -# deepseek_v4.py means any forward with batch>1 silently corrupts -# non-slot-0 lanes; eval (gsm8k) at conc>1 is the canary. -MAX_NUM_SEQS=$(( CONC < 4 ? 4 : CONC )) -MAX_NUM_BATCHED_TOKENS=${MAX_NUM_BATCHED_TOKENS:-$MAX_MODEL_LEN_VALUE} +BLOCK_SIZE=${BLOCK_SIZE:-128} +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ - --model $MODEL \ - --server-port $PORT \ - -tp $TP \ + --model "$MODEL" \ + --server-port "$PORT" \ + -tp "$TP" \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ - --block-size $BLOCK_SIZE \ + --block-size "$BLOCK_SIZE" \ + --max-num-batched-tokens "$ATOM_MAX_NUM_BATCHED_TOKENS" \ --enforce-eager \ - --max-num-seqs $MAX_NUM_SEQS \ - --max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \ - --trust-remote-code > $SERVER_LOG 2>&1 & + --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" > "$SERVER_LOG" 2>&1 & SERVER_PID=$! -# Wait for server to be ready wait_for_server_ready --port "$PORT" --server-log "$SERVER_LOG" --server-pid "$SERVER_PID" run_benchmark_serving \ @@ -527,14 +349,16 @@ run_benchmark_serving \ --max-concurrency "$CONC" \ --result-filename "$RESULT_FILENAME" \ --result-dir /workspace/ \ + --server-pid "$SERVER_PID" \ --trust-remote-code -# After throughput, run evaluation only if RUN_EVAL is true if [ "${RUN_EVAL}" = "true" ]; then - run_eval --framework lm-eval --port "$PORT" + run_eval --framework lm-eval --port "$PORT" --limit "${EVAL_LIMIT:-640}" append_lm_eval_summary + if [ "${ATOM_DSV4_EVAL_DIAG:-1}" = "1" ]; then + run_dsv4_atom_eval_diagnostics || echo "WARN: DSv4 eval diagnostics failed" >&2 + fi fi -# Stop GPU monitoring stop_gpu_monitor set +x diff --git a/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh b/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh index 036346af3..410743d1b 100644 --- a/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x pip install -U transformers +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -51,6 +52,7 @@ python3 -m atom.entrypoints.openai_server \ --gpu-memory-utilization $MEM_FRAC_STATIC \ --default-chat-template-kwargs '{"enable_thinking": false}' \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/glm5_fp8_mi355x_atom.sh b/benchmarks/single_node/glm5_fp8_mi355x_atom.sh index 036346af3..410743d1b 100644 --- a/benchmarks/single_node/glm5_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/glm5_fp8_mi355x_atom.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x pip install -U transformers +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -51,6 +52,7 @@ python3 -m atom.entrypoints.openai_server \ --gpu-memory-utilization $MEM_FRAC_STATIC \ --default-chat-template-kwargs '{"enable_thinking": false}' \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh b/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh index ee0810e8f..0298537f8 100644 --- a/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh @@ -50,13 +50,15 @@ set -x BLOCK_SIZE=${BLOCK_SIZE:-16} export ATOM_GPT_OSS_MODEL=1 #TODO remove this +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --gpu-memory-utilization $MEM_FRAC_STATIC \ - --block-size $BLOCK_SIZE > $SERVER_LOG 2>&1 & + --block-size $BLOCK_SIZE \ + "${ATOM_PROFILE_ARGS[@]}" > $SERVER_LOG 2>&1 & SERVER_PID=$! @@ -83,4 +85,4 @@ fi # Stop GPU monitoring stop_gpu_monitor -set +x \ No newline at end of file +set +x diff --git a/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh b/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh index ca84f8228..0eeb3e6aa 100755 --- a/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh @@ -42,12 +42,14 @@ start_gpu_monitor set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh b/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh index ca84f8228..0eeb3e6aa 100644 --- a/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh @@ -42,12 +42,14 @@ start_gpu_monitor set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh b/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh index 2a8c67da0..f9bb0d5cd 100755 --- a/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -50,6 +51,7 @@ python3 -m atom.entrypoints.openai_server \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --gpu-memory-utilization $MEM_FRAC_STATIC \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh index 2a8c67da0..f9bb0d5cd 100644 --- a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -50,6 +51,7 @@ python3 -m atom.entrypoints.openai_server \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --gpu-memory-utilization $MEM_FRAC_STATIC \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh index 9399fe792..8110ac124 100644 --- a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh +++ b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -52,6 +53,7 @@ python3 -m atom.entrypoints.openai_server \ --method mtp \ --num-speculative-tokens 3 \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 98002a100..b072f3d31 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2221,3 +2221,14 @@ - "Update the TensorRT-LLM DeepSeek-V4-Pro image to ghcr.io/semianalysisai/trtllm-deepseek-v4:feat-deepseek_v4-9aa3715" - "Enable TRTLLM fused MHC by default with the DeepSeek-V4 feature image" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1270 + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Update DSv4 ATOM to rocm/atom:rocm7.2.2_ubuntu24.04_py3.12_pytorch_release_2.10.0_atom0.1.2.post" + - "Overlay Oseltamivir/ATOM@3ed6633 stacked on ROCm/ATOM#650 and ROCm/aiter#2998@969863a for the current DSv4 AITER indexer/top-k path, structured eval answer trimming, deep L0 attention diagnostics, and split wq_a/wkv projections" + - "Remove temporary DSv4 ATOM/AITER runtime patches, perf-stack cherry-picks, dense fast paths, and shortened smoke-run defaults" + - "Lower DSv4 ATOM --max-num-batched-tokens default to max(4096, ISL) instead of carrying a fork-only warmup cap" + - "Retest higher DSv4 ATOM concurrency points: 1k1k conc 1-8, 8k1k conc 1-4" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1260 + evals-only: true diff --git a/utils/bench_serving/backend_request_func.py b/utils/bench_serving/backend_request_func.py index 7f4a93284..8cd59da60 100644 --- a/utils/bench_serving/backend_request_func.py +++ b/utils/bench_serving/backend_request_func.py @@ -273,6 +273,12 @@ async def async_request_openai_completions( async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: + if api_url.endswith("profile"): + output.latency = time.perf_counter() - st + output.generated_text = await response.text() + output.success = True + return output + first_chunk_received = False async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() diff --git a/utils/bench_serving/benchmark_serving.py b/utils/bench_serving/benchmark_serving.py index 68887c59b..9ed7d40e8 100644 --- a/utils/bench_serving/benchmark_serving.py +++ b/utils/bench_serving/benchmark_serving.py @@ -532,13 +532,14 @@ async def warmup_limited_req_fn(): if profile: print("Starting profiler...") + profile_num_steps = int(os.environ.get("PROFILE_NUM_STEPS", "1")) profile_input = RequestFuncInput(model=model_id, model_name=model_name, prompt=test_prompt, api_url=base_url + "/start_profile", prompt_len=test_prompt_len, output_len=test_output_len, - extra_body={"num_steps": 1, "merge_profiles": True, "profile_by_stage": True}, + extra_body={"num_steps": profile_num_steps, "merge_profiles": True, "profile_by_stage": True}, logprobs=logprobs, best_of=best_of, multi_modal_content=test_mm_content, diff --git a/utils/dsv4_atom_eval_diag.py b/utils/dsv4_atom_eval_diag.py new file mode 100644 index 000000000..80958f472 --- /dev/null +++ b/utils/dsv4_atom_eval_diag.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python3 +"""DeepSeek-V4 ATOM concurrency diagnostic probes. + +The eval failure seen at high concurrency can come from several places: +batched prefill metadata, decode KV/cache state, long-context sparse +attention/indexer paths, or cross-request leakage. This script issues a small +matrix of deterministic completion requests and writes both per-request rows +and per-case summaries to JSONL so the failure mode is visible from artifacts. +""" + +from __future__ import annotations + +import concurrent.futures +import hashlib +import json +import os +import re +import statistics +import time +import urllib.error +import urllib.request +from typing import Any + + +MARKER_PREFIX = "DSV4MARK" + + +def _parse_levels(raw: str) -> list[int]: + levels: list[int] = [] + for item in raw.split(","): + item = item.strip() + if not item: + continue + try: + levels.append(max(1, int(item))) + except ValueError: + pass + levels = sorted(set(levels)) or [1] + if 1 not in levels: + levels.insert(0, 1) + return levels + + +def _chat_prompt(body: str) -> str: + tok_bos = "<\uff5cbegin\u2581of\u2581sentence\uff5c>" + tok_user = "<\uff5cUser\uff5c>" + tok_assistant = "<\uff5cAssistant\uff5c>" + return f"{tok_bos}{tok_user}{body}{tok_assistant}" + + +def _first_tokenish(text: str) -> str: + stripped = text.lstrip() + if not stripped: + return "" + return stripped.split(maxsplit=1)[0][:40] + + +def _extract_answer(text: str) -> str | None: + match = re.search(r"####\s*\$?(-?\d+(?:\.\d+)?)", text) + return match.group(1) if match else None + + +def _marker_for(request_id: int) -> str: + return f"{MARKER_PREFIX}{request_id:04d}" + + +def _marker_prompt(prompt_kind: str, pad: str, request_id: int) -> tuple[str, str]: + marker = _marker_for(request_id) + body = ( + f"The marker for this request is {marker}.\n" + f"Output exactly this marker and nothing else: {marker}\n" + "Answer:" + ) + if prompt_kind == "long": + body = pad + "\n\n" + body + return _chat_prompt(body), marker + + +def _case_matrix(levels: list[int], isl: int) -> tuple[list[dict[str, Any]], str]: + math_body = ( + "Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast " + "every morning and bakes muffins for her friends every day with four. " + "She sells the remainder at the farmers' market daily for $2 per fresh " + "duck egg. How much in dollars does she make every day at the farmers' " + "market?\n" + "End your response with the answer on the last line, formatted as: #### [number]\n" + "Answer:" + ) + pad_units = int(os.environ.get("ATOM_DSV4_DIAG_LONG_PAD_UNITS", "0") or "0") + if pad_units <= 0: + pad_units = min(max(isl // 16, 64), 800) + pad = " ".join( + f"Reference filler sentence {i}: keep this context unchanged." + for i in range(pad_units) + ) + short_math_prompt = _chat_prompt(math_body) + long_math_prompt = _chat_prompt( + pad + + "\n\nUse only the final question below; the preceding filler is irrelevant.\n" + + math_body + ) + decode_tokens = int(os.environ.get("ATOM_DSV4_DIAG_DECODE_TOKENS", "32")) + marker_tokens = int(os.environ.get("ATOM_DSV4_DIAG_MARKER_TOKENS", "8")) + max_level = max(levels) + cases = [ + { + "name": "short_identical_1tok", + "mode": "identical", + "prompt_kind": "short", + "prompt": short_math_prompt, + "max_tokens": 1, + "levels": levels, + }, + { + "name": "short_identical_decode", + "mode": "identical", + "prompt_kind": "short", + "prompt": short_math_prompt, + "max_tokens": decode_tokens, + "levels": levels, + }, + { + "name": "long_identical_1tok", + "mode": "identical", + "prompt_kind": "long", + "prompt": long_math_prompt, + "max_tokens": 1, + "levels": levels, + }, + { + "name": "long_identical_decode", + "mode": "identical", + "prompt_kind": "long", + "prompt": long_math_prompt, + "max_tokens": decode_tokens, + "levels": levels, + }, + { + "name": "short_distinct_marker", + "mode": "marker", + "prompt_kind": "short", + "max_tokens": marker_tokens, + "levels": [max_level], + }, + { + "name": "long_distinct_marker", + "mode": "marker", + "prompt_kind": "long", + "max_tokens": marker_tokens, + "levels": [max_level], + }, + ] + case_filter = os.environ.get("ATOM_DSV4_DIAG_CASES", "").strip() + if case_filter: + wanted = {item.strip() for item in case_filter.split(",") if item.strip()} + cases = [case for case in cases if case["name"] in wanted] + return cases, pad + + +def _one_request( + *, + case: dict[str, Any], + level: int, + request_id: int, + model: str, + url: str, + stop: list[str], + pad: str, +) -> dict[str, Any]: + if case["mode"] == "marker": + prompt, expected_marker = _marker_prompt(case["prompt_kind"], pad, request_id) + else: + prompt = case["prompt"] + expected_marker = None + payload = { + "model": model, + "prompt": prompt, + "max_tokens": case["max_tokens"], + "temperature": 0, + "top_p": 1, + "stop": stop, + } + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + method="POST", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer EMPTY", + }, + ) + started = time.time() + try: + with urllib.request.urlopen(req, timeout=300) as resp: + body = json.loads(resp.read().decode("utf-8")) + choice = body["choices"][0] + text = choice.get("text", "") + markers = sorted(set(re.findall(rf"{MARKER_PREFIX}\d{{4}}", text))) + wrong_markers = [ + marker + for marker in markers + if expected_marker is not None and marker != expected_marker + ] + return { + "kind": "request", + "case": case["name"], + "mode": case["mode"], + "prompt_kind": case["prompt_kind"], + "max_tokens": case["max_tokens"], + "level": level, + "request_id": request_id, + "ok": True, + "latency_s": round(time.time() - started, 3), + "sha256": hashlib.sha256(text.encode("utf-8")).hexdigest(), + "prompt_sha256": hashlib.sha256(prompt.encode("utf-8")).hexdigest(), + "length": len(text), + "has_final_answer": "####" in text, + "answer": _extract_answer(text), + "first_tokenish": _first_tokenish(text), + "finish_reason": choice.get("finish_reason"), + "expected_marker": expected_marker, + "contains_expected_marker": ( + expected_marker in text if expected_marker is not None else None + ), + "markers_seen": markers, + "wrong_markers": wrong_markers, + "text": text, + } + except Exception as exc: + if isinstance(exc, urllib.error.HTTPError): + detail = exc.read().decode("utf-8", "replace") + else: + detail = "" + return { + "kind": "request", + "case": case["name"], + "mode": case["mode"], + "prompt_kind": case["prompt_kind"], + "max_tokens": case["max_tokens"], + "level": level, + "request_id": request_id, + "ok": False, + "latency_s": round(time.time() - started, 3), + "error": repr(exc), + "detail": detail[:2000], + } + + +def _summarize_case( + case: dict[str, Any], + level: int, + rows: list[dict[str, Any]], + baseline: dict[str, Any] | None, +) -> dict[str, Any]: + ok_rows = [row for row in rows if row.get("ok")] + errors = [row for row in rows if not row.get("ok")] + hashes = sorted({row.get("sha256") for row in ok_rows}) + firsts = sorted({row.get("first_tokenish") for row in ok_rows}) + latencies = [row["latency_s"] for row in ok_rows if "latency_s" in row] + summary: dict[str, Any] = { + "kind": "summary", + "case": case["name"], + "mode": case["mode"], + "prompt_kind": case["prompt_kind"], + "max_tokens": case["max_tokens"], + "level": level, + "ok": len(ok_rows), + "total": len(rows), + "errors": len(errors), + "unique_outputs": len(hashes), + "unique_first_tokenish": len(firsts), + "first_tokenish_values": firsts[:12], + "latency_s_mean": round(statistics.mean(latencies), 3) if latencies else None, + "latency_s_max": round(max(latencies), 3) if latencies else None, + } + if case["mode"] == "identical": + baseline_hash = baseline.get("sha256") if baseline else None + baseline_first = baseline.get("first_tokenish") if baseline else None + summary.update( + { + "baseline_sha256": baseline_hash, + "baseline_first_tokenish": baseline_first, + "drift_vs_baseline": sum( + 1 + for row in ok_rows + if baseline_hash is not None and row.get("sha256") != baseline_hash + ), + "first_token_drift_vs_baseline": sum( + 1 + for row in ok_rows + if baseline_first is not None + and row.get("first_tokenish") != baseline_first + ), + "missing_final": [ + row["request_id"] + for row in ok_rows + if case["max_tokens"] > 1 and not row.get("has_final_answer") + ], + "answers": sorted( + { + row.get("answer") + for row in ok_rows + if row.get("answer") is not None + } + )[:12], + } + ) + else: + missing = [ + row["request_id"] + for row in ok_rows + if row.get("contains_expected_marker") is False + ] + wrong = [row["request_id"] for row in ok_rows if row.get("wrong_markers")] + summary.update( + { + "missing_expected_marker": missing, + "wrong_marker_requests": wrong, + "wrong_markers_seen": sorted( + { + marker + for row in ok_rows + for marker in row.get("wrong_markers", []) + } + )[:24], + } + ) + return summary + + +def _print_summary(summary: dict[str, Any]) -> None: + if summary["mode"] == "identical": + print( + "[DSv4 diag] " + f"case={summary['case']} level={summary['level']} " + f"ok={summary['ok']}/{summary['total']} " + f"unique_outputs={summary['unique_outputs']} " + f"first_token_drift={summary['first_token_drift_vs_baseline']} " + f"drift={summary['drift_vs_baseline']} " + f"missing_final={summary.get('missing_final', [])[:16]}" + ) + else: + print( + "[DSv4 diag] " + f"case={summary['case']} level={summary['level']} " + f"ok={summary['ok']}/{summary['total']} " + f"missing_marker={summary['missing_expected_marker'][:16]} " + f"wrong_marker={summary['wrong_marker_requests'][:16]} " + f"unique_outputs={summary['unique_outputs']}" + ) + + +def _print_snippets(rows: list[dict[str, Any]]) -> None: + for row in rows[: min(4, len(rows))]: + snippet = row.get("text", row.get("error", "")).replace("\n", " ")[:260] + print( + "[DSv4 diag] " + f" case={row.get('case')} req={row['request_id']} ok={row.get('ok')} " + f"first={row.get('first_tokenish')!r} len={row.get('length')} " + f"sha={row.get('sha256', '')[:12]} markers={row.get('markers_seen')} " + f"snippet={snippet!r}" + ) + + +def _diagnosis(summaries: list[dict[str, Any]], max_level: int) -> list[str]: + by_case = {(s["case"], s["level"]): s for s in summaries} + short_1 = by_case.get(("short_identical_1tok", max_level), {}) + short_decode = by_case.get(("short_identical_decode", max_level), {}) + long_1 = by_case.get(("long_identical_1tok", max_level), {}) + long_decode = by_case.get(("long_identical_decode", max_level), {}) + short_marker = by_case.get(("short_distinct_marker", max_level), {}) + long_marker = by_case.get(("long_distinct_marker", max_level), {}) + notes: list[str] = [] + if short_1.get("first_token_drift_vs_baseline", 0): + notes.append( + "short 1-token drift: corruption happens by final prefill logits; " + "suspect batched prefill metadata, positions/slot mapping, sampler, " + "or common MHC/FFN path before decode KV growth" + ) + elif short_decode.get("drift_vs_baseline", 0): + notes.append( + "short multi-token drift but 1-token stable: suspect decode KV/cache " + "state update or per-step scheduling rather than initial prefill" + ) + if ( + long_1.get("first_token_drift_vs_baseline", 0) + and not short_1.get("first_token_drift_vs_baseline", 0) + ): + notes.append( + "long-only 1-token drift: suspect long-context DSv4 attention/indexer/" + "compressor path rather than generic batching" + ) + if ( + long_decode.get("drift_vs_baseline", 0) + and not short_decode.get("drift_vs_baseline", 0) + ): + notes.append( + "long-only decode drift: suspect sparse attention/indexer/cache growth " + "after prefill" + ) + if short_marker.get("wrong_marker_requests") or long_marker.get( + "wrong_marker_requests" + ): + notes.append( + "wrong marker copied from another request: direct evidence of " + "cross-request leakage" + ) + if short_marker.get("missing_expected_marker") and not short_marker.get( + "wrong_marker_requests" + ): + notes.append( + "short marker missing without wrong marker: request output is unstable " + "but not obviously copying another request" + ) + if long_marker.get("missing_expected_marker") and not long_marker.get( + "wrong_marker_requests" + ): + notes.append( + "long marker missing without wrong marker: long-context prompt " + "conditioning or attention may be corrupted" + ) + if not notes: + notes.append("diagnostic matrix did not reproduce a clear failure") + return notes + + +def main() -> int: + port = os.environ["DIAG_PORT"] + model = os.environ["DIAG_MODEL"] + out_path = os.environ["DIAG_OUT"] + isl = int(os.environ.get("DIAG_ISL", "8192")) + levels = _parse_levels(os.environ.get("DIAG_CONC_LIST", "1,2,4,8,16")) + cases, pad = _case_matrix(levels, isl) + max_level = max(levels) + stop = [ + "<\uff5cend\u2581of\u2581sentence\uff5c>", + "<\uff5cUser\uff5c>", + "<\uff5cAssistant\uff5c>", + "", + "<|im_end|>", + ] + url = f"http://127.0.0.1:{port}/v1/completions" + summaries: list[dict[str, Any]] = [] + baselines: dict[str, dict[str, Any]] = {} + + with open(out_path, "w", encoding="utf-8") as out: + for case in cases: + print( + "[DSv4 diag] " + f"starting case={case['name']} mode={case['mode']} " + f"prompt={case['prompt_kind']} max_tokens={case['max_tokens']} " + f"levels={case['levels']}" + ) + for level in case["levels"]: + with concurrent.futures.ThreadPoolExecutor(max_workers=level) as pool: + rows = list( + pool.map( + lambda i: _one_request( + case=case, + level=level, + request_id=i, + model=model, + url=url, + stop=stop, + pad=pad, + ), + range(level), + ) + ) + if case["mode"] == "identical" and level == 1: + ok = [row for row in rows if row.get("ok")] + if ok: + baselines[case["name"]] = ok[0] + summary = _summarize_case( + case, level, rows, baselines.get(case["name"]) + ) + summaries.append(summary) + _print_summary(summary) + should_print = ( + summary["errors"] + or summary["unique_outputs"] > 1 + or summary.get("drift_vs_baseline", 0) + or summary.get("first_token_drift_vs_baseline", 0) + or summary.get("missing_final") + or summary.get("missing_expected_marker") + or summary.get("wrong_marker_requests") + ) + if should_print: + _print_snippets(rows) + out.write(json.dumps(summary, ensure_ascii=True) + "\n") + for row in rows: + out.write(json.dumps(row, ensure_ascii=True) + "\n") + + notes = _diagnosis(summaries, max_level) + final = {"kind": "diagnosis", "max_level": max_level, "notes": notes} + print("[DSv4 diag] diagnosis=" + " | ".join(notes)) + out.write(json.dumps(final, ensure_ascii=True) + "\n") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/utils/evals/gsm8k.yaml b/utils/evals/gsm8k.yaml index e748119cd..24c7db12b 100644 --- a/utils/evals/gsm8k.yaml +++ b/utils/evals/gsm8k.yaml @@ -26,6 +26,10 @@ generation_kwargs: until: - "" - "<|im_end|>" + - "<|end▁of▁sentence|>" + - "<|User|>" + - "<|Assistant|>" + - "\u202e" do_sample: false temperature: 0.0 repeats: 1 diff --git a/utils/matrix_logic/generate_sweep_configs.py b/utils/matrix_logic/generate_sweep_configs.py index f7b4cca3b..cbfc1c241 100644 --- a/utils/matrix_logic/generate_sweep_configs.py +++ b/utils/matrix_logic/generate_sweep_configs.py @@ -20,7 +20,7 @@ "8k1k": (8192, 1024) } -MIN_EVAL_CONC = 16 +MIN_EVAL_CONC = 4 # Reverse mapping for exp-name generation seq_len_itos = {v: k for k, v in seq_len_stoi.items()} @@ -53,10 +53,21 @@ def mark_eval_entries(matrix_values: list[dict]) -> list[dict]: eval_indices = set() mn_eval_conc = {} # index -> chosen eval concurrency for multinode entries + def _min_eval_conc(entry): + # DSv4 ATOM still needs conc=1 eval smoke coverage while its + # batched path is under active upstreaming. + if ( + entry.get(Fields.MODEL_PREFIX.value) == "dsv4" + and entry.get(Fields.FRAMEWORK.value) == "atom" + and entry.get(Fields.RUNNER.value) == "mi355x" + ): + return 1 + return MIN_EVAL_CONC + def _eligible_eval_concs(entry): conc = entry[Fields.CONC.value] conc_values = conc if isinstance(conc, list) else [conc] - return sorted(c for c in conc_values if c >= MIN_EVAL_CONC) + return sorted(c for c in conc_values if c >= _min_eval_conc(entry)) def _max_eval_conc(ie): return max(_eligible_eval_concs(ie[1])) diff --git a/utils/matrix_logic/test_generate_sweep_configs.py b/utils/matrix_logic/test_generate_sweep_configs.py index 297e57524..305edea1d 100644 --- a/utils/matrix_logic/test_generate_sweep_configs.py +++ b/utils/matrix_logic/test_generate_sweep_configs.py @@ -311,7 +311,7 @@ def test_multi_node_eval_conc_uses_only_conc_values_at_or_above_min_conc(self): "ep": 1, "dp-attn": False, }, - "conc": [8, 16, 32], + "conc": [MIN_EVAL_CONC // 2, MIN_EVAL_CONC, MIN_EVAL_CONC * 2], }, { "model": "deepseek-ai/DeepSeek-R1-0528", @@ -333,14 +333,14 @@ def test_multi_node_eval_conc_uses_only_conc_values_at_or_above_min_conc(self): "ep": 1, "dp-attn": False, }, - "conc": [8], + "conc": [MIN_EVAL_CONC // 2], }, ] result = mark_eval_entries(matrix_values) assert result[0]["run-eval"] is True - assert result[0]["eval-conc"] == 32 + assert result[0]["eval-conc"] == MIN_EVAL_CONC * 2 assert result[1]["run-eval"] is False def test_marks_highest_and_median_conc(self): @@ -1970,4 +1970,3 @@ def test_prefill_entries_never_in_single_or_evals(self, mixed_entries): assert all('prefill' in x for x in multi) assert all('prefill' not in x for x in single) assert all('prefill' not in x for x in evals) - diff --git a/utils/mfu_lib.py b/utils/mfu_lib.py new file mode 100644 index 000000000..c7f5ea58b --- /dev/null +++ b/utils/mfu_lib.py @@ -0,0 +1,2204 @@ +""" +Shared library for MFU (Model FLOPS Utilization) trace analysis. + +This module encapsulates the core functionality required to analyse PyTorch +profiler traces for GEMM operations, grouped GEMM (MoE) operations, +communication overlap and network roofline. The goal of this library is +to centralise all common logic so that the command‑line interface simply +parses arguments and dispatches work to the routines defined here. + +The original implementation of the MFU trace analyser contained a large +amount of duplicated code spread across multiple functions. This module +breaks that monolith into reusable components whilst preserving the +behaviour and output of the original script. In particular it exposes +dataclasses for configuration and result types, helper functions for +extracting dimensions from CPU operations and GPU kernels, routines for +computing FLOPs, bytes and roofline metrics, and high level analysis +functions for GEMM kernels, grouped GEMM kernels, layer breakdown, +communication overlap and network rooflines. A summary printer is also +provided. + +Clients should instantiate a :class:`Config` to describe the model +architecture (hidden size, number of experts, etc.) and select a +:class:`GPUSpecs` entry from :data:`GPU_SPECS`. These objects are then +passed to the analysis routines. +""" + +from __future__ import annotations + +import json +import gzip +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Tuple, Optional, Any, Iterable +from collections import defaultdict + +############################################################################### +# Data classes +############################################################################### + +@dataclass +class GPUSpecs: + """GPU specifications for MFU/MBU calculation. + + The fields mirror those used in the original script. Peak TFLOPS are + provided per data type along with memory bandwidth, number of SMs and + cache sizes. NVLink bandwidth is given in GB/s and assumed bidirectional. + """ + + name: str + fp16_tflops: float + fp8_tflops: float + fp4_tflops: float + memory_bw_tb_s: float + num_sms: int + l2_cache_mb: float = 50.0 + nvlink_bw_gb_s: float = 900.0 + + +@dataclass +class Config: + """Configuration for model dimensions and parallelism. + + This dataclass captures the parameters that vary between models and runs. + It replaces the global ``MODEL_CONFIG`` dictionary used in the original + implementation. All routines that need knowledge of the model (e.g. for + inferring CUDA graph dimensions or computing expert sizes) accept a + :class:`Config` instance. + + In addition to tensor parallelism (TP), this configuration introduces + ``ep_degree`` to describe expert parallelism (EP). When EP is enabled, + experts are partitioned across ``ep_degree`` groups and each group holds + ``num_experts/ep_degree`` experts. Many memory and FLOP calculations + depend on the number of local experts, which can be obtained via + :pyattr:`local_experts`. + """ + + hidden_size: int = 7168 + num_experts: int = 256 + expert_intermediate_size: int = 2048 # Total intermediate size before TP division + decode_batch_size: int = 64 + tp_degree: int = 8 + ep_degree: int = 1 # Expert parallelism degree (EP). Use 1 for TP-only. + expert_saturation_scale: float = 70.0 # Controls saturation speed + + # Default model precision for GEMM outputs. + # + # Many models accumulate results in BF16 by default even when the inputs + # are lower precision (e.g. FP8). This field allows callers to + # override the fallback output dtype used when the profiler does not + # provide an explicit C dtype. For example, specifying + # ``model_dtype='fp16'`` will cause fallback kernels to be treated as + # producing FP16 outputs rather than the default BF16. This value is + # normalised via :func:`normalize_dtype` when used. + model_dtype: str = 'bf16' + + @property + def expert_intermediate_per_gpu(self) -> int: + """Intermediate size per GPU after tensor parallelism division.""" + return self.expert_intermediate_size // max(self.tp_degree, 1) + + @property + def local_experts(self) -> int: + """Number of experts resident on a single EP rank. + + With expert parallelism, the total number of experts is partitioned + across ``ep_degree`` groups. Each group holds ``num_experts / ep_degree`` + experts. When ``ep_degree=1``, this property simply returns + ``num_experts``, meaning that all experts are present on each rank (pure TP). + """ + return max(self.num_experts // max(self.ep_degree, 1), 1) + + +@dataclass +class KernelClassification: + """Classification result for a GPU kernel.""" + category: str + subcategory: str + is_gemm: bool + dtype: str + source: str + + +@dataclass +class GemmInfo: + """Information about an analysed GEMM kernel.""" + + m: int + n: int + k: int + dtype: str + input_dtype: str = "" + output_dtype: str = "" + a_dtype: str = "" + b_dtype: str = "" + c_dtype: str = "" + duration_us: float = 0.0 + flops: int = 0 + tflops: float = 0.0 + mfu: float = 0.0 + bytes_accessed: int = 0 + achieved_bw_tb_s: float = 0.0 + mbu: float = 0.0 + arithmetic_intensity: float = 0.0 + roofline_tflops: float = 0.0 + roofline_bound: str = "" + kernel_name: str = "" + external_id: int = 0 + layer_type: str = "" + activation_bytes: int = 0 + weight_bytes: int = 0 + effective_mbu: float = 0.0 + l2_cache_benefit: float = 0.0 + timestamp_us: float = 0.0 + correlation_id: int = 0 + tp_rank: str = "" + stream_id: int = 0 + + +@dataclass +class GroupedGemmInfo: + """Information about a grouped GEMM operation (e.g. fused MoE).""" + + num_tokens: int + top_k: int + num_experts: int + hidden_size: int + w1_intermediate: int + w2_intermediate: int + input_dtype: str = "bf16" + weight_dtype: str = "fp8" + output_dtype: str = "bf16" + total_token_expert_pairs: int = 0 + w1_flops: int = 0 + w2_flops: int = 0 + total_flops: int = 0 + input_bytes: int = 0 + w1_weight_bytes: int = 0 + w2_weight_bytes: int = 0 + output_bytes: int = 0 + total_bytes: int = 0 + duration_us: float = 0.0 + tflops: float = 0.0 + mfu: float = 0.0 + achieved_bw_tb_s: float = 0.0 + mbu: float = 0.0 + arithmetic_intensity: float = 0.0 + roofline_bound: str = "" + kernel_name: str = "" + external_id: int = 0 + num_kernels: int = 0 + timestamp_us: float = 0.0 + correlation_id: int = 0 + tp_rank: str = "" + + +############################################################################### +# GPU specification database +############################################################################### + +GPU_SPECS: Dict[str, GPUSpecs] = { + "B200": GPUSpecs( + name="NVIDIA B200 SXM", + fp16_tflops=2250.0, + fp8_tflops=4500.0, + fp4_tflops=9000.0, + memory_bw_tb_s=8.0, + num_sms=160, + l2_cache_mb=128.0, + nvlink_bw_gb_s=1800.0, + ), + "H200": GPUSpecs( + name="NVIDIA H200 SXM", + fp16_tflops=989.4, + fp8_tflops=1978.9, + fp4_tflops=0.0, + memory_bw_tb_s=4.8, + num_sms=132, + l2_cache_mb=80.0, + nvlink_bw_gb_s=900.0, + ), + "H100": GPUSpecs( + name="NVIDIA H100 SXM", + fp16_tflops=989.4, + fp8_tflops=1978.9, + fp4_tflops=0.0, + memory_bw_tb_s=3.35, + num_sms=132, + l2_cache_mb=50.0, + nvlink_bw_gb_s=900.0, + ), + "A100": GPUSpecs( + name="NVIDIA A100 SXM", + fp16_tflops=312.0, + fp8_tflops=0.0, + fp4_tflops=0.0, + memory_bw_tb_s=2.0, + num_sms=108, + l2_cache_mb=40.0, + nvlink_bw_gb_s=600.0, + ), +} + + +############################################################################### +# Pattern definitions for kernel classification +############################################################################### + +GEMM_KERNEL_PATTERNS = { + 'deep_gemm_fp8': { + 'match': lambda name: 'deep_gemm' in name.lower(), + 'is_gemm': True, + 'dtype': 'fp8', + 'source': 'deep_gemm', + 'subcategory': 'fp8_gemm', + }, + 'nvjet_cublas': { + 'match': lambda name: 'nvjet' in name.lower(), + 'is_gemm': True, + 'dtype': 'bf16', + 'source': 'cublas', + 'subcategory': 'cublas_gemm', + }, + 'cublas_gemm': { + 'match': lambda name: 'cublas' in name.lower() and 'gemm' in name.lower(), + 'is_gemm': True, + 'dtype': 'bf16', + 'source': 'cublas', + 'subcategory': 'cublas_gemm', + }, + 'cutlass_gemm': { + 'match': lambda name: 'cutlass' in name.lower() and ('gemm' in name.lower() or 'matmul' in name.lower()), + 'is_gemm': True, + 'dtype': 'bf16', + 'source': 'cutlass', + 'subcategory': 'cutlass_gemm', + }, + 'generic_gemm': { + 'match': lambda name: ('gemm' in name.lower() or 'matmul' in name.lower()) and 'deep_gemm' not in name.lower() and 'nvjet' not in name.lower(), + 'is_gemm': True, + 'dtype': 'bf16', + 'source': 'generic', + 'subcategory': 'other_gemm', + }, +} + + +COMM_KERNEL_PATTERNS = { + 'nccl_allreduce': { + 'match': lambda name: 'allreduce' in name.lower(), + 'subcategory': 'allreduce', + }, + 'nccl_allgather': { + 'match': lambda name: 'allgather' in name.lower() or 'all_gather' in name.lower(), + 'subcategory': 'all_gather', + }, + 'nccl_reducescatter': { + 'match': lambda name: 'reducescatter' in name.lower() or 'reduce_scatter' in name.lower(), + 'subcategory': 'reduce_scatter', + }, + 'cross_device_reduce': { + 'match': lambda name: 'cross_device_reduce' in name.lower(), + 'subcategory': 'cross_device_reduce', + }, + 'nccl_other': { + 'match': lambda name: 'nccl' in name.lower(), + 'subcategory': 'nccl_other', + }, +} + + +ATTENTION_KERNEL_PATTERNS = { + 'flash_attention': { + 'match': lambda name: 'flashinfer' in name.lower() or 'flash_attn' in name.lower() or 'fmha' in name.lower(), + 'subcategory': 'flash_attention', + }, + 'mla_attention': { + 'match': lambda name: 'batchmlapageattention' in name.lower() or 'prefillwithkvcache' in name.lower(), + 'subcategory': 'mla_attention', + }, +} + + +NORM_KERNEL_PATTERNS = { + 'rmsnorm': { + 'match': lambda name: 'rmsnorm' in name.lower(), + 'subcategory': 'rmsnorm', + }, + 'layernorm': { + 'match': lambda name: 'layernorm' in name.lower(), + 'subcategory': 'layernorm', + }, +} + + +############################################################################### +# Helper functions for dtype handling and metrics +############################################################################### + +def normalize_dtype(dt: Optional[str]) -> str: + """Normalize various dtype strings to canonical short names. + + The profiler sometimes emits data type strings with different naming + conventions. This helper function maps those to canonical forms such as + ``fp8``, ``bf16``, ``fp16`` and ``fp32``. Unknown dtypes are returned as + lowercased strings. + """ + if not dt: + return "" + s = str(dt).lower() + if any(x in s for x in ["float8", "fp8", "e4m3", "e5m2"]): + return "fp8" + if any(x in s for x in ["bfloat16", "bf16"]): + return "bf16" + if any(x in s for x in ["float16", "fp16", "half"]): + return "fp16" + if any(x in s for x in ["float32", "fp32"]): + return "fp32" + if "int8" in s: + return "int8" + return s + + +def get_bytes_per_element(dtype: str) -> float: + """Return bytes per element for a given dtype. + + Sub‑byte types such as FP4 are represented as fractional bytes. The + returned value may be a float to support these fractions. + """ + dtype_lower = str(dtype).lower() if dtype else "" + if any(x in dtype_lower for x in ["float4", "fp4", "e2m1"]): + return 0.5 + if any(x in dtype_lower for x in ["float8", "fp8", "e4m3", "e5m2"]): + return 1 + if any(x in dtype_lower for x in ["float16", "fp16", "bfloat16", "bf16", "half"]): + return 2 + if any(x in dtype_lower for x in ["float32", "fp32"]): + return 4 + return 2 + + +def compute_dtype_from_inputs(a_dtype: str, b_dtype: str) -> str: + """Heuristically determine the compute dtype from input dtypes. + + The logic prefers fp8 when present, then fp16, then bf16. If none of + these special cases match the two dtypes, one of them is returned. + """ + dts = {normalize_dtype(a_dtype), normalize_dtype(b_dtype)} + if "fp8" in dts: + return "fp8" + if "fp16" in dts: + return "fp16" + if "bf16" in dts: + return "bf16" + # Fallback to the first non‑empty dtype + return next(iter(dts - {""}), "bf16") + + +def calculate_gemm_flops(m: int, n: int, k: int) -> int: + """Calculate FLOPs for a GEMM operation ``C = A @ B``. + + Each element of the output matrix involves a multiply and an add. Thus + FLOPs are computed as ``2 * m * n * k``. + """ + return 2 * m * n * k + + +def calculate_gemm_bytes(m: int, n: int, k: int, + a_dtype: str = 'bf16', b_dtype: str = 'bf16', + c_dtype: str = 'bf16') -> int: + """Compute total bytes accessed for GEMM ``C = A @ B``. + + Reads ``A`` (``m × k``), reads ``B`` (``k × n``) and writes ``C`` + (``m × n``). Supports sub‑byte types by rounding up to the nearest whole + element count when necessary. + """ + a_bytes = get_bytes_per_element(a_dtype) + b_bytes = get_bytes_per_element(b_dtype) + c_bytes = get_bytes_per_element(c_dtype) + # Use integer rounding for sub‑byte types (e.g. FP4) + bytes_a = int(m * k * a_bytes) if a_bytes >= 1 else ((m * k + 1) // 2) + bytes_b = int(k * n * b_bytes) if b_bytes >= 1 else ((k * n + 1) // 2) + bytes_c = int(m * n * c_bytes) if c_bytes >= 1 else ((m * n + 1) // 2) + return bytes_a + bytes_b + bytes_c + + +def calculate_gemm_bytes_breakdown(m: int, n: int, k: int, + a_dtype: str = 'bf16', b_dtype: str = 'bf16', + c_dtype: str = 'bf16') -> Tuple[int, int, int]: + """Return breakdown of bytes for GEMM. + + Returns a tuple ``(activation_bytes, weight_bytes, total_bytes)`` where + ``activation_bytes`` is the sum of ``A`` and ``C`` (assuming weights ``B`` + can be served from cache), ``weight_bytes`` is the bytes for ``B`` and + ``total_bytes`` is the sum. + """ + a_bytes = get_bytes_per_element(a_dtype) + b_bytes = get_bytes_per_element(b_dtype) + c_bytes = get_bytes_per_element(c_dtype) + bytes_a = int(m * k * a_bytes) if a_bytes >= 1 else ((m * k + 1) // 2) + bytes_b = int(k * n * b_bytes) if b_bytes >= 1 else ((k * n + 1) // 2) + bytes_c = int(m * n * c_bytes) if c_bytes >= 1 else ((m * n + 1) // 2) + activation_bytes = bytes_a + bytes_c + weight_bytes = bytes_b + total_bytes = bytes_a + bytes_b + bytes_c + return activation_bytes, weight_bytes, total_bytes + + +def calculate_arithmetic_intensity(flops: int, bytes_accessed: int) -> float: + """Compute arithmetic intensity (FLOPs per byte).""" + return flops / bytes_accessed if bytes_accessed > 0 else 0.0 + + +def calculate_mfu(flops: int, duration_us: float, peak_tflops: float) -> float: + """Compute Model FLOPS Utilization (MFU).""" + if duration_us <= 0: + return 0.0 + duration_s = duration_us / 1e6 + achieved_tflops = (flops / 1e12) / duration_s + return (achieved_tflops / peak_tflops) * 100.0 if peak_tflops > 0 else 0.0 + + +def calculate_mbu(bytes_accessed: int, duration_us: float, peak_bw_tb_s: float) -> float: + """Compute Memory Bandwidth Utilization (MBU).""" + if duration_us <= 0: + return 0.0 + duration_s = duration_us / 1e6 + achieved_bw_tb_s = (bytes_accessed / 1e12) / duration_s + return (achieved_bw_tb_s / peak_bw_tb_s) * 100.0 if peak_bw_tb_s > 0 else 0.0 + + +def calculate_roofline_tflops(arithmetic_intensity: float, gpu_specs: GPUSpecs, + peak_tflops: float) -> Tuple[float, str]: + """Return roofline‐based theoretical TFLOPS and bound type. + + The roofline model states that attainable performance is the minimum of + compute peak and the product of memory bandwidth and arithmetic + intensity. ``peak_tflops`` corresponds to the compute peak for the + operation’s dtype (returned by :func:`get_dtype_peak_tflops`). + """ + if arithmetic_intensity <= 0: + return 0.0, "unknown" + memory_bound_tflops = gpu_specs.memory_bw_tb_s * arithmetic_intensity + compute_bound_tflops = peak_tflops + if memory_bound_tflops < compute_bound_tflops: + return memory_bound_tflops, "memory" + else: + return compute_bound_tflops, "compute" + + +def get_dtype_peak_tflops(dtype: str, gpu_specs: GPUSpecs) -> float: + """Return peak TFLOPS for a given dtype from the GPU specs. + + FP4 operations fall back to FP8 if unavailable. FP8 falls back to + FP16/BF16 if unavailable. Unknown dtypes default to FP16/BF16 peak. + """ + dtype_lower = str(dtype).lower() + if any(x in dtype_lower for x in ["float4", "fp4", "e2m1"]): + return gpu_specs.fp4_tflops if gpu_specs.fp4_tflops > 0 else ( + gpu_specs.fp8_tflops if gpu_specs.fp8_tflops > 0 else gpu_specs.fp16_tflops + ) + if any(x in dtype_lower for x in ["float8", "fp8", "e4m3", "e5m2"]): + return gpu_specs.fp8_tflops if gpu_specs.fp8_tflops > 0 else gpu_specs.fp16_tflops + return gpu_specs.fp16_tflops + +def estimate_activated_experts(batch_size: int, top_k: int = 8, max_experts: int = 257) -> int: + """Estimate number of activated experts based on batch size. + + Models the saturation curve where: + - batch_size=1 → top_k experts (each token picks top_k) + - batch_size→∞ → max_experts (all experts eventually used) + + Uses a logarithmic saturation model fitted to empirical observations + from DeepSeek R1 with top-k=8 routing. + + The model: activated = top_k + (max_experts - top_k) * (1 - exp(-batch_size / scale)) + where scale controls how quickly saturation occurs. + """ + import math + + if batch_size <= 0: + return top_k + + # Scale factor controls saturation speed + # Tuned so that: + # - batch_size=64 → ~140-180 experts + # - batch_size=200 → ~200-220 experts + # - batch_size=1000 → ~250+ experts + scale = 150.0 # Adjust based on your empirical data + + saturation = 1.0 - math.exp(-batch_size / scale) + activated = top_k + (max_experts - top_k) * saturation + + return min(int(activated), max_experts) + +############################################################################### +# Dimension extraction from CPU operations and kernels +############################################################################### + +CPU_OP_GEMM_PATTERNS = { + 'deep_gemm_fp8': { + 'match': lambda name: 'deep_gemm' in name.lower() or 'fp8_gemm' in name.lower(), + 'dtype': 'fp8', + }, + 'aten_mm': { + 'match': lambda name: name in ['aten::mm', 'aten::matmul'], + 'dtype': 'bf16', + }, + 'aten_linear': { + 'match': lambda name: name == 'aten::linear', + 'dtype': 'bf16', + }, +} + + +def extract_dimensions_from_cpu_op(event: Dict[str, Any]) -> Optional[Tuple[int, int, int, str, str, str]]: + """Extract matrix dimensions and dtypes from a CPU op event. + + Returns a tuple ``(M, N, K, A_dtype, B_dtype, C_dtype)`` where the dtypes + correspond to the input tensors and output for the GEMM. If the op + cannot be interpreted as a GEMM (e.g. insufficient information), returns + ``None``. + """ + args = event.get('args', {}) + input_dims = args.get('Input Dims', []) + input_types = args.get('Input type', []) + name = event.get('name', '') + + if not input_dims: + return None + + # Attempt to identify deep_gemm FP8 operations + if 'deep_gemm' in name.lower() and len(input_dims) >= 5: + # Format: [M, K], scale, [N, K], scale, [M, N] + a_dims = input_dims[0] + b_dims = input_dims[2] + if not (isinstance(a_dims, list) and len(a_dims) >= 2 and isinstance(b_dims, list) and len(b_dims) >= 1): + return None + m, k = a_dims[0], a_dims[1] + n = b_dims[0] + # Types: A_type, A_scale, B_type, B_scale, C_type + types = [normalize_dtype(t) for t in input_types] if input_types else [] + a_dtype = types[0] if len(types) >= 1 else 'bf16' + b_dtype = types[2] if len(types) >= 3 else a_dtype + c_dtype = types[4] if len(types) >= 5 else (a_dtype if a_dtype == b_dtype else 'bf16') + return m, n, k, a_dtype, b_dtype, c_dtype + + # aten::mm / aten::matmul: Input dims [[M,K],[K,N]] + if name in ['aten::mm', 'aten::matmul'] and len(input_dims) >= 2: + a_dims = input_dims[0] + b_dims = input_dims[1] + if not (isinstance(a_dims, list) and len(a_dims) >= 2 and isinstance(b_dims, list) and len(b_dims) >= 2): + return None + m, k = a_dims[0], a_dims[1] + n = b_dims[1] + types = [normalize_dtype(t) for t in input_types] if input_types else [] + a_dtype = types[0] if len(types) >= 1 else 'bf16' + b_dtype = types[1] if len(types) >= 2 else a_dtype + c_dtype = a_dtype if a_dtype == b_dtype else 'bf16' + return m, n, k, a_dtype, b_dtype, c_dtype + + # aten::linear: Input dims [[M,K],[N,K], bias] + if name == 'aten::linear' and len(input_dims) >= 2: + a_dims = input_dims[0] + w_dims = input_dims[1] + if not (isinstance(a_dims, list) and len(a_dims) >= 2 and isinstance(w_dims, list) and len(w_dims) >= 2): + return None + # Handle batched input: [B, M, K] -> effective M = B*M + if len(a_dims) == 2: + m, k = a_dims + elif len(a_dims) == 3: + m = a_dims[0] * a_dims[1] + k = a_dims[2] + else: + return None + n = w_dims[0] # Weight dims [N,K] + types = [normalize_dtype(t) for t in input_types] if input_types else [] + a_dtype = types[0] if len(types) >= 1 else 'bf16' + b_dtype = types[1] if len(types) >= 2 else a_dtype + c_dtype = a_dtype if a_dtype == b_dtype else 'bf16' + return m, n, k, a_dtype, b_dtype, c_dtype + + return None + + +def extract_tp_rank(pid: Any) -> Optional[str]: + """Extract tensor parallel rank from a PID string or number.""" + if pid is None: + return None + match = re.search(r'\[TP(\d+)\]', str(pid)) + if match: + return match.group(1) + return str(pid) + + +def parse_deep_gemm_kernel_dims(kernel_name: str, grid: List[int], + cpu_op_dims: Optional[Tuple[int, int, int]] = None) -> Optional[Tuple[int, int, int, str]]: + """Parse deep_gemm kernel template parameters to infer dimensions. + + The deep_gemm implementation names kernels with a template signature like + ``deep_gemm::sm90_fp8_gemm_1d2d_impl<..., N, K, ..., M_tile, N_tile, K_tile, ...>``. + If ``cpu_op_dims`` is provided, the M dimension is taken from it; otherwise + it is inferred from the grid dimensions under the assumption that the grid + x dimension is ``ceil(M/m_tile) * ceil(N/n_tile)``. + Returns a tuple ``(M, N, K, dtype)`` where ``dtype`` is the compute dtype + (fp8 or bf16). + """ + match = re.search(r'deep_gemm::[^<]*<[^,]*,\s*(\d+)u,\s*(\d+)u,[^,]*,\s*(\d+)u,\s*(\d+)u,\s*(\d+)u', kernel_name) + if not match: + return None + n = int(match.group(1)) + k = int(match.group(2)) + m_tile = int(match.group(3)) + n_tile = int(match.group(4)) + # Determine M dimension + if cpu_op_dims: + m = cpu_op_dims[0] # Use M from CPU op + else: + grid_x = grid[0] if grid else 1 + # number of tiles along N dimension + num_n_tiles = (n + n_tile - 1) // n_tile + if grid_x <= num_n_tiles: + # Single M tile + m = m_tile + else: + num_m_tiles = max(grid_x // num_n_tiles, 1) + m = num_m_tiles * m_tile + dtype = 'fp8' if 'fp8' in kernel_name.lower() else 'bf16' + return (m, n, k, dtype) + + +def infer_cuda_graph_kernel_dims(kernel_name: str, grid: List[int], + config: Optional[Config] = None, + sibling_dims: Optional[Dict[str, Tuple[int, int]]] = None) -> Optional[Tuple[int, int, int, str, str]]: + """Infer dimensions for CUDA graph replayed kernels. + + See the original implementation for detailed heuristics. The inference + relies on a combination of sibling kernel dimensions (obtained from + prefill kernels with External ID) and known model architecture. The + ``config`` argument provides hidden size, expert intermediate size and + decode batch size. Returns a tuple ``(M, N, K, dtype, layer_type)``. + """ + # Allow config to be None for fallback heuristics + hidden = config.hidden_size if config else 7168 + tp_degree = config.tp_degree if config else 8 + expert_intermediate_per_gpu = (config.expert_intermediate_size // max(tp_degree, 1)) if config else (2048 // 8) + decode_batch = config.decode_batch_size if config else 64 + name_lower = kernel_name.lower() + grid_tuple = tuple(grid) if grid else () + # Strategy 1: use sibling dimensions for nvjet_tst_128x8 (shared expert) + if sibling_dims and 'nvjet_tst_128x8' in name_lower and 'nvjet_tst_128x8' in sibling_dims: + n, k = sibling_dims['nvjet_tst_128x8'] + return (decode_batch, n, k, 'bf16', 'FFN') + # Strategy 2: use sibling dims for nvjet_tst_64x8 based on grid + if sibling_dims and 'nvjet_tst_64x8' in name_lower and 'nvjet_tst_64x64' in sibling_dims: + intermediate_per_gpu, hidden_from_sibling = sibling_dims['nvjet_tst_64x64'] + if grid_tuple == (2, 64, 1): + # Down projection: [M, intermediate] @ [intermediate, hidden] + return (decode_batch, hidden_from_sibling, intermediate_per_gpu, 'bf16', 'FFN') + elif grid_tuple == (2, 16, 1): + # Up projection: [M, hidden] @ [hidden, intermediate] + return (decode_batch, intermediate_per_gpu, hidden_from_sibling, 'bf16', 'FFN') + # Strategy 3: use model knowledge for nvjet_tst_64x8 + if 'nvjet_tst_64x8' in name_lower: + if grid_tuple == (2, 64, 1): + return (decode_batch, hidden, expert_intermediate_per_gpu, 'bf16', 'FFN') + elif grid_tuple == (2, 16, 1): + return (decode_batch, expert_intermediate_per_gpu, hidden, 'bf16', 'FFN') + # Shared expert nvjet_tst_128x8 fallback + if 'nvjet_tst_128x8' in name_lower: + return (decode_batch, 16160, hidden, 'bf16', 'FFN') + # nvjet_tst_64x64 kernels are handled via CPU op dims + if 'nvjet_tst_64x64' in name_lower: + return None + # router_gemm kernels + if 'router_gemm' in name_lower: + num_experts = config.num_experts + return (decode_batch, num_experts, hidden, 'bf16', 'FFN') + return None + + +def classify_kernel(kernel_name: str) -> KernelClassification: + """Classify a GPU kernel by examining its name against known patterns.""" + for pattern in GEMM_KERNEL_PATTERNS.values(): + if pattern['match'](kernel_name): + return KernelClassification( + category='gemm', + subcategory=pattern['subcategory'], + is_gemm=True, + dtype=pattern['dtype'], + source=pattern['source'], + ) + for pattern in COMM_KERNEL_PATTERNS.values(): + if pattern['match'](kernel_name): + return KernelClassification( + category='communication', + subcategory=pattern['subcategory'], + is_gemm=False, + dtype='', + source='nccl', + ) + for pattern in ATTENTION_KERNEL_PATTERNS.values(): + if pattern['match'](kernel_name): + return KernelClassification( + category='attention', + subcategory=pattern['subcategory'], + is_gemm=False, + dtype='', + source='flashinfer', + ) + for pattern in NORM_KERNEL_PATTERNS.values(): + if pattern['match'](kernel_name): + return KernelClassification( + category='normalization', + subcategory=pattern['subcategory'], + is_gemm=False, + dtype='', + source='custom', + ) + return KernelClassification( + category='other', + subcategory='unknown', + is_gemm=False, + dtype='', + source='unknown', + ) + + +def classify_layer_type(m: int, n: int, k: int, kernel_name: str = "") -> str: + """Heuristically classify a GEMM as belonging to QKVO, FFN or other layers.""" + hidden_size = 7168 + num_experts = 256 + # MoE router + if k == hidden_size and n == num_experts: + return 'FFN' + # MoE gate/router variants + if k == 512 and n in [4096, 2048, 1024]: + return 'FFN' + # MoE FFN projections: up or down + if (n == 4608 and k == hidden_size) or (n == hidden_size and k == 4608): + return 'FFN' + if n > 10000 or k > 10000: + return 'FFN' + # Attention projections + if k == hidden_size and n in [2112, 2048, 2304, 2560]: + return 'QKVO' + if n == hidden_size and k in [2048, 2112, 2304, 2560]: + return 'QKVO' + if (n == 3072 and k == 1536) or (n == 1536 and k == 3072): + return 'QKVO' + if k == hidden_size: + return 'QKVO' + if n == hidden_size: + return 'QKVO' + return 'Other' + + +############################################################################### +# CPU op and sibling dimension maps +############################################################################### + +def build_cpu_op_dims_map(events: Iterable[Dict[str, Any]]) -> Dict[Tuple[str, int], Tuple[int, int, int, str, str, str]]: + """Build a map from (TP rank, External ID) to GEMM dimensions and dtypes. + + This helper consolidates the repeated logic used in several analysis + routines. It iterates over CPU op events, extracts dimensions via + :func:`extract_dimensions_from_cpu_op` and stores them keyed by TP rank and + External ID. The map also includes entries for adjacent external IDs + (``ext_id ± 1``) since GPU kernels may use offset IDs. + """ + cpu_op_dims: Dict[Tuple[str, int], Tuple[int, int, int, str, str, str]] = {} + for event in events: + if event.get('cat') != 'cpu_op': + continue + ext_id = event.get('args', {}).get('External id') + if ext_id is None: + continue + tp_rank = extract_tp_rank(event.get('pid')) + dims = extract_dimensions_from_cpu_op(event) + if dims: + cpu_op_dims[(tp_rank, ext_id)] = dims + # Also map neighbouring IDs to the same dims to catch child kernels + cpu_op_dims[(tp_rank, ext_id + 1)] = dims + cpu_op_dims[(tp_rank, ext_id - 1)] = dims + return cpu_op_dims + + +def build_sibling_dims_map(events: Iterable[Dict[str, Any]], + cpu_op_dims: Dict[Tuple[str, int], Tuple[int, int, int, str, str, str]]) -> Dict[str, Tuple[int, int]]: + """Build a map of kernel signatures to (N, K) dimensions using CPU op dims. + + For kernels like ``nvjet_tst_64x64`` the External ID identifies a CPU op + with dimensions (M,N,K). The sibling map stores ``(N, K)`` keyed by + signature so that decode kernels (without External ID) can later infer + their shapes. + """ + sibling_dims: Dict[str, Tuple[int, int]] = {} + for event in events: + if event.get('cat') != 'kernel': + continue + name = event.get('name', '') + ext_id = event.get('args', {}).get('External id') + if ext_id is None: + continue + match = re.search(r'nvjet_tst_(\d+x\d+)', name.lower()) + if not match: + continue + signature = f"nvjet_tst_{match.group(1)}" + tp_rank = extract_tp_rank(event.get('pid')) + # Try several neighbouring ext_ids + for key_ext in [ext_id, ext_id - 1, ext_id + 1]: + dims = cpu_op_dims.get((tp_rank, key_ext)) + if dims and len(dims) >= 3: + # dims are (m,n,k,...) so take (n,k) + sibling_dims.setdefault(signature, (dims[1], dims[2])) + break + return sibling_dims + + +############################################################################### +# GEMM kernel analysis +############################################################################### + +def analyze_gemm_kernels(events: List[Dict[str, Any]], gpu_specs: GPUSpecs, config: Config) -> List[GemmInfo]: + """Analyse all GEMM/MatMul kernels in a trace and compute performance metrics. + + The returned list contains a :class:`GemmInfo` entry for each kernel + identified as GEMM with known dimensions. Dimension extraction proceeds in + order of priority: CPU op correlation, deep_gemm template parsing, CUDA + graph inference. Metrics such as MFU, MBU, arithmetic intensity and + roofline bound are computed for each kernel. + """ + gemm_infos: List[GemmInfo] = [] + # Build CPU op dimension map once + cpu_op_dims = build_cpu_op_dims_map(events) + # Build sibling dimension map for nvjet kernels + sibling_dims = build_sibling_dims_map(events, cpu_op_dims) + seen_kernels = set() + unmatched = defaultdict(lambda: {'count': 0, 'time_us': 0}) + + for event in events: + if event.get('cat') != 'kernel': + continue + name = event.get('name', '') + classification = classify_kernel(name) + if not classification.is_gemm: + continue + duration_us = event.get('dur', 0) + if duration_us <= 0: + continue + ext_id = event.get('args', {}).get('External id') + tp_rank = extract_tp_rank(event.get('pid')) + grid = event.get('args', {}).get('grid', [1, 1, 1]) + ts = event.get('ts', 0) + kernel_key = (tp_rank, ts, name[:50]) + if kernel_key in seen_kernels: + continue + seen_kernels.add(kernel_key) + # Extract dims via CPU op map + dims = None + if ext_id is not None: + dims = cpu_op_dims.get((tp_rank, ext_id)) or cpu_op_dims.get((tp_rank, ext_id - 1)) or cpu_op_dims.get((tp_rank, ext_id + 1)) + # Parse deep_gemm template if needed + if dims is None and classification.source == 'deep_gemm': + parsed = parse_deep_gemm_kernel_dims(name, grid, None) + if parsed: + m_, n_, k_, dtype_ = parsed + # Deep GEMM kernels use FP8 inputs. The accumulator/output is + # typically BF16 by default. To allow users to override this + # behaviour, use the configured model dtype for the output. + dims = (m_, n_, k_, 'fp8', 'fp8', normalize_dtype(config.model_dtype)) + # Infer CUDA graph kernels + inferred_layer_type = None + if dims is None and ext_id is None: + inferred = infer_cuda_graph_kernel_dims(name, grid, config, sibling_dims=sibling_dims) + if inferred: + m_, n_, k_, dtype_, inferred_layer_type = inferred + # Use the model's default dtype for the output when inferring + # CUDA graph kernel dimensions. When the compute dtype is + # FP8 the output is often BF16 in many models, but the model + # dtype parameter allows this behaviour to be customised. + c_dtype = normalize_dtype(config.model_dtype) + dims = (m_, n_, k_, dtype_, dtype_, c_dtype) + if dims is None: + # Record unmatched for debugging + unmatched[classification.subcategory]['count'] += 1 + unmatched[classification.subcategory]['time_us'] += duration_us + continue + # Unpack dims to (m,n,k,a,b,c) + if len(dims) >= 6: + m, n, k, a_dtype, b_dtype, c_dtype = dims[:6] + elif len(dims) == 5: + m, n, k, input_dtype, output_dtype = dims + a_dtype = b_dtype = input_dtype + c_dtype = output_dtype + else: + # Only (M,N,K,input_dtype) provided. Both A and B use the same + # input dtype. Use the configured model dtype as the fallback + # output dtype instead of assuming BF16. This allows users to + # specify the precision of GEMM outputs when the profiler does not + # record a C dtype (e.g. FP16 models). + m, n, k, input_dtype = dims + a_dtype = b_dtype = input_dtype + # Normalise the configured model dtype so that abbreviations like + # 'float16' map to canonical short names. If the input dtype is + # FP8 we still respect the model dtype for the output. + c_dtype = normalize_dtype(config.model_dtype) + if m <= 0 or n <= 0 or k <= 0: + continue + # Override dtype from classification if not specified + if not a_dtype and classification.dtype: + a_dtype = classification.dtype + if not b_dtype and classification.dtype: + b_dtype = classification.dtype + if not c_dtype and classification.dtype: + c_dtype = 'bf16' if classification.dtype == 'fp8' else classification.dtype + compute_dtype = compute_dtype_from_inputs(a_dtype, b_dtype) + # Compute metrics + flops = calculate_gemm_flops(m, n, k) + bytes_accessed = calculate_gemm_bytes(m, n, k, a_dtype, b_dtype, c_dtype) + activation_bytes, weight_bytes, _ = calculate_gemm_bytes_breakdown(m, n, k, a_dtype, b_dtype, c_dtype) + peak_tflops = get_dtype_peak_tflops(compute_dtype, gpu_specs) + duration_s = duration_us / 1e6 + achieved_tflops = (flops / 1e12) / duration_s + achieved_bw_tb_s = (bytes_accessed / 1e12) / duration_s + mfu = calculate_mfu(flops, duration_us, peak_tflops) + mbu = calculate_mbu(bytes_accessed, duration_us, gpu_specs.memory_bw_tb_s) + effective_mbu = calculate_mbu(activation_bytes, duration_us, gpu_specs.memory_bw_tb_s) + l2_cache_benefit = (bytes_accessed / activation_bytes) if activation_bytes > 0 else 1.0 + ai = calculate_arithmetic_intensity(flops, bytes_accessed) + roofline_tflops, roofline_bound = calculate_roofline_tflops(ai, gpu_specs, peak_tflops) + layer_type = inferred_layer_type if inferred_layer_type else classify_layer_type(m, n, k, name) + correlation_id = event.get('args', {}).get('correlation', 0) + stream_id = event.get('args', {}).get('stream', 0) + gemm_infos.append(GemmInfo( + m=m, n=n, k=k, + dtype=compute_dtype, + input_dtype=(a_dtype if a_dtype == b_dtype else 'mixed'), + output_dtype=c_dtype, + a_dtype=a_dtype, + b_dtype=b_dtype, + c_dtype=c_dtype, + duration_us=duration_us, + flops=flops, + tflops=achieved_tflops, + mfu=mfu, + bytes_accessed=bytes_accessed, + achieved_bw_tb_s=achieved_bw_tb_s, + mbu=mbu, + arithmetic_intensity=ai, + roofline_tflops=roofline_tflops, + roofline_bound=roofline_bound, + kernel_name=name, + external_id=ext_id if ext_id is not None else 0, + layer_type=layer_type, + activation_bytes=activation_bytes, + weight_bytes=weight_bytes, + effective_mbu=effective_mbu, + l2_cache_benefit=l2_cache_benefit, + timestamp_us=event.get('ts', 0), + correlation_id=correlation_id, + tp_rank=tp_rank if tp_rank else "", + stream_id=stream_id if stream_id else 0, + )) + return gemm_infos + + +############################################################################### +# Grouped GEMM (fused MoE) analysis +############################################################################### + +def analyze_grouped_gemm_kernels(events: List[Dict[str, Any]], gpu_specs: GPUSpecs, config: Config) -> List[GroupedGemmInfo]: + """Analyse fused MoE kernels and compute grouped GEMM metrics. + + This function supports both prefill and decode phases. For prefill + kernels (those with External ID) the CPU op event provides the full + dimensions. For decode kernels (those without External ID) the analysis + infers dimensions based on typical decode batch sizes and heuristics. The + returned list contains one entry per grouped GEMM operation aggregated + across all kernel calls belonging to that operation. + """ + grouped_infos: List[GroupedGemmInfo] = [] + # Identify TP ranks (to infer number of GPUs) + tp_ranks = set() + for event in events: + pid = event.get('pid') + match = re.search(r'\[TP(\d+)\]', str(pid)) + if match: + tp_ranks.add(match.group(1)) + num_gpus = max(len(tp_ranks), 1) + # Build map from (tp_rank, ext_id) to fused expert dims + fused_expert_ops: Dict[Tuple[str, int], Dict[str, Any]] = {} + for event in events: + if event.get('cat') != 'cpu_op': + continue + name = event.get('name', '') + if 'inplace_fused_experts' not in name and 'fused_experts' not in name: + continue + ext_id = event.get('args', {}).get('External id') + if ext_id is None: + continue + tp_rank = extract_tp_rank(event.get('pid')) + args = event.get('args', {}) + input_dims = args.get('Input Dims', []) + input_types = args.get('Input type', []) + if len(input_dims) < 5: + continue + input_shape = input_dims[0] if input_dims[0] else [] + w1_shape = input_dims[1] if len(input_dims) > 1 else [] + w2_shape = input_dims[2] if len(input_dims) > 2 else [] + topk_shape = input_dims[3] if len(input_dims) > 3 else [] + if not (len(input_shape) >= 2 and len(w1_shape) >= 3 and len(w2_shape) >= 3 and len(topk_shape) >= 2): + continue + num_tokens = input_shape[0] + hidden_size = input_shape[1] + num_experts_local = w1_shape[0] + w1_intermediate = w1_shape[1] + w2_intermediate = w2_shape[2] + top_k = topk_shape[1] + input_dtype = normalize_dtype(input_types[0]) if input_types else 'bf16' + weight_dtype = normalize_dtype(input_types[1]) if len(input_types) > 1 else 'fp8' + output_dtype = input_dtype + fused_expert_ops[(tp_rank, ext_id)] = { + 'num_tokens': num_tokens, + 'hidden_size': hidden_size, + 'num_experts': num_experts_local, + 'w1_intermediate': w1_intermediate, + 'w2_intermediate': w2_intermediate, + 'top_k': top_k, + 'input_dtype': input_dtype, + 'weight_dtype': weight_dtype, + 'output_dtype': output_dtype, + 'ts': event.get('ts', 0), + } + # Collect fused_moe kernels + moe_kernels_by_ext = defaultdict(list) + moe_kernels_no_ext = [] + for event in events: + if event.get('cat') != 'kernel': + continue + name = event.get('name', '') + if not name.startswith('fused_moe_kernel'): + continue + ext_id = event.get('args', {}).get('External id') + tp_rank = extract_tp_rank(event.get('pid')) + kernel_info = { + 'name': name, + 'dur': event.get('dur', 0), + 'ts': event.get('ts', 0), + 'correlation': event.get('args', {}).get('correlation', 0), + 'grid': event.get('args', {}).get('grid', []), + 'ext_id': ext_id, + 'tp_rank': tp_rank, + } + if ext_id is not None: + moe_kernels_by_ext[(tp_rank, ext_id)].append(kernel_info) + else: + moe_kernels_no_ext.append(kernel_info) + processed_ext_ids = set() + # Prefill kernels (with External ID) + for (tp_rank, ext_id), kernels in moe_kernels_by_ext.items(): + if ext_id in processed_ext_ids: + continue + processed_ext_ids.add(ext_id) + dims = fused_expert_ops.get((tp_rank, ext_id)) + if dims is None: + continue + num_tokens = dims['num_tokens'] + hidden_size = dims['hidden_size'] + # Determine number of experts and local experts + # dims['num_experts'] corresponds to the local experts for TP-only runs. + # When expert parallelism is enabled, the total experts is provided by + # config.num_experts and local experts per EP rank is config.local_experts. + if config.ep_degree > 1: + num_total_experts = config.num_experts + num_local_experts = config.local_experts + else: + num_total_experts = dims['num_experts'] + num_local_experts = dims['num_experts'] + w1_intermediate = dims['w1_intermediate'] + w2_intermediate = dims['w2_intermediate'] + top_k = dims['top_k'] + input_dtype = dims['input_dtype'] + weight_dtype = dims['weight_dtype'] + output_dtype = dims['output_dtype'] + total_duration_us = sum(k['dur'] for k in kernels) + total_pairs = num_tokens * top_k + # For EP, each EP rank handles a fraction of the total token‑expert pairs + pairs_per_rank = total_pairs / max(config.ep_degree, 1) + # FLOPs per EP rank + w1_flops = 2 * pairs_per_rank * hidden_size * w1_intermediate + w2_flops = 2 * pairs_per_rank * w2_intermediate * hidden_size + total_flops = w1_flops + w2_flops + # Memory bytes per EP rank + input_bytes = int(num_tokens * hidden_size * get_bytes_per_element(input_dtype)) + weight_bytes_elem = get_bytes_per_element(weight_dtype) + # Each rank stores only its local experts when ep_degree>1 + w1_weight_bytes = int(num_local_experts * w1_intermediate * hidden_size * weight_bytes_elem) + w2_weight_bytes = int(num_local_experts * hidden_size * w2_intermediate * weight_bytes_elem) + output_bytes = int(num_tokens * hidden_size * get_bytes_per_element(output_dtype)) + total_bytes = input_bytes + w1_weight_bytes + w2_weight_bytes + output_bytes + duration_s = total_duration_us / 1e6 + achieved_tflops = (total_flops / 1e12) / duration_s if duration_s > 0 else 0 + achieved_bw_tb_s = (total_bytes / 1e12) / duration_s if duration_s > 0 else 0 + peak_tflops = get_dtype_peak_tflops(weight_dtype, gpu_specs) + mfu = (achieved_tflops / peak_tflops) * 100.0 if peak_tflops > 0 else 0 + mbu = (achieved_bw_tb_s / gpu_specs.memory_bw_tb_s) * 100.0 if gpu_specs.memory_bw_tb_s > 0 else 0 + ai = total_flops / total_bytes if total_bytes > 0 else 0 + memory_bound_tflops = gpu_specs.memory_bw_tb_s * ai + roofline_bound = 'memory' if memory_bound_tflops < peak_tflops else 'compute' + grouped_infos.append(GroupedGemmInfo( + num_tokens=num_tokens, + top_k=top_k, + num_experts=num_experts_local, + hidden_size=hidden_size, + w1_intermediate=w1_intermediate, + w2_intermediate=w2_intermediate, + input_dtype=input_dtype, + weight_dtype=weight_dtype, + output_dtype=output_dtype, + total_token_expert_pairs=total_pairs, + w1_flops=w1_flops, + w2_flops=w2_flops, + total_flops=total_flops, + input_bytes=input_bytes, + w1_weight_bytes=w1_weight_bytes, + w2_weight_bytes=w2_weight_bytes, + output_bytes=output_bytes, + total_bytes=total_bytes, + duration_us=total_duration_us, + tflops=achieved_tflops, + mfu=mfu, + achieved_bw_tb_s=achieved_bw_tb_s, + mbu=mbu, + arithmetic_intensity=ai, + roofline_bound=roofline_bound, + kernel_name='fused_moe_kernel', + external_id=ext_id, + num_kernels=len(kernels), + timestamp_us=kernels[0]['ts'] if kernels else 0, + correlation_id=kernels[0]['correlation'] if kernels else 0, + tp_rank=tp_rank if tp_rank else "", + )) + # Decode kernels (without External ID) + if moe_kernels_no_ext and fused_expert_ops: + sample_dims = next(iter(fused_expert_ops.values())) + decode_batch_size = config.decode_batch_size + hidden_size = sample_dims['hidden_size'] + # Determine experts for EP: sample_dims['num_experts'] holds local experts for TP-only + if config.ep_degree > 1: + num_total_experts = config.num_experts + num_local_experts = config.local_experts + else: + num_total_experts = sample_dims['num_experts'] + num_local_experts = sample_dims['num_experts'] + w1_intermediate = sample_dims['w1_intermediate'] + w2_intermediate = sample_dims['w2_intermediate'] + top_k = sample_dims['top_k'] + input_dtype = sample_dims['input_dtype'] + weight_dtype = sample_dims['weight_dtype'] + output_dtype = sample_dims['output_dtype'] + total_pairs = decode_batch_size * top_k + # Group by grid pattern + grid_patterns = defaultdict(list) + for kinfo in moe_kernels_no_ext: + grid_key = tuple(kinfo['grid']) if kinfo['grid'] else () + grid_patterns[grid_key].append(kinfo) + for grid_key, kernels in grid_patterns.items(): + if not kernels: + continue + total_dur_all_gpus_us = sum(k['dur'] for k in kernels) + total_dur_per_gpu_us = total_dur_all_gpus_us / num_gpus + num_kernel_calls_per_gpu = len(kernels) // num_gpus if num_gpus > 0 else len(kernels) + # Heuristic: small grid means W1 (gate+up), large grid means W2 (down) + is_w1 = grid_key and grid_key[0] < 5000 + # FLOPs per kernel call: adjust for EP. Each EP rank handles + # ``pairs_per_rank = total_pairs / ep_degree`` token-expert pairs. + pairs_per_rank = total_pairs / max(config.ep_degree, 1) + if is_w1: + flops_per_kernel = 2 * pairs_per_rank * hidden_size * w1_intermediate + else: + flops_per_kernel = 2 * pairs_per_rank * w2_intermediate * hidden_size + total_flops_per_gpu = flops_per_kernel * num_kernel_calls_per_gpu + # Estimate memory usage: we assume 60% utilisation of local experts + est_experts_used = config.estimate_activated_experts(decode_batch_size, top_k) + input_bytes = int(decode_batch_size * hidden_size * get_bytes_per_element(input_dtype)) + if is_w1: + weight_bytes = int(est_experts_used * w1_intermediate * hidden_size * get_bytes_per_element(weight_dtype)) + output_bytes_per_call = int(pairs_per_rank * w1_intermediate * get_bytes_per_element(output_dtype)) + else: + weight_bytes = int(est_experts_used * hidden_size * w2_intermediate * get_bytes_per_element(weight_dtype)) + # Output of w2 is [M, hidden_size] per rank + output_bytes_per_call = int(decode_batch_size * hidden_size * get_bytes_per_element(output_dtype)) + bytes_per_call = input_bytes + weight_bytes + output_bytes_per_call + total_bytes_per_gpu = bytes_per_call * num_kernel_calls_per_gpu + duration_s = total_dur_per_gpu_us / 1e6 + achieved_tflops = (total_flops_per_gpu / 1e12) / duration_s if duration_s > 0 else 0 + achieved_bw_tb_s = (total_bytes_per_gpu / 1e12) / duration_s if duration_s > 0 else 0 + peak_tflops = get_dtype_peak_tflops(weight_dtype, gpu_specs) + mfu = (achieved_tflops / peak_tflops) * 100.0 if peak_tflops > 0 else 0 + mbu = (achieved_bw_tb_s / gpu_specs.memory_bw_tb_s) * 100.0 if gpu_specs.memory_bw_tb_s > 0 else 0 + # Adjust unrealistic MBU + if mbu > 100: + estimated_actual_bw = gpu_specs.memory_bw_tb_s * 0.9 + total_bytes_per_gpu = int(estimated_actual_bw * 1e12 * duration_s) + achieved_bw_tb_s = estimated_actual_bw + mbu = 90.0 + ai = total_flops_per_gpu / total_bytes_per_gpu if total_bytes_per_gpu > 0 else 0 + memory_bound_tflops = gpu_specs.memory_bw_tb_s * ai + bound = 'memory' if memory_bound_tflops < peak_tflops else 'compute' + grouped_infos.append(GroupedGemmInfo( + num_tokens=decode_batch_size, + top_k=top_k, + # Report the total number of experts in the model. Local experts + # depend on ep_degree but the full model has num_total_experts. + num_experts=num_total_experts, + hidden_size=hidden_size, + w1_intermediate=w1_intermediate if is_w1 else 0, + w2_intermediate=w2_intermediate if not is_w1 else 0, + input_dtype=input_dtype, + weight_dtype=weight_dtype, + output_dtype=output_dtype, + total_token_expert_pairs=total_pairs, + w1_flops=total_flops_per_gpu if is_w1 else 0, + w2_flops=total_flops_per_gpu if not is_w1 else 0, + total_flops=total_flops_per_gpu, + input_bytes=input_bytes * num_kernel_calls_per_gpu, + w1_weight_bytes=weight_bytes * num_kernel_calls_per_gpu if is_w1 else 0, + w2_weight_bytes=weight_bytes * num_kernel_calls_per_gpu if not is_w1 else 0, + output_bytes=output_bytes_per_call * num_kernel_calls_per_gpu, + total_bytes=total_bytes_per_gpu, + duration_us=total_dur_per_gpu_us, + tflops=achieved_tflops, + mfu=mfu, + achieved_bw_tb_s=achieved_bw_tb_s, + mbu=mbu, + arithmetic_intensity=ai, + roofline_bound=bound, + kernel_name=f"fused_moe_kernel (decode, {'w1' if is_w1 else 'w2'})", + external_id=0, + num_kernels=len(kernels), + timestamp_us=kernels[0]['ts'] if kernels else 0, + correlation_id=kernels[0]['correlation'] if kernels else 0, + tp_rank="*", + )) + return grouped_infos + + +############################################################################### +# Layer time breakdown analysis +############################################################################### + +def analyze_layer_time_breakdown(events: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """Analyse total time spent per layer type (QKVO, SDPA, FFN, etc.).""" + tp_ranks = set() + for event in events: + pid = event.get('pid', '') + match = re.search(r'\[TP(\d+)\]', str(pid)) + if match: + tp_ranks.add(match.group(1)) + num_gpus = max(len(tp_ranks), 1) + layer_times = { + 'QKVO': {'time_us': 0.0, 'count': 0}, + 'SDPA': {'time_us': 0.0, 'count': 0}, + 'FFN': {'time_us': 0.0, 'count': 0}, + 'Normalization': {'time_us': 0.0, 'count': 0}, + 'Communication': {'time_us': 0.0, 'count': 0}, + 'Other': {'time_us': 0.0, 'count': 0}, + } + # Build CPU op dims map for GEMM classification + cpu_op_dims = build_cpu_op_dims_map(events) + for event in events: + if event.get('cat') != 'kernel': + continue + name = event.get('name', '') + dur = event.get('dur', 0) + if dur <= 0: + continue + layer_type = None + name_lower = name.lower() + # Communication (exclude long warmup) + if any(x in name_lower for x in ['nccl', 'cross_device_reduce', 'allreduce', 'allgather', 'all_gather', 'reducescatter', 'reduce_scatter']): + if dur < 1e6: + layer_type = 'Communication' + else: + continue + if layer_type is None and ('rmsnorm' in name_lower or 'layernorm' in name_lower): + layer_type = 'Normalization' + if layer_type is None and any(x in name_lower for x in ['flashinfer', 'attention', 'mla', 'fmha']): + if 'BatchMLAPageAttention' in name or 'PrefillWithKVCache' in name: + layer_type = 'SDPA' + elif 'Rotary' in name: + layer_type = 'QKVO' + # GEMM kernels + if layer_type is None and any(x in name_lower for x in ['deep_gemm', 'nvjet', 'gemm', 'matmul', 'splitkreduce']): + ext_id = event.get('args', {}).get('External id') + tp_rank = extract_tp_rank(event.get('pid')) + dims = None + if ext_id is not None: + dims = cpu_op_dims.get((tp_rank, ext_id)) or cpu_op_dims.get((tp_rank, ext_id - 1)) + if dims: + m, n, k = dims[0], dims[1], dims[2] + layer_type = classify_layer_type(m, n, k, name) + else: + # Try parse deep_gemm template + match = re.search(r'deep_gemm[^<]*<[^,]*,\s*(\d+)u,\s*(\d+)u', name) + if match: + n_, k_ = int(match.group(1)), int(match.group(2)) + layer_type = classify_layer_type(992, n_, k_, name) + else: + layer_type = 'QKVO' + # Activations + if layer_type is None and any(x in name_lower for x in ['act_and_mul', 'silu', 'gelu', 'activation']): + layer_type = 'FFN' + if layer_type is None and any(x in name_lower for x in ['moe', 'router', 'topk', 'expert_tokens', 'router_gemm']): + layer_type = 'FFN' + if layer_type is None and any(x in name_lower for x in ['quant', 'per_token_group']): + layer_type = 'QKVO' + if layer_type is None and any(x in name_lower for x in ['kv_buffer', 'kv_cache', 'mla_k', 'mla_v']): + layer_type = 'QKVO' + if layer_type is None: + layer_type = 'Other' + lt = layer_times[layer_type] + lt['time_us'] += dur + lt['count'] += 1 + total_time = sum(lt['time_us'] for lt in layer_times.values()) + for lt in layer_times.values(): + lt['percentage'] = (lt['time_us'] / total_time * 100) if total_time > 0 else 0 + lt['time_ms'] = lt['time_us'] / 1000.0 + lt['time_ms_per_gpu'] = lt['time_us'] / num_gpus / 1000.0 + layer_times['_total'] = { + 'time_us': total_time, + 'time_ms': total_time / 1000.0, + 'time_ms_per_gpu': total_time / num_gpus / 1000.0, + 'num_gpus': num_gpus, + } + return layer_times + + +############################################################################### +# Communication overlap analysis +############################################################################### + +def analyze_communication_overlap(events: List[Dict[str, Any]], warmup_threshold_s: float = 1.0) -> Dict[str, Any]: + """Analyse communication overlap with compute across GPUs. + + The returned dict contains overall communication time broken down into + same‑GPU overlap, cross‑GPU pipeline overlap, exposed time and warmup. + It also contains per‑type breakdown for NCCL operations. + """ + from collections import defaultdict + warmup_threshold_us = warmup_threshold_s * 1e6 + # Group kernels per GPU + kernels_by_gpu = defaultdict(list) + for e in events: + if e.get('cat') != 'kernel': + continue + name = e.get('name', '') + ts = e.get('ts', 0) + dur = e.get('dur', 0) + pid = e.get('pid') + tid = e.get('tid') + if dur <= 0: + continue + name_lower = name.lower() + is_comm = any(x in name_lower for x in ['nccl', 'cross_device_reduce', 'allreduce', 'allgather', 'all_gather', 'reducescatter', 'reduce_scatter', 'alltoall', 'broadcast']) + kernels_by_gpu[pid].append({ + 'name': name, + 'ts': ts, + 'dur': dur, + 'end': ts + dur, + 'tid': tid, + 'is_comm': is_comm, + }) + gpus = sorted(kernels_by_gpu.keys()) + for gpu in kernels_by_gpu: + kernels_by_gpu[gpu].sort(key=lambda x: x['ts']) + total_comm_time = 0 + same_gpu_overlap = 0 + cross_gpu_overlap = 0 + no_overlap = 0 + warmup_time = 0 + warmup_count = 0 + comm_by_type = defaultdict(lambda: { + 'count': 0, 'time_us': 0, + 'same_gpu_overlap_us': 0, 'cross_gpu_overlap_us': 0, 'no_overlap_us': 0, + 'warmup_count': 0, 'warmup_time_us': 0, + }) + for gpu, kernels in kernels_by_gpu.items(): + other_gpus = [g for g in gpus if g != gpu] + for ck in kernels: + if not ck['is_comm']: + continue + ck_start, ck_end = ck['ts'], ck['end'] + ck_dur = ck['dur'] + # Identify type + name_lower = ck['name'].lower() + if 'cross_device_reduce' in name_lower: + kernel_type = 'cross_device_reduce' + elif 'allreduce' in name_lower: + kernel_type = 'allreduce' + elif 'allgather' in name_lower or 'all_gather' in name_lower: + kernel_type = 'all_gather' + elif 'reducescatter' in name_lower or 'reduce_scatter' in name_lower: + kernel_type = 'reduce_scatter' + else: + kernel_type = 'other_comm' + # Warmup detection + if ck_dur > warmup_threshold_us: + warmup_time += ck_dur + warmup_count += 1 + comm_by_type[kernel_type]['warmup_count'] += 1 + comm_by_type[kernel_type]['warmup_time_us'] += ck_dur + continue + total_comm_time += ck_dur + comm_by_type[kernel_type]['count'] += 1 + comm_by_type[kernel_type]['time_us'] += ck_dur + # Same GPU overlap + same_gpu_compute = [k for k in kernels if not k['is_comm'] and k['ts'] < ck_end and k['end'] > ck_start and k['tid'] != ck['tid']] + cross_gpu_compute = [] + for other_gpu in other_gpus: + for ok in kernels_by_gpu[other_gpu]: + if not ok['is_comm'] and ok['ts'] < ck_end and ok['end'] > ck_start: + cross_gpu_compute.append(ok) + same_overlap = 0 + if same_gpu_compute: + intervals = [] + for ok in same_gpu_compute: + intervals.append((max(ck_start, ok['ts']), min(ck_end, ok['end']))) + intervals.sort() + merged = [intervals[0]] if intervals else [] + for s, e in intervals[1:]: + if s <= merged[-1][1]: + merged[-1] = (merged[-1][0], max(merged[-1][1], e)) + else: + merged.append((s, e)) + same_overlap = sum(e - s for s, e in merged) + cross_overlap_time = 0 + if cross_gpu_compute: + intervals = [] + for ok in cross_gpu_compute: + intervals.append((max(ck_start, ok['ts']), min(ck_end, ok['end']))) + intervals.sort() + merged = [intervals[0]] if intervals else [] + for s, e in intervals[1:]: + if s <= merged[-1][1]: + merged[-1] = (merged[-1][0], max(merged[-1][1], e)) + else: + merged.append((s, e)) + cross_overlap_time = sum(e - s for s, e in merged) + # Use the maximum overlap (since overlaps can overlap) + total_overlap_time = max(same_overlap, cross_overlap_time) + exposed = ck_dur - total_overlap_time + same_gpu_overlap += same_overlap + cross_gpu_overlap += cross_overlap_time + no_overlap += exposed + comm_by_type[kernel_type]['same_gpu_overlap_us'] += same_overlap + comm_by_type[kernel_type]['cross_gpu_overlap_us'] += cross_overlap_time + comm_by_type[kernel_type]['no_overlap_us'] += exposed + return { + 'total_comm_time_us': total_comm_time, + 'same_gpu_overlap_us': same_gpu_overlap, + 'cross_gpu_overlap_us': cross_gpu_overlap, + 'exposed_time_us': no_overlap, + 'warmup_time_us': warmup_time, + 'warmup_count': warmup_count, + 'by_type': dict(comm_by_type), + 'num_gpus': max(len(gpus), 1), + } + + +############################################################################### +# Network roofline analysis +############################################################################### + +def analyze_network_roofline(events: List[Dict[str, Any]], gemm_infos: List[GemmInfo], + gpu_specs: GPUSpecs, tp_degree: int) -> Dict[str, Any]: + """Analyse network communication roofline for tensor parallel GEMMs. + + Returns a dict describing critical arithmetic intensities and per‑phase + operations. See the original implementation for details. + """ + # NVLink bandwidth in bytes/s + nvlink_bw_bytes = gpu_specs.nvlink_bw_gb_s * 1e9 + def peak_flops_for_dtype(dtype: str) -> float: + return get_dtype_peak_tflops(dtype, gpu_specs) * 1e12 + dtype_list = ['fp8', 'bf16', 'fp16'] + critical_ai_hbm = {dt: (peak_flops_for_dtype(dt) / (gpu_specs.memory_bw_tb_s * 1e12) if gpu_specs.memory_bw_tb_s > 0 else float('inf')) for dt in dtype_list} + critical_ai_network = {dt: (peak_flops_for_dtype(dt) / nvlink_bw_bytes if nvlink_bw_bytes > 0 else float('inf')) for dt in dtype_list} + # Identify AllReduce kernels + # Treat both "allreduce" and "cross_device_reduce" kernels as AllReduce + allreduce_durations = [ + e.get('dur', 0) + for e in events + if e.get('cat') == 'kernel' + and (('allreduce' in e.get('name', '').lower()) or ('cross_device_reduce' in e.get('name', '').lower())) + and e.get('dur', 0) <= 1e6 + ] + results = { + 'critical_ai_hbm': critical_ai_hbm, + 'critical_ai_network': critical_ai_network, + 'nvlink_bw_gb_s': gpu_specs.nvlink_bw_gb_s, + 'tp_degree': tp_degree, + 'phases': {}, + 'allreduce_stats': { + 'count': len(allreduce_durations), + 'total_time_us': sum(allreduce_durations), + 'avg_time_us': (sum(allreduce_durations) / len(allreduce_durations)) if allreduce_durations else 0, + }, + } + # Split gemms into prefill vs decode based on M dimension + prefill_gemms = [g for g in gemm_infos if g.m > 128] + decode_gemms = [g for g in gemm_infos if g.m <= 128] + for phase_name, phase_gemms in [('prefill', prefill_gemms), ('decode', decode_gemms)]: + if not phase_gemms: + continue + # Pick the most common M + m_values = [g.m for g in phase_gemms] + M = max(set(m_values), key=m_values.count) + # Aggregate by (N,K,dtype,out_dtype) + dim_stats = defaultdict(lambda: {'count': 0, 'time_us': 0, 'flops': 0}) + for g in phase_gemms: + key = (g.n, g.k, g.dtype, g.output_dtype or 'bf16') + dim_stats[key]['count'] += 1 + dim_stats[key]['time_us'] += g.duration_us + dim_stats[key]['flops'] += g.flops + phase_ops = [] + hidden_size = 7168 # DeepSeek assumption + for (N, K, dt, out_dt), stats in dim_stats.items(): + is_row_parallel = (N == hidden_size) + flops_per_gpu = 2 * M * N * K + if is_row_parallel: + dtype_bytes = int(get_bytes_per_element(out_dt)) if out_dt else 2 + allreduce_bytes = 2 * (tp_degree - 1) / tp_degree * M * N * dtype_bytes + network_ai = flops_per_gpu / allreduce_bytes if allreduce_bytes > 0 else float('inf') + t_network_us = allreduce_bytes / nvlink_bw_bytes * 1e6 + else: + allreduce_bytes = 0 + network_ai = float('inf') + t_network_us = 0 + peak_flops_op = peak_flops_for_dtype(dt) + t_compute_us = flops_per_gpu / peak_flops_op * 1e6 if peak_flops_op > 0 else float('inf') + bound = 'network' if is_row_parallel and network_ai < critical_ai_network.get(dt, float('inf')) else 'compute' + phase_ops.append({ + 'M': M, 'N': N, 'K': K, 'dtype': dt, 'out_dtype': out_dt, + 'parallelism': 'row-parallel' if is_row_parallel else 'column-parallel', + 'flops_per_gpu': flops_per_gpu, + 'allreduce_bytes': allreduce_bytes, + 'network_ai': network_ai, + 't_compute_us': t_compute_us, + 't_network_us': t_network_us, + 'bound': bound, + 'kernel_count': stats['count'], + 'measured_time_us': stats['time_us'], + }) + results['phases'][phase_name] = { + 'M': M, + 'operations': phase_ops, + 'total_gemm_time_us': sum(g.duration_us for g in phase_gemms), + 'total_gemm_count': len(phase_gemms), + } + return results + + +############################################################################### +# MFU/MBU annotation of trace events +############################################################################### + +def add_mfu_to_trace(trace_data: Dict[str, Any], gpu_specs: GPUSpecs, config: Config) -> Dict[str, Any]: + """Annotate kernel events in a trace with MFU/MBU metrics. + + This function mirrors the logic of :func:`analyze_gemm_kernels` but writes + metrics back into the ``args`` dict of each kernel or CPU op. It can be + used to generate an enriched trace JSON which can be visualised in + Chrome’s trace viewer. + """ + events = trace_data.get('traceEvents', []) + cpu_op_dims = build_cpu_op_dims_map(events) + sibling_dims = build_sibling_dims_map(events, cpu_op_dims) + kernel_times_by_key = defaultdict(float) + for event in events: + if event.get('cat') == 'kernel': + ext_id = event.get('args', {}).get('External id') + tp_rank = extract_tp_rank(event.get('pid')) + if ext_id is not None: + kernel_times_by_key[(tp_rank, ext_id)] += event.get('dur', 0) + modified_count = 0 + for event in events: + if event.get('cat') == 'kernel': + name = event.get('name', '') + classification = classify_kernel(name) + if not classification.is_gemm: + continue + duration_us = event.get('dur', 0) + if duration_us <= 0: + continue + ext_id = event.get('args', {}).get('External id') + tp_rank = extract_tp_rank(event.get('pid')) + grid = event.get('args', {}).get('grid', [1, 1, 1]) + # Extract dims + dims = None + if ext_id is not None: + dims = cpu_op_dims.get((tp_rank, ext_id)) or cpu_op_dims.get((tp_rank, ext_id - 1)) or cpu_op_dims.get((tp_rank, ext_id + 1)) + if dims is None and classification.source == 'deep_gemm': + parsed = parse_deep_gemm_kernel_dims(name, grid, None) + if parsed: + m_, n_, k_, dtype_ = parsed + # Use the configured model dtype for the output rather than a fixed + # BF16. See discussion in :func:`analyze_gemm_kernels`. + dims = (m_, n_, k_, 'fp8', 'fp8', normalize_dtype(config.model_dtype)) + inferred_layer_type = None + if dims is None and ext_id is None: + inferred = infer_cuda_graph_kernel_dims(name, grid, config, sibling_dims=sibling_dims) + if inferred: + m_, n_, k_, dtype_, inferred_layer_type = inferred + # Use model dtype for fallback output dtype + c_dtype = normalize_dtype(config.model_dtype) + dims = (m_, n_, k_, dtype_, dtype_, c_dtype) + if dims is None: + continue + if len(dims) >= 6: + m, n, k, a_dtype, b_dtype, c_dtype = dims[:6] + elif len(dims) == 5: + m, n, k, input_dtype, output_dtype = dims + a_dtype = b_dtype = input_dtype + c_dtype = output_dtype + else: + m, n, k, input_dtype = dims + a_dtype = b_dtype = input_dtype + c_dtype = normalize_dtype(config.model_dtype) + if m <= 0 or n <= 0 or k <= 0: + continue + if not a_dtype and classification.dtype: + a_dtype = classification.dtype + if not b_dtype and classification.dtype: + b_dtype = classification.dtype + if not c_dtype and classification.dtype: + c_dtype = 'bf16' if classification.dtype == 'fp8' else classification.dtype + dtype = compute_dtype_from_inputs(a_dtype, b_dtype) + # Metrics + flops = calculate_gemm_flops(m, n, k) + bytes_accessed = calculate_gemm_bytes(m, n, k, a_dtype, b_dtype, c_dtype) + peak_tflops = get_dtype_peak_tflops(dtype, gpu_specs) + mfu = calculate_mfu(flops, duration_us, peak_tflops) + mbu = calculate_mbu(bytes_accessed, duration_us, gpu_specs.memory_bw_tb_s) + ai = calculate_arithmetic_intensity(flops, bytes_accessed) + roofline_tflops, roofline_bound = calculate_roofline_tflops(ai, gpu_specs, peak_tflops) + achieved_tflops = (flops / 1e12) / (duration_us / 1e6) + achieved_bw_tb_s = (bytes_accessed / 1e12) / (duration_us / 1e6) + layer_type = inferred_layer_type if inferred_layer_type else classify_layer_type(m, n, k, name) + if 'args' not in event: + event['args'] = {} + args = event['args'] + args['MFU (%)'] = round(mfu, 2) + args['MBU (%)'] = round(mbu, 2) + args['Achieved TFLOPS'] = round(achieved_tflops, 2) + args['Peak TFLOPS'] = round(peak_tflops, 2) + args['Roofline TFLOPS'] = round(roofline_tflops, 2) + args['Roofline Bound'] = roofline_bound + args['Achieved BW (TB/s)'] = round(achieved_bw_tb_s, 3) + args['Peak BW (TB/s)'] = round(gpu_specs.memory_bw_tb_s, 2) + args['Arithmetic Intensity'] = round(ai, 2) + args['FLOPs'] = flops + args['Bytes'] = bytes_accessed + args['GEMM M'] = m + args['GEMM N'] = n + args['GEMM K'] = k + # Record the compute dtype (kernel execution precision) and individual input/output dtypes. + # ``dtype`` is the compute/accumulator dtype used for performance calculations. + args['GEMM dtype'] = dtype + args['GEMM A dtype'] = a_dtype + args['GEMM B dtype'] = b_dtype + args['GEMM C dtype'] = c_dtype + # Aggregate input and output dtypes: if A and B have the same dtype + # then the input dtype is that dtype; otherwise denote as 'mixed'. + input_dtype = a_dtype if a_dtype == b_dtype else 'mixed' + args['GEMM input dtype'] = input_dtype + # The output dtype corresponds to C's dtype. + args['GEMM output dtype'] = c_dtype + args['Layer Type'] = layer_type + modified_count += 1 + elif event.get('cat') == 'cpu_op': + name = event.get('name', '') + if not any(x in name.lower() for x in ['deep_gemm', 'fp8_gemm']): + continue + dims = extract_dimensions_from_cpu_op(event) + if dims: + if len(dims) >= 6: + m, n, k, a_dtype, b_dtype, c_dtype = dims[:6] + dtype = compute_dtype_from_inputs(a_dtype, b_dtype) + elif len(dims) == 5: + m, n, k, input_dtype, output_dtype = dims + a_dtype = b_dtype = input_dtype + c_dtype = output_dtype + dtype = compute_dtype_from_inputs(a_dtype, b_dtype) + else: + m, n, k, input_dtype = dims + a_dtype = b_dtype = input_dtype + # Use model default dtype for the output when the profiler does + # not record C dtype for this CPU op + c_dtype = normalize_dtype(config.model_dtype) + dtype = compute_dtype_from_inputs(a_dtype, b_dtype) + ext_id = event.get('args', {}).get('External id') + tp_rank = extract_tp_rank(event.get('pid')) + key = (tp_rank, ext_id) if ext_id is not None else None + duration_us = event.get('dur', 0) + if key and key in kernel_times_by_key: + duration_us = kernel_times_by_key[key] + if duration_us > 0: + flops = calculate_gemm_flops(m, n, k) + bytes_accessed = calculate_gemm_bytes(m, n, k, a_dtype, b_dtype, c_dtype) + peak_tflops = get_dtype_peak_tflops(dtype, gpu_specs) + mfu = calculate_mfu(flops, duration_us, peak_tflops) + mbu = calculate_mbu(bytes_accessed, duration_us, gpu_specs.memory_bw_tb_s) + ai = calculate_arithmetic_intensity(flops, bytes_accessed) + roofline_tflops, roofline_bound = calculate_roofline_tflops(ai, gpu_specs, peak_tflops) + if 'args' not in event: + event['args'] = {} + args = event['args'] + args['MFU (%)'] = round(mfu, 2) + args['MBU (%)'] = round(mbu, 2) + args['Achieved TFLOPS'] = round((flops / 1e12) / (duration_us / 1e6), 2) + args['Roofline TFLOPS'] = round(roofline_tflops, 2) + args['Roofline Bound'] = roofline_bound + args['Arithmetic Intensity'] = round(ai, 2) + # Annotate compute, individual input/output dtypes and aggregate input/output dtypes. + args['GEMM dtype'] = dtype + args['GEMM A dtype'] = a_dtype + args['GEMM B dtype'] = b_dtype + args['GEMM C dtype'] = c_dtype + input_dtype = a_dtype if a_dtype == b_dtype else 'mixed' + args['GEMM input dtype'] = input_dtype + args['GEMM output dtype'] = c_dtype + modified_count += 1 + # Could print or log modified_count here if desired + return trace_data + + +############################################################################### +# Trace loading and saving +############################################################################### + +def load_trace(input_path: str) -> Dict[str, Any]: + """Load a trace JSON or JSON.GZ file from disk.""" + path = Path(input_path) + if path.suffix == '.gz': + with gzip.open(path, 'rt', encoding='utf-8') as f: + return json.load(f) + else: + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + + +def save_trace(trace_data: Dict[str, Any], output_path: str, compress: bool = False): + """Save a trace dict to disk, optionally compressing with gzip.""" + path = Path(output_path) + if compress or path.suffix == '.gz': + with gzip.open(path, 'wt', encoding='utf-8') as f: + json.dump(trace_data, f) + else: + with open(path, 'w', encoding='utf-8') as f: + json.dump(trace_data, f, indent=2) + + +############################################################################### +# Summary reporting +############################################################################### + +def print_summary(gemm_infos: List[GemmInfo], layer_times: Dict[str, Any], gpu_specs: GPUSpecs, + comm_overlap: Optional[Dict[str, Any]] = None, + network_roofline: Optional[Dict[str, Any]] = None, + events: Optional[List[Dict[str, Any]]] = None, + grouped_gemm_infos: Optional[List[GroupedGemmInfo]] = None): + """Print a comprehensive analysis summary to stdout. + + This function consolidates the print logic from the original script into a + single place. It accepts pre‑computed GEMM information, layer timing + breakdown, communication overlap statistics, network roofline analysis and + grouped GEMM information. All printing is performed here to avoid + scattering summary code throughout the analysis. + """ + if not gemm_infos: + print("No GEMM operations found") + return + num_gpus = layer_times.get('_total', {}).get('num_gpus', 1) + total_flops = sum(g.flops for g in gemm_infos) + total_bytes = sum(g.bytes_accessed for g in gemm_infos) + total_time_us = sum(g.duration_us for g in gemm_infos) + per_gpu_time_us = total_time_us / num_gpus + per_gpu_flops = total_flops / num_gpus + per_gpu_time_s = per_gpu_time_us / 1e6 if num_gpus > 0 else 0 + overall_tflops = (per_gpu_flops / 1e12) / per_gpu_time_s if per_gpu_time_s > 0 else 0 + def fmt_tflops(tf: float) -> str: + return f"{tf/1000:.1f} PFLOPS" if tf >= 1000 else f"{tf:.1f} TFLOPS" + avg_mfu = (sum(g.mfu * g.duration_us for g in gemm_infos) / total_time_us) if total_time_us > 0 else 0 + avg_mbu = (sum(g.mbu * g.duration_us for g in gemm_infos) / total_time_us) if total_time_us > 0 else 0 + print("\n" + "="*80) + print("GEMM/MatMul Analysis Summary (MFU, MBU, Roofline)") + print("="*80) + print(f"GPU: {gpu_specs.name} (x{num_gpus} GPUs in trace)") + if gpu_specs.fp4_tflops > 0: + print(f"Peak FP4: {fmt_tflops(gpu_specs.fp4_tflops)}") + print(f"Peak FP8: {fmt_tflops(gpu_specs.fp8_tflops)}") + print(f"Peak BF16: {fmt_tflops(gpu_specs.fp16_tflops)}") + print(f"Peak Memory BW: {gpu_specs.memory_bw_tb_s:.2f} TB/s") + print(f"L2 Cache: {gpu_specs.l2_cache_mb:.0f} MB") + print("-"*80) + print(f"Total GEMM kernels analysed: {len(gemm_infos)} (with known M dimension)") + print(f"Total GEMM FLOPs: {total_flops / 1e12:.2f} TFLOPs ({per_gpu_flops / 1e12:.2f} per GPU)") + print(f"Total GEMM bytes: {total_bytes / 1e9:.2f} GB") + print(f"Total GEMM time: {total_time_us/1000:.2f} ms ({per_gpu_time_us/1000:.2f} ms per GPU)") + print(f"Average TFLOPS (per GPU): {overall_tflops:.2f}") + print(f"Weighted Average MFU: {avg_mfu:.2f}%") + print(f"Weighted Average MBU: {avg_mbu:.2f}%") + # Group by dtype + dtype_groups = defaultdict(list) + for g in gemm_infos: + dtype_groups[g.dtype].append(g) + print("-"*80) + print("By Data Type:") + for dtype, ops in dtype_groups.items(): + time_sum = sum(g.duration_us for g in ops) + avg_mfu_d = (sum(g.mfu * g.duration_us for g in ops) / time_sum) if time_sum > 0 else 0 + avg_mbu_d = (sum(g.mbu * g.duration_us for g in ops) / time_sum) if time_sum > 0 else 0 + print(f" {dtype.upper():<5}: {len(ops)} ops, {time_sum/1000/num_gpus:.2f} ms/GPU, MFU: {avg_mfu_d:.2f}%, MBU: {avg_mbu_d:.2f}%") + # Roofline bound breakdown + memory_bound = [g for g in gemm_infos if g.roofline_bound == 'memory'] + compute_bound = [g for g in gemm_infos if g.roofline_bound == 'compute'] + print("-"*80) + print("By Roofline Bound:") + if memory_bound: + mb_time = sum(g.duration_us for g in memory_bound) + mb_avg_mbu = (sum(g.mbu * g.duration_us for g in memory_bound) / mb_time) if mb_time > 0 else 0 + mb_avg_bw = (sum(g.achieved_bw_tb_s * g.duration_us for g in memory_bound) / mb_time * 1000) if mb_time > 0 else 0 + print(f" Memory-bound: {len(memory_bound)} ops, {mb_time/1000/num_gpus:.2f} ms/GPU ({mb_time/total_time_us*100:.1f}%)") + print(f" Avg MBU: {mb_avg_mbu:.1f}%, Avg BW: {mb_avg_bw:.0f} GB/s") + if compute_bound: + cb_time = sum(g.duration_us for g in compute_bound) + cb_avg_mfu = (sum(g.mfu * g.duration_us for g in compute_bound) / cb_time) if cb_time > 0 else 0 + print(f" Compute-bound: {len(compute_bound)} ops, {cb_time/1000/num_gpus:.2f} ms/GPU ({cb_time/total_time_us*100:.1f}%)") + print(f" Avg MFU: {cb_avg_mfu:.1f}%") + # Phase breakdown + prefill_ops = [g for g in gemm_infos if g.m > 128] + decode_ops = [g for g in gemm_infos if g.m <= 128] + print("-"*80) + print("By Phase (based on M dimension):") + if prefill_ops: + pf_time = sum(g.duration_us for g in prefill_ops) + pf_avg_mfu = (sum(g.mfu * g.duration_us for g in prefill_ops) / pf_time) if pf_time > 0 else 0 + pf_avg_mbu = (sum(g.mbu * g.duration_us for g in prefill_ops) / pf_time) if pf_time > 0 else 0 + pf_avg_bw = (sum(g.achieved_bw_tb_s * g.duration_us for g in prefill_ops) / pf_time * 1000) if pf_time > 0 else 0 + common_m = max(set([g.m for g in prefill_ops]), key=lambda x: sum(1 for g in prefill_ops if g.m == x)) + print(f" Prefill (M={common_m}): {len(prefill_ops)} ops, {pf_time/1000/num_gpus:.2f} ms/GPU ({pf_time/total_time_us*100:.1f}%)") + print(f" MFU: {pf_avg_mfu:.1f}%, MBU: {pf_avg_mbu:.1f}%, BW: {pf_avg_bw:.0f} GB/s") + if decode_ops: + dc_time = sum(g.duration_us for g in decode_ops) + dc_avg_mfu = (sum(g.mfu * g.duration_us for g in decode_ops) / dc_time) if dc_time > 0 else 0 + dc_avg_mbu = (sum(g.mbu * g.duration_us for g in decode_ops) / dc_time) if dc_time > 0 else 0 + dc_avg_bw = (sum(g.achieved_bw_tb_s * g.duration_us for g in decode_ops) / dc_time * 1000) if dc_time > 0 else 0 + common_m = max(set([g.m for g in decode_ops]), key=lambda x: sum(1 for g in decode_ops if g.m == x)) if decode_ops else 0 + print(f" Decode (M={common_m}): {len(decode_ops)} ops, {dc_time/1000/num_gpus:.2f} ms/GPU ({dc_time/total_time_us*100:.1f}%)") + if dc_avg_mbu > 100: + print(f" MFU: {dc_avg_mfu:.1f}%, MBU: {dc_avg_mbu:.1f}% (INVALID - dimension inference error)") + print(f" BW: {dc_avg_bw:.0f} GB/s") + print(f" Note: MBU>100% is physically impossible and indicates dimension inference errors.") + else: + print(f" MFU: {dc_avg_mfu:.1f}%, MBU: {dc_avg_mbu:.1f}%, BW: {dc_avg_bw:.0f} GB/s") + # Top 10 by MFU + print("\n" + "-"*80) + print("Top 10 GEMMs by MFU:") + # Include per-operand dtypes when reporting top kernels. Compute dtype + # ``g.dtype`` may differ from individual operand dtypes (A/B/C). Show + # A/B/C explicitly for clarity. + for i, g in enumerate(sorted(gemm_infos, key=lambda g: g.mfu, reverse=True)[:10]): + roof_eff = (g.tflops / g.roofline_tflops * 100) if g.roofline_tflops > 0 else 0 + mbu_flag = " (INVALID)" if g.mbu > 100 else "" + print(f" {i+1}. M={g.m}, N={g.n}, K={g.k}, compute={g.dtype}, A={g.a_dtype}, B={g.b_dtype}, C={g.c_dtype}, {g.layer_type}: MFU={g.mfu:.2f}%{mbu_flag}, MBU={g.mbu:.1f}%, Roofline Eff={roof_eff:.1f}%") + print(f" Achieved={g.tflops:.1f} TFLOPS, BW={g.achieved_bw_tb_s*1000:.0f} GB/s, AI={g.arithmetic_intensity:.1f}") + print(f" [Trace: TP{g.tp_rank}, ts={g.timestamp_us}]") + # Bottom 10 by MFU (duration > 5us) + print("\nBottom 10 GEMMs by MFU (duration > 5us):") + significant_ops = [g for g in gemm_infos if g.duration_us > 5] + for i, g in enumerate(sorted(significant_ops, key=lambda g: g.mfu)[:10]): + roof_eff = (g.tflops / g.roofline_tflops * 100) if g.roofline_tflops > 0 else 0 + print(f" {i+1}. M={g.m}, N={g.n}, K={g.k}, compute={g.dtype}, A={g.a_dtype}, B={g.b_dtype}, C={g.c_dtype}, {g.layer_type}: MFU={g.mfu:.2f}%, MBU={g.mbu:.1f}%, Roofline Eff={roof_eff:.1f}%") + print(f" Achieved={g.tflops:.1f} TFLOPS, BW={g.achieved_bw_tb_s*1000:.0f} GB/s, AI={g.arithmetic_intensity:.1f}") + print(f" [Trace: TP{g.tp_rank}, ts={g.timestamp_us}]") + # Top 10 by MBU + print("\n" + "-"*80) + print("Top 10 GEMMs by MBU (Memory Bandwidth Utilisation):") + high_mbu = [g for g in gemm_infos if g.mbu > 100] + if high_mbu: + print(" WARNING: MBU > 100% indicates dimension inference error!") + print(" These kernels likely lack External ID and were inferred incorrectly.") + print() + for i, g in enumerate(sorted(gemm_infos, key=lambda g: g.mbu, reverse=True)[:10]): + if g.mbu > 100: + weight_time_us = g.weight_bytes / (gpu_specs.memory_bw_tb_s * 1e12) * 1e6 if g.weight_bytes > 0 else 0 + print(f" {i+1}. M={g.m}, N={g.n}, K={g.k}, compute={g.dtype}, A={g.a_dtype}, B={g.b_dtype}, C={g.c_dtype}, {g.layer_type}: MBU={g.mbu:.1f}% (INVALID)") + print(f" Weight load time {weight_time_us:.1f}µs > kernel {g.duration_us:.1f}µs") + else: + peak_gb_s = gpu_specs.memory_bw_tb_s * 1000 + print(f" {i+1}. M={g.m}, N={g.n}, K={g.k}, compute={g.dtype}, A={g.a_dtype}, B={g.b_dtype}, C={g.c_dtype}, {g.layer_type}: MBU={g.mbu:.1f}%, BW={g.achieved_bw_tb_s*1000:.0f} GB/s (peak: {peak_gb_s:.0f} GB/s)") + print(f" MFU={g.mfu:.2f}%, AI={g.arithmetic_intensity:.1f}, {g.roofline_bound}-bound") + print(f" [Trace: TP{g.tp_rank}, ts={g.timestamp_us}]") + # Top 10 by time + print("\n" + "-"*80) + print("Top 10 GEMMs by time:") + for i, g in enumerate(sorted(gemm_infos, key=lambda g: g.duration_us, reverse=True)[:10]): + print(f" {i+1}. M={g.m}, N={g.n}, K={g.k}, compute={g.dtype}, A={g.a_dtype}, B={g.b_dtype}, C={g.c_dtype}, {g.layer_type}: {g.duration_us:.2f}us, MFU={g.mfu:.2f}%, {g.tflops:.1f} TFLOPS, AI={g.arithmetic_intensity:.1f}") + print(f" [Trace: TP{g.tp_rank}, ts={g.timestamp_us}]") + # Grouped GEMM summary + if grouped_gemm_infos: + print("\n" + "="*80) + print("Grouped GEMM Analysis (Fused MoE)") + print("="*80) + print(f"GPU: {gpu_specs.name}") + print(f"Peak FP8: {gpu_specs.fp8_tflops/1000:.1f} PFLOPS") + print("-"*80) + prefill_ops = [g for g in grouped_gemm_infos if g.external_id > 0] + decode_ops = [g for g in grouped_gemm_infos if g.external_id == 0] + total_grouped_flops = sum(g.total_flops for g in grouped_gemm_infos) + total_grouped_bytes = sum(g.total_bytes for g in grouped_gemm_infos) + total_grouped_time = sum(g.duration_us for g in grouped_gemm_infos) + print(f"Total grouped GEMM operations: {len(grouped_gemm_infos)}") + print(f" Prefill ops (with External ID): {len(prefill_ops)}") + print(f" Decode ops (CUDA Graph/inferred): {len(decode_ops)}") + print(f"Total FLOPs: {total_grouped_flops/1e12:.2f} TFLOPs") + print(f"Total bytes: {total_grouped_bytes/1e9:.2f} GB") + print() + if prefill_ops: + print("Prefill Phase (fused_moe_kernel with External ID):") + pf_flops = sum(g.total_flops for g in prefill_ops) + pf_time = sum(g.duration_us for g in prefill_ops) + pf_bytes = sum(g.total_bytes for g in prefill_ops) + sample = prefill_ops[0] + print(f" Dimensions: {sample.num_tokens} tokens × top_{sample.top_k} experts") + print(f" {sample.num_experts} total experts, hidden={sample.hidden_size}") + print(f" w1_inter={sample.w1_intermediate}, w2_inter={sample.w2_intermediate}") + print(f" Token-expert pairs: {sample.total_token_expert_pairs}") + print() + pf_avg_mfu = (sum(g.mfu * g.duration_us for g in prefill_ops) / pf_time) if pf_time > 0 else 0 + pf_avg_mbu = (sum(g.mbu * g.duration_us for g in prefill_ops) / pf_time) if pf_time > 0 else 0 + pf_avg_bw = (sum(g.achieved_bw_tb_s * g.duration_us for g in prefill_ops) / pf_time * 1000) if pf_time > 0 else 0 + pf_avg_tflops = (pf_flops / 1e12) / (pf_time / 1e6) if pf_time > 0 else 0 + print(f" Total time: {pf_time/1000:.2f} ms ({len(prefill_ops)} ops)") + print(f" Total FLOPs: {pf_flops/1e12:.2f} TFLOPs") + print(f" Achieved: {pf_avg_tflops:.1f} TFLOPS") + print(f" MFU: {pf_avg_mfu:.1f}%, MBU: {pf_avg_mbu:.1f}%") + print(f" Bandwidth: {pf_avg_bw:.0f} GB/s") + print(f" Arithmetic Intensity: {sample.arithmetic_intensity:.1f} FLOPs/byte") + print(f" Roofline bound: {sample.roofline_bound}") + print() + print(" Top 5 Prefill MoE ops by MFU:") + for i, g in enumerate(sorted(prefill_ops, key=lambda g: g.mfu, reverse=True)[:5]): + print(f" {i+1}. {g.num_tokens}tok×top{g.top_k}, {g.weight_dtype}: MFU={g.mfu:.1f}%, {g.tflops:.1f} TFLOPS, {g.duration_us:.1f}us") + print(f" [ExtID={g.external_id}, TP{g.tp_rank}]") + if decode_ops: + print() + print("Decode Phase (fused_moe_kernel, CUDA Graph):") + dc_flops = sum(g.total_flops for g in decode_ops) + dc_time = sum(g.duration_us for g in decode_ops) + dc_bytes = sum(g.total_bytes for g in decode_ops) + w1_ops = [g for g in decode_ops if g.w1_intermediate > 0] + w2_ops = [g for g in decode_ops if g.w2_intermediate > 0] + if w1_ops or w2_ops: + sample = w1_ops[0] if w1_ops else (w2_ops[0] if w2_ops else None) + if sample: + print(f" Dimensions (inferred): {sample.num_tokens} tokens × top_{sample.top_k} experts") + print(f" {sample.num_experts} experts, hidden={sample.hidden_size}") + print(f" w1_inter={sample.w1_intermediate if w1_ops else 0} (gate+up), w2_inter={sample.w2_intermediate if w2_ops else 0} (down)") + print() + dc_avg_mfu = (sum(g.mfu * g.duration_us for g in decode_ops) / dc_time) if dc_time > 0 else 0 + dc_avg_mbu = (sum(g.mbu * g.duration_us for g in decode_ops) / dc_time) if dc_time > 0 else 0 + dc_avg_bw = (sum(g.achieved_bw_tb_s * g.duration_us for g in decode_ops) / dc_time * 1000) if dc_time > 0 else 0 + dc_avg_tflops = (dc_flops / 1e12) / (dc_time / 1e6) if dc_time > 0 else 0 + total_kernels = sum(g.num_kernels for g in decode_ops) + print(f" Total time: {dc_time/1000:.2f} ms/GPU ({total_kernels} kernels across all GPUs)") + print(f" Total FLOPs: {dc_flops/1e12:.4f} TFLOPs") + print(f" Achieved: {dc_avg_tflops:.1f} TFLOPS") + print(f" MFU: {dc_avg_mfu:.1f}%, MBU: {dc_avg_mbu:.1f}%") + print(f" Bandwidth: {dc_avg_bw:.0f} GB/s") + if w1_ops and w2_ops: + print() + print(" By projection type:") + w1_time = sum(g.duration_us for g in w1_ops) + w1_flops = sum(g.total_flops for g in w1_ops) + w1_mfu = (sum(g.mfu * g.duration_us for g in w1_ops) / w1_time) if w1_time > 0 else 0 + print(f" W1 (gate+up): {w1_time/1000:.2f}ms, {w1_flops/1e12:.4f} TFLOPs, MFU={w1_mfu:.1f}%") + w2_time = sum(g.duration_us for g in w2_ops) + w2_flops = sum(g.total_flops for g in w2_ops) + w2_mfu = (sum(g.mfu * g.duration_us for g in w2_ops) / w2_time) if w2_time > 0 else 0 + print(f" W2 (down): {w2_time/1000:.2f}ms, {w2_flops/1e12:.4f} TFLOPs, MFU={w2_mfu:.1f}%") + print("-"*80) + # Layer time breakdown + print("\n" + "="*80) + print("Layer Type Time Breakdown (All Kernels)") + print("="*80) + total_info = layer_times.get('_total', {}) + total_kernel_time = total_info.get('time_ms', 0) + per_gpu_time = total_info.get('time_ms_per_gpu', total_kernel_time) + print(f"Total kernel time (sum across {num_gpus} GPUs): {total_kernel_time:.2f} ms") + print(f"Per-GPU average kernel time: {per_gpu_time:.2f} ms\n") + layer_order = ['QKVO', 'SDPA', 'FFN', 'Normalization', 'Communication', 'Other'] + for layer_name in layer_order: + lt = layer_times.get(layer_name, {}) + time_ms = lt.get('time_ms_per_gpu', 0) + pct = lt.get('percentage', 0) + count = lt.get('count', 0) + print(f" {layer_name:<15s}: {time_ms:10.2f} ms/GPU ({pct:5.1f}%) [{count:6d} kernels]") + # Communication overlap + print("\n" + "-"*80) + print("Communication Overlap Analysis") + print("-"*80) + if comm_overlap: + total_comm_us = comm_overlap['total_comm_time_us'] + same_us = comm_overlap['same_gpu_overlap_us'] + cross_us = comm_overlap['cross_gpu_overlap_us'] + exposed_us = comm_overlap['exposed_time_us'] + warmup_us = comm_overlap.get('warmup_time_us', 0) + warmup_count = comm_overlap.get('warmup_count', 0) + num_gpus_co = comm_overlap['num_gpus'] + total_comm_ms = total_comm_us / 1000 / num_gpus_co + same_ms = same_us / 1000 / num_gpus_co + cross_ms = cross_us / 1000 / num_gpus_co + exposed_ms = exposed_us / 1000 / num_gpus_co + if warmup_count > 0: + print(f" Warmup/barrier kernels excluded: {warmup_count} kernels, {warmup_us/1000/num_gpus_co:.2f} ms/GPU\n") + print(f" Total communication time (excluding warmup): {total_comm_ms:10.2f} ms/GPU") + print() + if total_comm_us > 0: + same_pct = same_us / total_comm_us * 100 + cross_pct = cross_us / total_comm_us * 100 + exposed_pct = exposed_us / total_comm_us * 100 + else: + same_pct = cross_pct = exposed_pct = 0 + print(f" Same-GPU overlap: {same_ms:10.2f} ms/GPU ({same_pct:5.1f}%)") + print(f" (Compute on same GPU, different stream)") + print(f" Cross-GPU pipeline: {cross_ms:10.2f} ms/GPU ({cross_pct:5.1f}%)") + print(f" (Compute on other GPUs - pipeline)") + print(f" Exposed (no overlap): {exposed_ms:10.2f} ms/GPU ({exposed_pct:5.1f}%)") + print() + print(" By communication type:") + by_type = comm_overlap['by_type'] + for ctype, data in sorted(by_type.items(), key=lambda x: -x[1]['time_us']): + if data['count'] == 0 and data.get('warmup_count', 0) == 0: + continue + time_ms_type = data['time_us'] / 1000 / num_gpus_co + cross_pct_type = (data['cross_gpu_overlap_us'] / data['time_us'] * 100) if data['time_us'] > 0 else 0 + exposed_pct_type = (data['no_overlap_us'] / data['time_us'] * 100) if data['time_us'] > 0 else 0 + warmup_info = '' + if data.get('warmup_count', 0) > 0: + warmup_ms_type = data['warmup_time_us'] / 1000 / num_gpus_co + warmup_info = f" (+{data['warmup_count']} warmup, {warmup_ms_type:.1f}ms)" + print(f" {ctype:<25s}: {time_ms_type:8.2f} ms/GPU, {data['count']:5d} calls{warmup_info}") + if data['count'] > 0: + print(f" Pipeline overlap: {cross_pct_type:5.1f}%, Exposed: {exposed_pct_type:5.1f}%") + else: + lt = layer_times.get('Communication', {}) + time_ms_co = lt.get('time_ms_per_gpu', 0) + pct_co = lt.get('percentage', 0) + count_co = lt.get('count', 0) + print(f" Total: {time_ms_co:10.2f} ms/GPU ({pct_co:5.1f}%) [{count_co:6d} kernels]") + print(" (Run with full analysis to see overlap breakdown)") + # Network roofline + if network_roofline: + print("\n" + "-"*80) + print("Network Communication Roofline Analysis") + print("-"*80) + print(" Reference: https://jax-ml.github.io/scaling-book/roofline/\n") + crit_hbm = network_roofline.get('critical_ai_hbm') + crit_net = network_roofline.get('critical_ai_network') + if isinstance(crit_hbm, dict) and isinstance(crit_net, dict): + for dt in ['fp8', 'bf16', 'fp16']: + print(f" {dt.upper():>4s} HBM Roofline: {crit_hbm[dt]:8.1f} FLOPs/byte") + print(f" {dt.upper():>4s} Network Roofline: {crit_net[dt]:8.1f} FLOPs/byte") + else: + print(f" HBM Roofline: {crit_hbm:8.1f} FLOPs/byte") + print(f" Network Roofline: {crit_net:8.1f} FLOPs/byte") + print(f"\n Hardware: NVLink BW = {network_roofline['nvlink_bw_gb_s']:.0f} GB/s, TP = {network_roofline['tp_degree']}") + print() + for phase_name, phase_data in network_roofline.get('phases', {}).items(): + M = phase_data['M'] + print(f" {phase_name.capitalize()} Phase (M={M}):") + print(f" {'Operation':<25s} {'Parallelism':<15s} {'Network AI':>12s} {'T_compute':>10s} {'T_network':>10s} {'Bound':>12s}") + print(f" {'-'*25:<25s} {'-'*15:<15s} {'-'*12:>12s} {'-'*10:>10s} {'-'*10:>10s} {'-'*12:>12s}") + ops = sorted(phase_data.get('operations', []), key=lambda x: -x['measured_time_us']) + for op in ops[:6]: + op_name = f"N={op['N']},K={op['K']},{op['dtype']}" if op.get('dtype') else f"N={op['N']},K={op['K']}" + parallelism = op.get('parallelism', 'unknown') + ai_str = 'N/A' if op['network_ai'] == float('inf') else f"{op['network_ai']:.0f}" + t_comp = f"{op['t_compute_us']:.1f}us" + t_net = f"{op['t_network_us']:.1f}us" if op['t_network_us'] > 0 else 'N/A' + bound = op['bound'] + print(f" {op_name:<25s} {parallelism:<15s} {ai_str:>12s} {t_comp:>10s} {t_net:>10s} {bound:>12s}") + # Summary per phase + total_gemm_time = phase_data['total_gemm_time_us'] + row_ops = [op for op in phase_data['operations'] if op.get('parallelism') == 'row-parallel'] + col_ops = [op for op in phase_data['operations'] if op.get('parallelism') == 'column-parallel'] + row_time = sum(op['measured_time_us'] for op in row_ops) + col_time = sum(op['measured_time_us'] for op in col_ops) + net_bound = [op for op in row_ops if 'network' in op['bound']] + net_bound_time = sum(op['measured_time_us'] for op in net_bound) + print() + print(f" Row-parallel: {len(row_ops)} ops, {row_time/1000:.2f}ms ({net_bound_time/1000:.2f}ms network-bound)") + print(f" Column-parallel: {len(col_ops)} ops, {col_time/1000:.2f}ms") + print() + ar_stats = network_roofline.get('allreduce_stats', {}) + if ar_stats.get('count', 0) > 0: + print(" AllReduce Statistics (excluding warmup):") + print(f" Count: {ar_stats['count']}") + print(f" Total time: {ar_stats['total_time_us']/1000:.2f} ms") + print(f" Avg time: {ar_stats['avg_time_us']:.2f} us") + # GEMM layer breakdown + print("\n" + "-"*80) + print("GEMM Layer Type Breakdown (kernels with known dimensions):") + gemm_by_layer = defaultdict(lambda: {'time_us': 0, 'count': 0, 'mfu_sum': 0}) + for g in gemm_infos: + lt = g.layer_type + gemm_by_layer[lt]['time_us'] += g.duration_us + gemm_by_layer[lt]['count'] += 1 + gemm_by_layer[lt]['mfu_sum'] += g.mfu * g.duration_us + gemm_total_time = sum(d['time_us'] for d in gemm_by_layer.values()) + for layer_name in ['QKVO', 'FFN', 'Other']: + if layer_name in gemm_by_layer: + data = gemm_by_layer[layer_name] + time_ms = data['time_us'] / 1000 / num_gpus + pct = (data['time_us'] / gemm_total_time * 100) if gemm_total_time > 0 else 0 + avg_mfu_lt = data['mfu_sum'] / data['time_us'] if data['time_us'] > 0 else 0 + print(f" {layer_name:<10s}: {time_ms:10.2f} ms/GPU ({pct:5.1f}%) [{data['count']:5d} kernels] Avg MFU: {avg_mfu_lt:.1f}%") + # Unmatched kernels summary + if events: + analysed_signatures = set(g.kernel_name[:50] for g in gemm_infos) + unmatched_time_us = 0 + unmatched_count = 0 + unmatched_types = defaultdict(lambda: {'count': 0, 'time_us': 0}) + for e in events: + if e.get('cat') != 'kernel': + continue + name = e.get('name', '') + name_lower = name.lower() + if not any(x in name_lower for x in ['gemm', 'matmul', 'nvjet']): + continue + ext_id = e.get('args', {}).get('External id') + if ext_id is None and name[:50] not in analysed_signatures: + grid = e.get('args', {}).get('grid', []) + if infer_cuda_graph_kernel_dims(name, grid, config=None) is None: # type: ignore + unmatched_time_us += e.get('dur', 0) + unmatched_count += 1 + if 'nvjet' in name_lower: + unmatched_types['nvjet']['count'] += 1 + unmatched_types['nvjet']['time_us'] += e.get('dur', 0) + elif 'router_gemm' in name_lower: + unmatched_types['router_gemm']['count'] += 1 + unmatched_types['router_gemm']['time_us'] += e.get('dur', 0) + else: + unmatched_types['other']['count'] += 1 + unmatched_types['other']['time_us'] += e.get('dur', 0) + if unmatched_count > 0: + unmatched_time_ms = unmatched_time_us / 1000 / num_gpus + print(f"\n Note: {unmatched_count} GEMM kernels ({unmatched_time_ms:.2f} ms/GPU) could not be analysed:") + for ktype, data in sorted(unmatched_types.items(), key=lambda x: -x[1]['time_us']): + print(f" {ktype}: {data['count']} kernels, {data['time_us']/1000/num_gpus:.2f} ms/GPU") + print("="*80) \ No newline at end of file